1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include <vector>
20 
21 #include "IndexedShapeWrapper.h"
22 #include "OperationResolver.h"
23 #include "OperationsUtils.h"
24 
25 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
26 #include "LSTM.h"
27 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
28 
29 namespace android {
30 namespace nn {
31 namespace unidirectional_sequence_lstm {
32 
33 // Inputs
34 constexpr uint32_t kNumInputs = 28;
35 
36 // Input tensor of size {max_time, n_batch, n_input}
37 constexpr uint32_t kInputTensor = 0;
38 
39 // Input weight tensors of size: {n_cell, n_input}
40 constexpr uint32_t kInputToInputWeightsTensor = 1;  // Optional
41 constexpr uint32_t kInputToForgetWeightsTensor = 2;
42 constexpr uint32_t kInputToCellWeightsTensor = 3;
43 constexpr uint32_t kInputToOutputWeightsTensor = 4;
44 
45 // Recurrent weight tensors of size {n_cell, n_output}
46 constexpr uint32_t kRecurrentToInputWeightsTensor = 5;  // Optional
47 constexpr uint32_t kRecurrentToForgetWeightsTensor = 6;
48 constexpr uint32_t kRecurrentToCellWeightsTensor = 7;
49 constexpr uint32_t kRecurrentToOutputWeightsTensor = 8;
50 
51 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
52 constexpr uint32_t kCellToInputWeightsTensor = 9;    // Optional
53 constexpr uint32_t kCellToForgetWeightsTensor = 10;  // Optional
54 constexpr uint32_t kCellToOutputWeightsTensor = 11;  // Optional
55 
56 // Gates bias tensors of size {n_cell}
57 constexpr uint32_t kInputGateBiasTensor = 12;  // Optional
58 constexpr uint32_t kForgetGateBiasTensor = 13;
59 constexpr uint32_t kCellGateBiasTensor = 14;
60 constexpr uint32_t kOutputGateBiasTensor = 15;
61 
62 // Projection weight tensor of size {n_output, n_cell}
63 constexpr uint32_t kProjectionWeightsTensor = 16;  // Optional
64 // Projection bias tensor of size {n_output}
65 constexpr uint32_t kProjectionBiasTensor = 17;  // Optional
66 
67 // Input from the output of the previous step, tensor of size {batch_size, n_output}
68 constexpr uint32_t kOutputStateInTensor = 18;
69 // Input from the cell state of the previous step, tensor of size {batch_size, n_cell}
70 constexpr uint32_t kCellStateInTensor = 19;
71 
72 constexpr uint32_t kActivationParam = 20;
73 constexpr uint32_t kCellClipParam = 21;
74 constexpr uint32_t kProjClipParam = 22;
75 constexpr uint32_t kTimeMajorParam = 23;
76 
77 // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
78 constexpr uint32_t kInputLayerNormWeightsTensor = 24;   // Optional
79 constexpr uint32_t kForgetLayerNormWeightsTensor = 25;  // Optional
80 constexpr uint32_t kCellLayerNormWeightsTensor = 26;    // Optional
81 constexpr uint32_t kOutputLayerNormWeightsTensor = 27;  // Optional
82 
83 // Output tensors.
84 constexpr uint32_t kNumOutputs = 1;
85 constexpr uint32_t kNumOutputsWithState = 3;
86 
87 constexpr uint32_t kOutputTensor = 0;
88 constexpr uint32_t kOutputStateOutTensor = 1;
89 constexpr uint32_t kCellStateOutTensor = 2;
90 
91 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
92 namespace {
93 
hasTensor(IOperationExecutionContext * context,const uint32_t tensor)94 inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {
95     return context->getInputBuffer(tensor) != nullptr;
96 }
97 
isTimeMajor(IOperationExecutionContext * context)98 inline bool isTimeMajor(IOperationExecutionContext* context) {
99     return context->getInputValue<bool>(kTimeMajorParam);
100 }
101 
102 template <typename T>
getLSTMParams(IOperationExecutionContext * context)103 inline LSTMParams getLSTMParams(IOperationExecutionContext* context) {
104     LSTMParams params;
105     params.activation =
106             static_cast<TfLiteFusedActivation>(context->getInputValue<int32_t>(kActivationParam));
107     params.cell_clip = static_cast<float>(context->getInputValue<T>(kCellClipParam));
108     params.proj_clip = static_cast<float>(context->getInputValue<T>(kProjClipParam));
109     params.use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
110     params.use_peephole = hasTensor(context, kCellToOutputWeightsTensor);
111     params.use_layer_norm = hasTensor(context, kOutputLayerNormWeightsTensor);
112     params.use_projection_weight = hasTensor(context, kProjectionWeightsTensor);
113     params.use_projection_bias = hasTensor(context, kProjectionBiasTensor);
114     return params;
115 }
116 
117 }  // namespace
118 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
119 
validate(const IOperationValidationContext * context)120 Result<Version> validate(const IOperationValidationContext* context) {
121     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
122     const uint32_t numOutputs = context->getNumOutputs();
123     NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState);
124     const OperandType inputType = context->getInputType(kInputTensor);
125     std::vector<OperandType> inExpectedTypes;
126     std::vector<OperandType> outExpectedTypes;
127     if (inputType == OperandType::TENSOR_FLOAT32) {
128         inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
129                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
130                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
131                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
132                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
133                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
134                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
135                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
136                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
137                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
138                            OperandType::INT32,          OperandType::FLOAT32,
139                            OperandType::FLOAT32,        OperandType::BOOL,
140                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
141                            OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32};
142         outExpectedTypes = {OperandType::TENSOR_FLOAT32};
143     } else if (inputType == OperandType::TENSOR_FLOAT16) {
144         inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
145                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
146                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
147                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
148                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
149                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
150                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
151                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
152                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
153                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
154                            OperandType::INT32,          OperandType::FLOAT16,
155                            OperandType::FLOAT16,        OperandType::BOOL,
156                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
157                            OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16};
158         outExpectedTypes = {OperandType::TENSOR_FLOAT16};
159     } else {
160         NN_RET_CHECK_FAIL()
161                 << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_LSTM op: "
162                 << inputType;
163     }
164     Version minVersionSupported = Version::ANDROID_Q;
165     if (context->getNumOutputs() == kNumOutputsWithState) {
166         minVersionSupported = Version::ANDROID_R;
167         outExpectedTypes.insert(outExpectedTypes.end(), {inputType, inputType});
168     }
169     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
170     NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
171     return minVersionSupported;
172 }
173 
174 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
prepare(IOperationExecutionContext * context)175 bool prepare(IOperationExecutionContext* context) {
176     // Check that none of the required inputs are omitted
177     const std::vector<int> requiredInputs = {
178             kInputTensor,
179             kInputToForgetWeightsTensor,
180             kInputToCellWeightsTensor,
181             kInputToOutputWeightsTensor,
182             kRecurrentToForgetWeightsTensor,
183             kRecurrentToCellWeightsTensor,
184             kRecurrentToOutputWeightsTensor,
185             kForgetGateBiasTensor,
186             kCellGateBiasTensor,
187             kOutputGateBiasTensor,
188             kOutputStateInTensor,
189             kCellStateInTensor,
190             kActivationParam,
191             kCellClipParam,
192             kProjClipParam,
193             kTimeMajorParam,
194     };
195     for (const int requiredInput : requiredInputs) {
196         NN_RET_CHECK(!context->isOmittedInput(requiredInput))
197                 << "required input " << requiredInput << " is omitted";
198     }
199 
200     const Shape inputShape = context->getInputShape(kInputTensor);
201     const uint32_t inputRank = getNumberOfDimensions(inputShape);
202     NN_RET_CHECK_EQ(inputRank, 3) << "Invalid input tensor rank: " << inputRank;
203 
204     const uint32_t maxTime = getSizeOfDimension(inputShape, isTimeMajor(context) ? 0 : 1);
205     const uint32_t batchSize = getSizeOfDimension(inputShape, isTimeMajor(context) ? 1 : 0);
206     const uint32_t inputSize = getSizeOfDimension(inputShape, inputRank - 1);
207 
208     const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor);
209     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2);
210     NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize);
211     const uint32_t numCells = getSizeOfDimension(inputToOutputShape, 0);
212 
213     const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor);
214     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2);
215     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numCells);
216     const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1);
217 
218     if (hasTensor(context, kInputToInputWeightsTensor)) {
219         const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor);
220         NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2);
221         NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numCells);
222         NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize);
223     }
224 
225     const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor);
226     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2);
227     NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numCells);
228     NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize);
229     const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor);
230     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2);
231     NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numCells);
232     NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize);
233 
234     if (hasTensor(context, kRecurrentToInputWeightsTensor)) {
235         const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor);
236         NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2);
237         NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numCells);
238         NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize);
239     }
240 
241     const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor);
242     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2);
243     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numCells);
244     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize);
245     const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor);
246     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2);
247     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numCells);
248     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize);
249 
250     // We make sure the input-gate's parameters are either both present (regular
251     // LSTM) or not at all (CIFG-LSTM).
252     const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) &&
253                                        hasTensor(context, kRecurrentToInputWeightsTensor)) ||
254                                       (!hasTensor(context, kInputToInputWeightsTensor) &&
255                                        !hasTensor(context, kRecurrentToInputWeightsTensor));
256     NN_RET_CHECK(cifgWeightsAllOrNone);
257 
258     if (hasTensor(context, kCellToInputWeightsTensor)) {
259         const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
260         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1);
261         NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numCells);
262     }
263 
264     if (hasTensor(context, kCellToForgetWeightsTensor)) {
265         const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
266         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1);
267         NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numCells);
268     }
269 
270     if (hasTensor(context, kCellToOutputWeightsTensor)) {
271         const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
272         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1);
273         NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numCells);
274     }
275 
276     // Making sure the peephole weights are there all or none.
277     const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor);
278     const bool peepholeWeightsAllOrNone =
279             ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) &&
280              hasTensor(context, kCellToForgetWeightsTensor) &&
281              hasTensor(context, kCellToOutputWeightsTensor)) ||
282             (!hasTensor(context, kCellToInputWeightsTensor) &&
283              !hasTensor(context, kCellToForgetWeightsTensor) &&
284              !hasTensor(context, kCellToOutputWeightsTensor));
285     NN_RET_CHECK(peepholeWeightsAllOrNone);
286 
287     if (!cifgUsed) {
288         NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor));
289         const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor);
290         NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1);
291         NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numCells);
292     } else {
293         NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor))
294                 << "Input gate bias tensor is present when CIFG is used";
295     }
296 
297     const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor);
298     NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1);
299     NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numCells);
300     const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor);
301     NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1);
302     NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numCells);
303     const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor);
304     NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1);
305     NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numCells);
306 
307     if (hasTensor(context, kProjectionWeightsTensor)) {
308         const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor);
309         NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2);
310         NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize);
311         NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numCells);
312     }
313 
314     if (hasTensor(context, kProjectionBiasTensor)) {
315         const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor);
316         NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1);
317         NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize);
318     }
319 
320     const Shape outputStateShape = context->getInputShape(kOutputStateInTensor);
321     NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2);
322     NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize);
323     NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize);
324     const Shape cellStateShape = context->getInputShape(kCellStateInTensor);
325     NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2);
326     NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize);
327     NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numCells);
328 
329     if (hasTensor(context, kInputLayerNormWeightsTensor)) {
330         const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormWeightsTensor);
331         NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1);
332         NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numCells);
333     }
334 
335     if (hasTensor(context, kForgetLayerNormWeightsTensor)) {
336         const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormWeightsTensor);
337         NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1);
338         NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numCells);
339     }
340 
341     if (hasTensor(context, kCellLayerNormWeightsTensor)) {
342         const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormWeightsTensor);
343         NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1);
344         NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numCells);
345     }
346 
347     if (hasTensor(context, kOutputLayerNormWeightsTensor)) {
348         const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormWeightsTensor);
349         NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1);
350         NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numCells);
351     }
352 
353     if (cifgUsed) {
354         NN_RET_CHECK(!hasTensor(context, kInputLayerNormWeightsTensor))
355                 << "Input layer norm weights tensor is present when CIFG is used";
356         const bool layerNormWeightsAllOrNoneCifg =
357                 (hasTensor(context, kForgetLayerNormWeightsTensor) &&
358                  hasTensor(context, kCellLayerNormWeightsTensor) &&
359                  hasTensor(context, kOutputLayerNormWeightsTensor)) ||
360                 (!hasTensor(context, kForgetLayerNormWeightsTensor) &&
361                  !hasTensor(context, kCellLayerNormWeightsTensor) &&
362                  !hasTensor(context, kOutputLayerNormWeightsTensor));
363         NN_RET_CHECK(layerNormWeightsAllOrNoneCifg);
364     } else {
365         const bool layerNormWeightsAllOrNone =
366                 (hasTensor(context, kInputLayerNormWeightsTensor) &&
367                  hasTensor(context, kForgetLayerNormWeightsTensor) &&
368                  hasTensor(context, kCellLayerNormWeightsTensor) &&
369                  hasTensor(context, kOutputLayerNormWeightsTensor)) ||
370                 (!hasTensor(context, kInputLayerNormWeightsTensor) &&
371                  !hasTensor(context, kForgetLayerNormWeightsTensor) &&
372                  !hasTensor(context, kCellLayerNormWeightsTensor) &&
373                  !hasTensor(context, kOutputLayerNormWeightsTensor));
374         NN_RET_CHECK(layerNormWeightsAllOrNone);
375     }
376 
377     Shape outputShape = context->getInputShape(kInputTensor);
378     outputShape.dimensions[2] = outputSize;
379 
380     if (context->getNumOutputs() == kNumOutputsWithState) {
381         NN_RET_CHECK(!context->isOmittedOutput(kOutputStateOutTensor));
382         NN_RET_CHECK(!context->isOmittedOutput(kCellStateOutTensor));
383 
384         Shape outputStateOutTensor = context->getInputShape(kOutputStateInTensor);
385         outputStateOutTensor.dimensions.resize(2);
386         outputStateOutTensor.dimensions[0] = batchSize;
387         outputStateOutTensor.dimensions[1] = outputSize;
388         NN_RET_CHECK(context->setOutputShape(kOutputStateOutTensor, outputStateOutTensor));
389 
390         Shape cellStateOutTensor = context->getInputShape(kCellStateInTensor);
391         cellStateOutTensor.dimensions.resize(2);
392         cellStateOutTensor.dimensions[0] = batchSize;
393         cellStateOutTensor.dimensions[1] = numCells;
394         NN_RET_CHECK(context->setOutputShape(kCellStateOutTensor, cellStateOutTensor));
395     }
396 
397     return context->setOutputShape(kOutputTensor, outputShape);
398 }
399 
execute(IOperationExecutionContext * context)400 bool execute(IOperationExecutionContext* context) {
401     const auto outputStateSize = getNumberOfElements(context->getInputShape(kOutputStateInTensor));
402     const auto cellStateSize = getNumberOfElements(context->getInputShape(kCellStateInTensor));
403     const bool use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
404     const auto scratchSize = use_cifg ? 3 * cellStateSize : 4 * cellStateSize;
405     const bool useStateOutTensors = (context->getNumOutputs() == kNumOutputsWithState);
406 
407     const OperandType inputType = context->getInputType(kInputTensor);
408     switch (inputType) {
409         case OperandType::TENSOR_FLOAT32: {
410             // Initialize empty vectors and resize below only if needed
411             std::vector<float> outputStateOutBuffer;
412             std::vector<float> cellStateOutBuffer;
413             float* outputStateOut;
414             float* cellStateOut;
415             if (useStateOutTensors) {
416                 outputStateOut = context->getOutputBuffer<float>(kOutputStateOutTensor);
417                 cellStateOut = context->getOutputBuffer<float>(kCellStateOutTensor);
418             } else {
419                 outputStateOutBuffer.resize(outputStateSize);
420                 cellStateOutBuffer.resize(cellStateSize);
421                 outputStateOut = outputStateOutBuffer.data();
422                 cellStateOut = cellStateOutBuffer.data();
423             }
424             std::vector<float> scratchBuffer(scratchSize);
425             LSTMCell::LSTMEvalFloat32(
426                     getLSTMParams<float>(context), context->getInputBuffer<float>(kInputTensor),
427                     context->getInputShape(kInputTensor),
428                     context->getInputBuffer<float>(kInputToInputWeightsTensor),
429                     context->getInputBuffer<float>(kInputToForgetWeightsTensor),
430                     context->getInputBuffer<float>(kInputToCellWeightsTensor),
431                     context->getInputBuffer<float>(kInputToOutputWeightsTensor),
432                     context->getInputShape(kInputToOutputWeightsTensor),
433                     context->getInputBuffer<float>(kRecurrentToInputWeightsTensor),
434                     context->getInputBuffer<float>(kRecurrentToForgetWeightsTensor),
435                     context->getInputBuffer<float>(kRecurrentToCellWeightsTensor),
436                     context->getInputBuffer<float>(kRecurrentToOutputWeightsTensor),
437                     context->getInputShape(kRecurrentToOutputWeightsTensor),
438                     context->getInputBuffer<float>(kCellToInputWeightsTensor),
439                     context->getInputBuffer<float>(kCellToForgetWeightsTensor),
440                     context->getInputBuffer<float>(kCellToOutputWeightsTensor),
441                     /*aux_input_buffer=*/nullptr,
442                     /*aux_input_to_input_weights_buffer=*/nullptr,
443                     /*aux_input_to_forget_weights_buffer=*/nullptr,
444                     /*aux_input_to_cell_weights_buffer=*/nullptr,
445                     /*aux_input_to_output_weights_buffer=*/nullptr,
446                     context->getInputBuffer<float>(kInputGateBiasTensor),
447                     context->getInputBuffer<float>(kForgetGateBiasTensor),
448                     context->getInputBuffer<float>(kCellGateBiasTensor),
449                     context->getInputBuffer<float>(kOutputGateBiasTensor),
450                     context->getInputBuffer<float>(kProjectionWeightsTensor),
451                     context->getInputBuffer<float>(kProjectionBiasTensor),
452                     context->getInputBuffer<float>(kOutputStateInTensor),
453                     context->getInputBuffer<float>(kCellStateInTensor),
454                     context->getInputBuffer<float>(kInputLayerNormWeightsTensor),
455                     context->getInputBuffer<float>(kForgetLayerNormWeightsTensor),
456                     context->getInputBuffer<float>(kCellLayerNormWeightsTensor),
457                     context->getInputBuffer<float>(kOutputLayerNormWeightsTensor), outputStateOut,
458                     cellStateOut, context->getOutputBuffer<float>(kOutputTensor),
459                     scratchBuffer.data(), isTimeMajor(context));
460         } break;
461         case OperandType::TENSOR_FLOAT16: {
462             // Initialize empty vectors and resize below only if needed
463             std::vector<_Float16> outputStateOutBuffer;
464             std::vector<_Float16> cellStateOutBuffer;
465             _Float16* outputStateOut;
466             _Float16* cellStateOut;
467             if (useStateOutTensors) {
468                 outputStateOut = context->getOutputBuffer<_Float16>(kOutputStateOutTensor);
469                 cellStateOut = context->getOutputBuffer<_Float16>(kCellStateOutTensor);
470             } else {
471                 outputStateOutBuffer.resize(outputStateSize);
472                 cellStateOutBuffer.resize(cellStateSize);
473                 outputStateOut = outputStateOutBuffer.data();
474                 cellStateOut = cellStateOutBuffer.data();
475             }
476             std::vector<_Float16> scratchBuffer(scratchSize);
477             LSTMCell::LSTMEvalFloat16(
478                     getLSTMParams<_Float16>(context),
479                     context->getInputBuffer<_Float16>(kInputTensor),
480                     context->getInputShape(kInputTensor),
481                     context->getInputBuffer<_Float16>(kInputToInputWeightsTensor),
482                     context->getInputBuffer<_Float16>(kInputToForgetWeightsTensor),
483                     context->getInputBuffer<_Float16>(kInputToCellWeightsTensor),
484                     context->getInputBuffer<_Float16>(kInputToOutputWeightsTensor),
485                     context->getInputShape(kInputToOutputWeightsTensor),
486                     context->getInputBuffer<_Float16>(kRecurrentToInputWeightsTensor),
487                     context->getInputBuffer<_Float16>(kRecurrentToForgetWeightsTensor),
488                     context->getInputBuffer<_Float16>(kRecurrentToCellWeightsTensor),
489                     context->getInputBuffer<_Float16>(kRecurrentToOutputWeightsTensor),
490                     context->getInputShape(kRecurrentToOutputWeightsTensor),
491                     context->getInputBuffer<_Float16>(kCellToInputWeightsTensor),
492                     context->getInputBuffer<_Float16>(kCellToForgetWeightsTensor),
493                     context->getInputBuffer<_Float16>(kCellToOutputWeightsTensor),
494                     /*aux_input_buffer=*/nullptr,
495                     /*aux_input_to_input_weights_buffer=*/nullptr,
496                     /*aux_input_to_forget_weights_buffer=*/nullptr,
497                     /*aux_input_to_cell_weights_buffer=*/nullptr,
498                     /*aux_input_to_output_weights_buffer=*/nullptr,
499                     context->getInputBuffer<_Float16>(kInputGateBiasTensor),
500                     context->getInputBuffer<_Float16>(kForgetGateBiasTensor),
501                     context->getInputBuffer<_Float16>(kCellGateBiasTensor),
502                     context->getInputBuffer<_Float16>(kOutputGateBiasTensor),
503                     context->getInputBuffer<_Float16>(kProjectionWeightsTensor),
504                     context->getInputBuffer<_Float16>(kProjectionBiasTensor),
505                     context->getInputBuffer<_Float16>(kOutputStateInTensor),
506                     context->getInputBuffer<_Float16>(kCellStateInTensor),
507                     context->getInputBuffer<_Float16>(kInputLayerNormWeightsTensor),
508                     context->getInputBuffer<_Float16>(kForgetLayerNormWeightsTensor),
509                     context->getInputBuffer<_Float16>(kCellLayerNormWeightsTensor),
510                     context->getInputBuffer<_Float16>(kOutputLayerNormWeightsTensor),
511                     outputStateOut, cellStateOut, context->getOutputBuffer<_Float16>(kOutputTensor),
512                     scratchBuffer.data(), isTimeMajor(context));
513         } break;
514         default: {
515             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(inputType);
516             return false;
517         }
518     }
519     return true;
520 }
521 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
522 
523 }  // namespace unidirectional_sequence_lstm
524 
525 NN_REGISTER_OPERATION(UNIDIRECTIONAL_SEQUENCE_LSTM, "UNIDIRECTIONAL_SEQUENCE_LSTM",
526                       unidirectional_sequence_lstm::validate, unidirectional_sequence_lstm::prepare,
527                       unidirectional_sequence_lstm::execute, .allowOmittedOperand = true);
528 
529 }  // namespace nn
530 }  // namespace android
531