1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "lstm_builder.h"
17
18 namespace OHOS {
19 namespace NeuralNetworkRuntime {
20 namespace Ops {
21 static const int INPUT_NUM = 6;
22 static const int OUTPUT_NUM = 3;
23 static const int PARAM_MAX_NUM = 10;
24 static const int SCALAR_LENGTH = 1;
25 static const std::string OP_NAME = "LSTM";
26
LSTMBuilder()27 LSTMBuilder::LSTMBuilder() {}
28
~LSTMBuilder()29 LSTMBuilder::~LSTMBuilder() {}
30
SetBidirectional(const std::shared_ptr<NNTensor> & tensor)31 OH_NN_ReturnCode LSTMBuilder::SetBidirectional(const std::shared_ptr<NNTensor>& tensor)
32 {
33 if (tensor->GetDataType() != OH_NN_BOOL) {
34 LOGE("[LSTM] The bidirectional should be type OH_NN_BOOL.");
35 return OH_NN_INVALID_PARAMETER;
36 }
37
38 if (tensor->GetElementCount() != SCALAR_LENGTH) {
39 LOGE("[LSTM] The bidirectional should be scalar.");
40 return OH_NN_INVALID_PARAMETER;
41 }
42
43 void* buffer = tensor->GetBuffer();
44 if (buffer == nullptr) {
45 LOGE("[LSTM] Tensor buffer is nullptr.");
46 return OH_NN_INVALID_PARAMETER;
47 }
48 m_bidirectional = *(static_cast<bool*>(buffer));
49
50 return OH_NN_SUCCESS;
51 }
52
SetHasBias(const std::shared_ptr<NNTensor> & tensor)53 OH_NN_ReturnCode LSTMBuilder::SetHasBias(const std::shared_ptr<NNTensor>& tensor)
54 {
55 if (tensor->GetDataType() != OH_NN_BOOL) {
56 LOGE("[LSTM] The hasBias should be type OH_NN_BOOL.");
57 return OH_NN_INVALID_PARAMETER;
58 }
59
60 if (tensor->GetElementCount() != SCALAR_LENGTH) {
61 LOGE("[LSTM] The hasBias should be scalar.");
62 return OH_NN_INVALID_PARAMETER;
63 }
64
65 void* buffer = tensor->GetBuffer();
66 if (buffer == nullptr) {
67 LOGE("[LSTM] Tensor buffer is nullptr.");
68 return OH_NN_INVALID_PARAMETER;
69 }
70 m_hasBias = *(static_cast<bool*>(buffer));
71
72 return OH_NN_SUCCESS;
73 }
74
SetInputSize(const std::shared_ptr<NNTensor> & tensor)75 OH_NN_ReturnCode LSTMBuilder::SetInputSize(const std::shared_ptr<NNTensor>& tensor)
76 {
77 if (tensor->GetDataType() != OH_NN_INT64) {
78 LOGE("[LSTM] The inputSize should be type OH_NN_INT64.");
79 return OH_NN_INVALID_PARAMETER;
80 }
81
82 if (tensor->GetElementCount() != SCALAR_LENGTH) {
83 LOGE("[LSTM] The inputSize should be scalar.");
84 return OH_NN_INVALID_PARAMETER;
85 }
86
87 void* buffer = tensor->GetBuffer();
88 if (buffer == nullptr) {
89 LOGE("[LSTM] Tensor buffer is nullptr.");
90 return OH_NN_INVALID_PARAMETER;
91 }
92 m_inputSize = *(static_cast<const int64_t*>(buffer));
93
94 return OH_NN_SUCCESS;
95 }
96
SetHiddenSize(const std::shared_ptr<NNTensor> & tensor)97 OH_NN_ReturnCode LSTMBuilder::SetHiddenSize(const std::shared_ptr<NNTensor>& tensor)
98 {
99 if (tensor->GetDataType() != OH_NN_INT64) {
100 LOGE("[LSTM] The hiddenSize should be type OH_NN_INT64.");
101 return OH_NN_INVALID_PARAMETER;
102 }
103
104 if (tensor->GetElementCount() != SCALAR_LENGTH) {
105 LOGE("[LSTM] The hiddenSize should be scalar.");
106 return OH_NN_INVALID_PARAMETER;
107 }
108
109 void* buffer = tensor->GetBuffer();
110 if (buffer == nullptr) {
111 LOGE("[LSTM] Tensor buffer is nullptr.");
112 return OH_NN_INVALID_PARAMETER;
113 }
114 m_hiddenSize = *(static_cast<const int64_t*>(buffer));
115
116 return OH_NN_SUCCESS;
117 }
118
SetNumLayers(const std::shared_ptr<NNTensor> & tensor)119 OH_NN_ReturnCode LSTMBuilder::SetNumLayers(const std::shared_ptr<NNTensor>& tensor)
120 {
121 if (tensor->GetDataType() != OH_NN_INT64) {
122 LOGE("[LSTM] The numLayers should be type OH_NN_INT64.");
123 return OH_NN_INVALID_PARAMETER;
124 }
125
126 if (tensor->GetElementCount() != SCALAR_LENGTH) {
127 LOGE("[LSTM] The numLayers should be scalar.");
128 return OH_NN_INVALID_PARAMETER;
129 }
130
131 void* buffer = tensor->GetBuffer();
132 if (buffer == nullptr) {
133 LOGE("[LSTM] Tensor buffer is nullptr.");
134 return OH_NN_INVALID_PARAMETER;
135 }
136 m_numLayers = *(static_cast<const int64_t*>(buffer));
137
138 return OH_NN_SUCCESS;
139 }
140
SetNumDirections(const std::shared_ptr<NNTensor> & tensor)141 OH_NN_ReturnCode LSTMBuilder::SetNumDirections(const std::shared_ptr<NNTensor>& tensor)
142 {
143 if (tensor->GetDataType() != OH_NN_INT64) {
144 LOGE("[LSTM] The numDirections should be type OH_NN_INT64.");
145 return OH_NN_INVALID_PARAMETER;
146 }
147
148 if (tensor->GetElementCount() != SCALAR_LENGTH) {
149 LOGE("[LSTM] The numDirections should be scalar.");
150 return OH_NN_INVALID_PARAMETER;
151 }
152
153 void* buffer = tensor->GetBuffer();
154 if (buffer == nullptr) {
155 LOGE("[LSTM] Tensor buffer is nullptr.");
156 return OH_NN_INVALID_PARAMETER;
157 }
158 m_numDirections = *(static_cast<const int64_t*>(buffer));
159
160 return OH_NN_SUCCESS;
161 }
162
SetDropout(const std::shared_ptr<NNTensor> & tensor)163 OH_NN_ReturnCode LSTMBuilder::SetDropout(const std::shared_ptr<NNTensor>& tensor)
164 {
165 if (tensor->GetDataType() != OH_NN_FLOAT32) {
166 LOGE("[LSTM] The dropout should be type OH_NN_FLOAT32.");
167 return OH_NN_INVALID_PARAMETER;
168 }
169
170 if (tensor->GetElementCount() != SCALAR_LENGTH) {
171 LOGE("[LSTM] The dropout should be scalar.");
172 return OH_NN_INVALID_PARAMETER;
173 }
174
175 void* buffer = tensor->GetBuffer();
176 if (buffer == nullptr) {
177 LOGE("[LSTM] Tensor buffer is nullptr.");
178 return OH_NN_INVALID_PARAMETER;
179 }
180 m_dropout = *(static_cast<const float*>(buffer));
181
182 return OH_NN_SUCCESS;
183 }
184
SetZoneoutCell(const std::shared_ptr<NNTensor> & tensor)185 OH_NN_ReturnCode LSTMBuilder::SetZoneoutCell(const std::shared_ptr<NNTensor>& tensor)
186 {
187 if (tensor->GetDataType() != OH_NN_FLOAT32) {
188 LOGE("[LSTM] The zoneoutCell should be type OH_NN_FLOAT32.");
189 return OH_NN_INVALID_PARAMETER;
190 }
191
192 if (tensor->GetElementCount() != SCALAR_LENGTH) {
193 LOGE("[LSTM] The zoneoutCell should be scalar.");
194 return OH_NN_INVALID_PARAMETER;
195 }
196
197 void* buffer = tensor->GetBuffer();
198 if (buffer == nullptr) {
199 LOGE("[LSTM] Tensor buffer is nullptr.");
200 return OH_NN_INVALID_PARAMETER;
201 }
202 m_zoneoutCell = *(static_cast<const float*>(buffer));
203
204 return OH_NN_SUCCESS;
205 }
206
SetZoneoutHidden(const std::shared_ptr<NNTensor> & tensor)207 OH_NN_ReturnCode LSTMBuilder::SetZoneoutHidden(const std::shared_ptr<NNTensor>& tensor)
208 {
209 if (tensor->GetDataType() != OH_NN_FLOAT32) {
210 LOGE("[LSTM] The zoneoutHidden should be type OH_NN_FLOAT32.");
211 return OH_NN_INVALID_PARAMETER;
212 }
213
214 if (tensor->GetElementCount() != SCALAR_LENGTH) {
215 LOGE("[LSTM] The zoneoutHidden should be scalar.");
216 return OH_NN_INVALID_PARAMETER;
217 }
218
219 void* buffer = tensor->GetBuffer();
220 if (buffer == nullptr) {
221 LOGE("[LSTM] Tensor buffer is nullptr.");
222 return OH_NN_INVALID_PARAMETER;
223 }
224 m_zoneoutHidden = *(static_cast<const float*>(buffer));
225
226 return OH_NN_SUCCESS;
227 }
228
SetProjSize(const std::shared_ptr<NNTensor> & tensor)229 OH_NN_ReturnCode LSTMBuilder::SetProjSize(const std::shared_ptr<NNTensor>& tensor)
230 {
231 if (tensor->GetDataType() != OH_NN_INT64) {
232 LOGE("[LSTM] The projSize should be type OH_NN_INT64.");
233 return OH_NN_INVALID_PARAMETER;
234 }
235
236 if (tensor->GetElementCount() != SCALAR_LENGTH) {
237 LOGE("[LSTM] The projSize should be scalar.");
238 return OH_NN_INVALID_PARAMETER;
239 }
240
241 void* buffer = tensor->GetBuffer();
242 if (buffer == nullptr) {
243 LOGE("[LSTM] Tensor buffer is nullptr.");
244 return OH_NN_INVALID_PARAMETER;
245 }
246 m_projSize = *(static_cast<const float*>(buffer));
247
248 return OH_NN_SUCCESS;
249 }
250
ParseParam(const std::vector<uint32_t> & paramsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)251 OH_NN_ReturnCode LSTMBuilder::ParseParam(const std::vector<uint32_t>& paramsIndex,
252 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
253 {
254 OH_NN_ReturnCode returnCode;
255 for (int i : paramsIndex) {
256 std::shared_ptr<NNTensor> tensor = allTensors[i];
257 tensor->IdentifyOpParameter();
258 if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
259 returnCode = (this->*(m_paramMap[tensor->GetType()]))(tensor);
260 } else {
261 LOGE("[lSTM] Build failed, param invalid, type=%d", tensor->GetType());
262 return OH_NN_INVALID_PARAMETER;
263 }
264
265 if (returnCode != OH_NN_SUCCESS) {
266 LOGE("[LSTM] Build failed, passed invalid param.");
267 return returnCode;
268 }
269 }
270 return OH_NN_SUCCESS;
271 }
272
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)273 OH_NN_ReturnCode LSTMBuilder::Build(const std::vector<uint32_t>& paramsIndex,
274 const std::vector<uint32_t>& inputsIndex,
275 const std::vector<uint32_t>& outputsIndex,
276 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
277 {
278 if (m_isBuild) {
279 LOGE("[LSTM] Build failed, the LSTM operation has been build. cannot build again.");
280 return OH_NN_OPERATION_FORBIDDEN;
281 }
282
283 auto ret = CheckIOIndex(inputsIndex, outputsIndex, allTensors, INPUT_NUM, OUTPUT_NUM);
284 if (ret != OH_NN_SUCCESS) {
285 LOGE("[LSTM] Build failed, passed invalid input or output index.");
286 return ret;
287 }
288
289 m_inputsIndex = inputsIndex;
290 m_outputsIndex = outputsIndex;
291
292 ret = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
293 if (ret != OH_NN_SUCCESS) {
294 LOGE("[LSTM] Build failed, passed invalid param index.");
295 return ret;
296 }
297
298 ret = ParseParam(paramsIndex, allTensors);
299 if (ret != OH_NN_SUCCESS) {
300 LOGE("[LSTM] ParseParam failed, passed invalid param.");
301 return ret;
302 }
303
304 m_name = OP_NAME;
305 m_isBuild = true;
306 return OH_NN_SUCCESS;
307 }
308
GetPrimitive()309 LiteGraphPrimitvePtr LSTMBuilder::GetPrimitive()
310 {
311 if (!m_isBuild) {
312 LOGE("[LSTM] GetPrimitive failed, cannot get primitive before call build.");
313 return {nullptr, DestroyLiteGraphPrimitive};
314 }
315
316 void* primitive = mindspore::lite::MindIR_LSTM_CreatePrimitive(m_bidirectional, m_hasBias, m_inputSize,
317 m_hiddenSize, m_numLayers, m_numDirections, m_dropout, m_zoneoutCell, m_zoneoutHidden, m_projSize);
318 LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive) ;
319 return graphPrimitivePtr;
320 }
321
322 REGISTER_OPS(LSTMBuilder, OH_NN_OPS_LSTM);
323 } // namespace Ops
324 } // namespace NeuralNetworkRuntime
325 } // namespace OHOS