1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <vector>
20
21 #include "IndexedShapeWrapper.h"
22 #include "OperationResolver.h"
23 #include "OperationsUtils.h"
24
25 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
26 #include "LSTM.h"
27 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
28
29 namespace android {
30 namespace nn {
31 namespace unidirectional_sequence_lstm {
32
33 // Inputs
34 constexpr uint32_t kNumInputs = 28;
35
36 // Input tensor of size {max_time, n_batch, n_input}
37 constexpr uint32_t kInputTensor = 0;
38
39 // Input weight tensors of size: {n_cell, n_input}
40 constexpr uint32_t kInputToInputWeightsTensor = 1; // Optional
41 constexpr uint32_t kInputToForgetWeightsTensor = 2;
42 constexpr uint32_t kInputToCellWeightsTensor = 3;
43 constexpr uint32_t kInputToOutputWeightsTensor = 4;
44
45 // Recurrent weight tensors of size {n_cell, n_output}
46 constexpr uint32_t kRecurrentToInputWeightsTensor = 5; // Optional
47 constexpr uint32_t kRecurrentToForgetWeightsTensor = 6;
48 constexpr uint32_t kRecurrentToCellWeightsTensor = 7;
49 constexpr uint32_t kRecurrentToOutputWeightsTensor = 8;
50
51 // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
52 constexpr uint32_t kCellToInputWeightsTensor = 9; // Optional
53 constexpr uint32_t kCellToForgetWeightsTensor = 10; // Optional
54 constexpr uint32_t kCellToOutputWeightsTensor = 11; // Optional
55
56 // Gates bias tensors of size {n_cell}
57 constexpr uint32_t kInputGateBiasTensor = 12; // Optional
58 constexpr uint32_t kForgetGateBiasTensor = 13;
59 constexpr uint32_t kCellGateBiasTensor = 14;
60 constexpr uint32_t kOutputGateBiasTensor = 15;
61
62 // Projection weight tensor of size {n_output, n_cell}
63 constexpr uint32_t kProjectionWeightsTensor = 16; // Optional
64 // Projection bias tensor of size {n_output}
65 constexpr uint32_t kProjectionBiasTensor = 17; // Optional
66
67 // Input from the output of the previous step, tensor of size {batch_size, n_output}
68 constexpr uint32_t kOutputStateInTensor = 18;
69 // Input from the cell state of the previous step, tensor of size {batch_size, n_cell}
70 constexpr uint32_t kCellStateInTensor = 19;
71
72 constexpr uint32_t kActivationParam = 20;
73 constexpr uint32_t kCellClipParam = 21;
74 constexpr uint32_t kProjClipParam = 22;
75 constexpr uint32_t kTimeMajorParam = 23;
76
77 // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
78 constexpr uint32_t kInputLayerNormWeightsTensor = 24; // Optional
79 constexpr uint32_t kForgetLayerNormWeightsTensor = 25; // Optional
80 constexpr uint32_t kCellLayerNormWeightsTensor = 26; // Optional
81 constexpr uint32_t kOutputLayerNormWeightsTensor = 27; // Optional
82
83 // Output tensors.
84 constexpr uint32_t kNumOutputs = 1;
85 constexpr uint32_t kNumOutputsWithState = 3;
86
87 constexpr uint32_t kOutputTensor = 0;
88 constexpr uint32_t kOutputStateOutTensor = 1;
89 constexpr uint32_t kCellStateOutTensor = 2;
90
91 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
92 namespace {
93
hasTensor(IOperationExecutionContext * context,const uint32_t tensor)94 inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {
95 return context->getInputBuffer(tensor) != nullptr;
96 }
97
isTimeMajor(IOperationExecutionContext * context)98 inline bool isTimeMajor(IOperationExecutionContext* context) {
99 return context->getInputValue<bool>(kTimeMajorParam);
100 }
101
102 template <typename T>
getLSTMParams(IOperationExecutionContext * context)103 inline LSTMParams getLSTMParams(IOperationExecutionContext* context) {
104 LSTMParams params;
105 params.activation =
106 static_cast<TfLiteFusedActivation>(context->getInputValue<int32_t>(kActivationParam));
107 params.cell_clip = static_cast<float>(context->getInputValue<T>(kCellClipParam));
108 params.proj_clip = static_cast<float>(context->getInputValue<T>(kProjClipParam));
109 params.use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
110 params.use_peephole = hasTensor(context, kCellToOutputWeightsTensor);
111 params.use_layer_norm = hasTensor(context, kOutputLayerNormWeightsTensor);
112 params.use_projection_weight = hasTensor(context, kProjectionWeightsTensor);
113 params.use_projection_bias = hasTensor(context, kProjectionBiasTensor);
114 return params;
115 }
116
117 } // namespace
118 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
119
validate(const IOperationValidationContext * context)120 Result<Version> validate(const IOperationValidationContext* context) {
121 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
122 const uint32_t numOutputs = context->getNumOutputs();
123 NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState);
124 const OperandType inputType = context->getInputType(kInputTensor);
125 std::vector<OperandType> inExpectedTypes;
126 std::vector<OperandType> outExpectedTypes;
127 if (inputType == OperandType::TENSOR_FLOAT32) {
128 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
129 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
130 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
131 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
132 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
133 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
134 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
135 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
136 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
137 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
138 OperandType::INT32, OperandType::FLOAT32,
139 OperandType::FLOAT32, OperandType::BOOL,
140 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
141 OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32};
142 outExpectedTypes = {OperandType::TENSOR_FLOAT32};
143 } else if (inputType == OperandType::TENSOR_FLOAT16) {
144 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
145 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
146 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
147 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
148 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
149 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
150 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
151 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
152 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
153 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
154 OperandType::INT32, OperandType::FLOAT16,
155 OperandType::FLOAT16, OperandType::BOOL,
156 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
157 OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16};
158 outExpectedTypes = {OperandType::TENSOR_FLOAT16};
159 } else {
160 NN_RET_CHECK_FAIL()
161 << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_LSTM op: "
162 << inputType;
163 }
164 Version minVersionSupported = Version::ANDROID_Q;
165 if (context->getNumOutputs() == kNumOutputsWithState) {
166 minVersionSupported = Version::ANDROID_R;
167 outExpectedTypes.insert(outExpectedTypes.end(), {inputType, inputType});
168 }
169 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
170 NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
171 return minVersionSupported;
172 }
173
174 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
prepare(IOperationExecutionContext * context)175 bool prepare(IOperationExecutionContext* context) {
176 // Check that none of the required inputs are omitted
177 const std::vector<int> requiredInputs = {
178 kInputTensor,
179 kInputToForgetWeightsTensor,
180 kInputToCellWeightsTensor,
181 kInputToOutputWeightsTensor,
182 kRecurrentToForgetWeightsTensor,
183 kRecurrentToCellWeightsTensor,
184 kRecurrentToOutputWeightsTensor,
185 kForgetGateBiasTensor,
186 kCellGateBiasTensor,
187 kOutputGateBiasTensor,
188 kOutputStateInTensor,
189 kCellStateInTensor,
190 kActivationParam,
191 kCellClipParam,
192 kProjClipParam,
193 kTimeMajorParam,
194 };
195 for (const int requiredInput : requiredInputs) {
196 NN_RET_CHECK(!context->isOmittedInput(requiredInput))
197 << "required input " << requiredInput << " is omitted";
198 }
199
200 const Shape inputShape = context->getInputShape(kInputTensor);
201 const uint32_t inputRank = getNumberOfDimensions(inputShape);
202 NN_RET_CHECK_EQ(inputRank, 3) << "Invalid input tensor rank: " << inputRank;
203
204 const uint32_t maxTime = getSizeOfDimension(inputShape, isTimeMajor(context) ? 0 : 1);
205 const uint32_t batchSize = getSizeOfDimension(inputShape, isTimeMajor(context) ? 1 : 0);
206 const uint32_t inputSize = getSizeOfDimension(inputShape, inputRank - 1);
207
208 const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor);
209 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2);
210 NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize);
211 const uint32_t numCells = getSizeOfDimension(inputToOutputShape, 0);
212
213 const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor);
214 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2);
215 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numCells);
216 const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1);
217
218 if (hasTensor(context, kInputToInputWeightsTensor)) {
219 const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor);
220 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2);
221 NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numCells);
222 NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize);
223 }
224
225 const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor);
226 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2);
227 NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numCells);
228 NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize);
229 const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor);
230 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2);
231 NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numCells);
232 NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize);
233
234 if (hasTensor(context, kRecurrentToInputWeightsTensor)) {
235 const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor);
236 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2);
237 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numCells);
238 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize);
239 }
240
241 const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor);
242 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2);
243 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numCells);
244 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize);
245 const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor);
246 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2);
247 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numCells);
248 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize);
249
250 // We make sure the input-gate's parameters are either both present (regular
251 // LSTM) or not at all (CIFG-LSTM).
252 const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) &&
253 hasTensor(context, kRecurrentToInputWeightsTensor)) ||
254 (!hasTensor(context, kInputToInputWeightsTensor) &&
255 !hasTensor(context, kRecurrentToInputWeightsTensor));
256 NN_RET_CHECK(cifgWeightsAllOrNone);
257
258 if (hasTensor(context, kCellToInputWeightsTensor)) {
259 const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
260 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1);
261 NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numCells);
262 }
263
264 if (hasTensor(context, kCellToForgetWeightsTensor)) {
265 const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
266 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1);
267 NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numCells);
268 }
269
270 if (hasTensor(context, kCellToOutputWeightsTensor)) {
271 const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
272 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1);
273 NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numCells);
274 }
275
276 // Making sure the peephole weights are there all or none.
277 const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor);
278 const bool peepholeWeightsAllOrNone =
279 ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) &&
280 hasTensor(context, kCellToForgetWeightsTensor) &&
281 hasTensor(context, kCellToOutputWeightsTensor)) ||
282 (!hasTensor(context, kCellToInputWeightsTensor) &&
283 !hasTensor(context, kCellToForgetWeightsTensor) &&
284 !hasTensor(context, kCellToOutputWeightsTensor));
285 NN_RET_CHECK(peepholeWeightsAllOrNone);
286
287 if (!cifgUsed) {
288 NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor));
289 const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor);
290 NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1);
291 NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numCells);
292 } else {
293 NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor))
294 << "Input gate bias tensor is present when CIFG is used";
295 }
296
297 const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor);
298 NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1);
299 NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numCells);
300 const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor);
301 NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1);
302 NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numCells);
303 const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor);
304 NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1);
305 NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numCells);
306
307 if (hasTensor(context, kProjectionWeightsTensor)) {
308 const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor);
309 NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2);
310 NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize);
311 NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numCells);
312 }
313
314 if (hasTensor(context, kProjectionBiasTensor)) {
315 const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor);
316 NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1);
317 NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize);
318 }
319
320 const Shape outputStateShape = context->getInputShape(kOutputStateInTensor);
321 NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2);
322 NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize);
323 NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize);
324 const Shape cellStateShape = context->getInputShape(kCellStateInTensor);
325 NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2);
326 NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize);
327 NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numCells);
328
329 if (hasTensor(context, kInputLayerNormWeightsTensor)) {
330 const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormWeightsTensor);
331 NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1);
332 NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numCells);
333 }
334
335 if (hasTensor(context, kForgetLayerNormWeightsTensor)) {
336 const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormWeightsTensor);
337 NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1);
338 NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numCells);
339 }
340
341 if (hasTensor(context, kCellLayerNormWeightsTensor)) {
342 const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormWeightsTensor);
343 NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1);
344 NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numCells);
345 }
346
347 if (hasTensor(context, kOutputLayerNormWeightsTensor)) {
348 const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormWeightsTensor);
349 NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1);
350 NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numCells);
351 }
352
353 if (cifgUsed) {
354 NN_RET_CHECK(!hasTensor(context, kInputLayerNormWeightsTensor))
355 << "Input layer norm weights tensor is present when CIFG is used";
356 const bool layerNormWeightsAllOrNoneCifg =
357 (hasTensor(context, kForgetLayerNormWeightsTensor) &&
358 hasTensor(context, kCellLayerNormWeightsTensor) &&
359 hasTensor(context, kOutputLayerNormWeightsTensor)) ||
360 (!hasTensor(context, kForgetLayerNormWeightsTensor) &&
361 !hasTensor(context, kCellLayerNormWeightsTensor) &&
362 !hasTensor(context, kOutputLayerNormWeightsTensor));
363 NN_RET_CHECK(layerNormWeightsAllOrNoneCifg);
364 } else {
365 const bool layerNormWeightsAllOrNone =
366 (hasTensor(context, kInputLayerNormWeightsTensor) &&
367 hasTensor(context, kForgetLayerNormWeightsTensor) &&
368 hasTensor(context, kCellLayerNormWeightsTensor) &&
369 hasTensor(context, kOutputLayerNormWeightsTensor)) ||
370 (!hasTensor(context, kInputLayerNormWeightsTensor) &&
371 !hasTensor(context, kForgetLayerNormWeightsTensor) &&
372 !hasTensor(context, kCellLayerNormWeightsTensor) &&
373 !hasTensor(context, kOutputLayerNormWeightsTensor));
374 NN_RET_CHECK(layerNormWeightsAllOrNone);
375 }
376
377 Shape outputShape = context->getInputShape(kInputTensor);
378 outputShape.dimensions[2] = outputSize;
379
380 if (context->getNumOutputs() == kNumOutputsWithState) {
381 NN_RET_CHECK(!context->isOmittedOutput(kOutputStateOutTensor));
382 NN_RET_CHECK(!context->isOmittedOutput(kCellStateOutTensor));
383
384 Shape outputStateOutTensor = context->getInputShape(kOutputStateInTensor);
385 outputStateOutTensor.dimensions.resize(2);
386 outputStateOutTensor.dimensions[0] = batchSize;
387 outputStateOutTensor.dimensions[1] = outputSize;
388 NN_RET_CHECK(context->setOutputShape(kOutputStateOutTensor, outputStateOutTensor));
389
390 Shape cellStateOutTensor = context->getInputShape(kCellStateInTensor);
391 cellStateOutTensor.dimensions.resize(2);
392 cellStateOutTensor.dimensions[0] = batchSize;
393 cellStateOutTensor.dimensions[1] = numCells;
394 NN_RET_CHECK(context->setOutputShape(kCellStateOutTensor, cellStateOutTensor));
395 }
396
397 return context->setOutputShape(kOutputTensor, outputShape);
398 }
399
execute(IOperationExecutionContext * context)400 bool execute(IOperationExecutionContext* context) {
401 const auto outputStateSize = getNumberOfElements(context->getInputShape(kOutputStateInTensor));
402 const auto cellStateSize = getNumberOfElements(context->getInputShape(kCellStateInTensor));
403 const bool use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
404 const auto scratchSize = use_cifg ? 3 * cellStateSize : 4 * cellStateSize;
405 const bool useStateOutTensors = (context->getNumOutputs() == kNumOutputsWithState);
406
407 const OperandType inputType = context->getInputType(kInputTensor);
408 switch (inputType) {
409 case OperandType::TENSOR_FLOAT32: {
410 // Initialize empty vectors and resize below only if needed
411 std::vector<float> outputStateOutBuffer;
412 std::vector<float> cellStateOutBuffer;
413 float* outputStateOut;
414 float* cellStateOut;
415 if (useStateOutTensors) {
416 outputStateOut = context->getOutputBuffer<float>(kOutputStateOutTensor);
417 cellStateOut = context->getOutputBuffer<float>(kCellStateOutTensor);
418 } else {
419 outputStateOutBuffer.resize(outputStateSize);
420 cellStateOutBuffer.resize(cellStateSize);
421 outputStateOut = outputStateOutBuffer.data();
422 cellStateOut = cellStateOutBuffer.data();
423 }
424 std::vector<float> scratchBuffer(scratchSize);
425 LSTMCell::LSTMEvalFloat32(
426 getLSTMParams<float>(context), context->getInputBuffer<float>(kInputTensor),
427 context->getInputShape(kInputTensor),
428 context->getInputBuffer<float>(kInputToInputWeightsTensor),
429 context->getInputBuffer<float>(kInputToForgetWeightsTensor),
430 context->getInputBuffer<float>(kInputToCellWeightsTensor),
431 context->getInputBuffer<float>(kInputToOutputWeightsTensor),
432 context->getInputShape(kInputToOutputWeightsTensor),
433 context->getInputBuffer<float>(kRecurrentToInputWeightsTensor),
434 context->getInputBuffer<float>(kRecurrentToForgetWeightsTensor),
435 context->getInputBuffer<float>(kRecurrentToCellWeightsTensor),
436 context->getInputBuffer<float>(kRecurrentToOutputWeightsTensor),
437 context->getInputShape(kRecurrentToOutputWeightsTensor),
438 context->getInputBuffer<float>(kCellToInputWeightsTensor),
439 context->getInputBuffer<float>(kCellToForgetWeightsTensor),
440 context->getInputBuffer<float>(kCellToOutputWeightsTensor),
441 /*aux_input_buffer=*/nullptr,
442 /*aux_input_to_input_weights_buffer=*/nullptr,
443 /*aux_input_to_forget_weights_buffer=*/nullptr,
444 /*aux_input_to_cell_weights_buffer=*/nullptr,
445 /*aux_input_to_output_weights_buffer=*/nullptr,
446 context->getInputBuffer<float>(kInputGateBiasTensor),
447 context->getInputBuffer<float>(kForgetGateBiasTensor),
448 context->getInputBuffer<float>(kCellGateBiasTensor),
449 context->getInputBuffer<float>(kOutputGateBiasTensor),
450 context->getInputBuffer<float>(kProjectionWeightsTensor),
451 context->getInputBuffer<float>(kProjectionBiasTensor),
452 context->getInputBuffer<float>(kOutputStateInTensor),
453 context->getInputBuffer<float>(kCellStateInTensor),
454 context->getInputBuffer<float>(kInputLayerNormWeightsTensor),
455 context->getInputBuffer<float>(kForgetLayerNormWeightsTensor),
456 context->getInputBuffer<float>(kCellLayerNormWeightsTensor),
457 context->getInputBuffer<float>(kOutputLayerNormWeightsTensor), outputStateOut,
458 cellStateOut, context->getOutputBuffer<float>(kOutputTensor),
459 scratchBuffer.data(), isTimeMajor(context));
460 } break;
461 case OperandType::TENSOR_FLOAT16: {
462 // Initialize empty vectors and resize below only if needed
463 std::vector<_Float16> outputStateOutBuffer;
464 std::vector<_Float16> cellStateOutBuffer;
465 _Float16* outputStateOut;
466 _Float16* cellStateOut;
467 if (useStateOutTensors) {
468 outputStateOut = context->getOutputBuffer<_Float16>(kOutputStateOutTensor);
469 cellStateOut = context->getOutputBuffer<_Float16>(kCellStateOutTensor);
470 } else {
471 outputStateOutBuffer.resize(outputStateSize);
472 cellStateOutBuffer.resize(cellStateSize);
473 outputStateOut = outputStateOutBuffer.data();
474 cellStateOut = cellStateOutBuffer.data();
475 }
476 std::vector<_Float16> scratchBuffer(scratchSize);
477 LSTMCell::LSTMEvalFloat16(
478 getLSTMParams<_Float16>(context),
479 context->getInputBuffer<_Float16>(kInputTensor),
480 context->getInputShape(kInputTensor),
481 context->getInputBuffer<_Float16>(kInputToInputWeightsTensor),
482 context->getInputBuffer<_Float16>(kInputToForgetWeightsTensor),
483 context->getInputBuffer<_Float16>(kInputToCellWeightsTensor),
484 context->getInputBuffer<_Float16>(kInputToOutputWeightsTensor),
485 context->getInputShape(kInputToOutputWeightsTensor),
486 context->getInputBuffer<_Float16>(kRecurrentToInputWeightsTensor),
487 context->getInputBuffer<_Float16>(kRecurrentToForgetWeightsTensor),
488 context->getInputBuffer<_Float16>(kRecurrentToCellWeightsTensor),
489 context->getInputBuffer<_Float16>(kRecurrentToOutputWeightsTensor),
490 context->getInputShape(kRecurrentToOutputWeightsTensor),
491 context->getInputBuffer<_Float16>(kCellToInputWeightsTensor),
492 context->getInputBuffer<_Float16>(kCellToForgetWeightsTensor),
493 context->getInputBuffer<_Float16>(kCellToOutputWeightsTensor),
494 /*aux_input_buffer=*/nullptr,
495 /*aux_input_to_input_weights_buffer=*/nullptr,
496 /*aux_input_to_forget_weights_buffer=*/nullptr,
497 /*aux_input_to_cell_weights_buffer=*/nullptr,
498 /*aux_input_to_output_weights_buffer=*/nullptr,
499 context->getInputBuffer<_Float16>(kInputGateBiasTensor),
500 context->getInputBuffer<_Float16>(kForgetGateBiasTensor),
501 context->getInputBuffer<_Float16>(kCellGateBiasTensor),
502 context->getInputBuffer<_Float16>(kOutputGateBiasTensor),
503 context->getInputBuffer<_Float16>(kProjectionWeightsTensor),
504 context->getInputBuffer<_Float16>(kProjectionBiasTensor),
505 context->getInputBuffer<_Float16>(kOutputStateInTensor),
506 context->getInputBuffer<_Float16>(kCellStateInTensor),
507 context->getInputBuffer<_Float16>(kInputLayerNormWeightsTensor),
508 context->getInputBuffer<_Float16>(kForgetLayerNormWeightsTensor),
509 context->getInputBuffer<_Float16>(kCellLayerNormWeightsTensor),
510 context->getInputBuffer<_Float16>(kOutputLayerNormWeightsTensor),
511 outputStateOut, cellStateOut, context->getOutputBuffer<_Float16>(kOutputTensor),
512 scratchBuffer.data(), isTimeMajor(context));
513 } break;
514 default: {
515 LOG(ERROR) << "Unsupported data type: " << static_cast<int>(inputType);
516 return false;
517 }
518 }
519 return true;
520 }
521 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
522
523 } // namespace unidirectional_sequence_lstm
524
525 NN_REGISTER_OPERATION(UNIDIRECTIONAL_SEQUENCE_LSTM, "UNIDIRECTIONAL_SEQUENCE_LSTM",
526 unidirectional_sequence_lstm::validate, unidirectional_sequence_lstm::prepare,
527 unidirectional_sequence_lstm::execute, .allowOmittedOperand = true);
528
529 } // namespace nn
530 } // namespace android
531