1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef NEURAL_NETWORK_RUNTIME_LSTM_BUILDER_H
17 #define NEURAL_NETWORK_RUNTIME_LSTM_BUILDER_H
18 
19 #include "mindir.h"
20 
21 #include "ops_builder.h"
22 #include "ops_registry.h"
23 
24 namespace OHOS {
25 namespace NeuralNetworkRuntime {
26 namespace Ops {
27 class LSTMBuilder : public OpsBuilder {
28 public:
29     typedef OH_NN_ReturnCode (LSTMBuilder::*FuncPtr)(const std::shared_ptr<NNTensor>&);
30 
31     LSTMBuilder();
32     ~LSTMBuilder() override;
33     OH_NN_ReturnCode Build(const std::vector<uint32_t>& paramsIndex,
34                            const std::vector<uint32_t>& inputsIndex,
35                            const std::vector<uint32_t>& outputsIndex,
36                            const std::vector<std::shared_ptr<NNTensor>>& allTensors) override;
37     LiteGraphPrimitvePtr GetPrimitive() override;
38 
39 private:
40     OH_NN_ReturnCode SetBidirectional(const std::shared_ptr<NNTensor>& tensor);
41     OH_NN_ReturnCode SetHasBias(const std::shared_ptr<NNTensor>& tensor);
42     OH_NN_ReturnCode SetInputSize(const std::shared_ptr<NNTensor>& tensor);
43     OH_NN_ReturnCode SetHiddenSize(const std::shared_ptr<NNTensor>& tensor);
44     OH_NN_ReturnCode SetNumLayers(const std::shared_ptr<NNTensor>& tensor);
45     OH_NN_ReturnCode SetNumDirections(const std::shared_ptr<NNTensor>& tensor);
46     OH_NN_ReturnCode SetDropout(const std::shared_ptr<NNTensor>& tensor);
47     OH_NN_ReturnCode SetZoneoutCell(const std::shared_ptr<NNTensor>& tensor);
48     OH_NN_ReturnCode SetZoneoutHidden(const std::shared_ptr<NNTensor>& tensor);
49     OH_NN_ReturnCode SetProjSize(const std::shared_ptr<NNTensor>& tensor);
50     OH_NN_ReturnCode ParseParam(const std::vector<uint32_t>& paramsIndex,
51                                 const std::vector<std::shared_ptr<NNTensor>>& allTensors);
52 
53 private:
54     bool m_bidirectional {false};
55     bool m_hasBias {false};
56     int64_t m_inputSize {0};
57     int64_t m_hiddenSize {0};
58     int64_t m_numLayers {0};
59     int64_t m_numDirections {0};
60     float m_dropout {0.0f};
61     float m_zoneoutCell {0.0f};
62     float m_zoneoutHidden {0.0f};
63     int64_t m_projSize {0};
64     std::unordered_map<OH_NN_TensorType, FuncPtr> m_paramMap = {
65         {OH_NN_LSTM_BIDIRECTIONAL, &LSTMBuilder::SetBidirectional},
66         {OH_NN_LSTM_HAS_BIAS, &LSTMBuilder::SetHasBias},
67         {OH_NN_LSTM_INPUT_SIZE, &LSTMBuilder::SetInputSize},
68         {OH_NN_LSTM_HIDDEN_SIZE, &LSTMBuilder::SetHiddenSize},
69         {OH_NN_LSTM_NUM_LAYERS, &LSTMBuilder::SetNumLayers},
70         {OH_NN_LSTM_NUM_DIRECTIONS, &LSTMBuilder::SetNumDirections},
71         {OH_NN_LSTM_DROPOUT, &LSTMBuilder::SetDropout},
72         {OH_NN_LSTM_ZONEOUT_CELL, &LSTMBuilder::SetZoneoutCell},
73         {OH_NN_LSTM_ZONEOUT_HIDDEN, &LSTMBuilder::SetZoneoutHidden},
74         {OH_NN_LSTM_PROJ_SIZE, &LSTMBuilder::SetProjSize}
75     };
76 };
77 } // namespace Ops
78 } // namespace NeuralNetworkRuntime
79 } // namespace OHOS
80 
81 #endif // NEURAL_NETWORK_RUNTIME_LSTM_BUILDER_H