1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include <memory>
19 #include <vector>
20 
21 #include "CpuExecutor.h"
22 #include "OperationsUtils.h"
23 
24 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
25 #include "QuantUtils.h"
26 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
27 
28 namespace android {
29 namespace nn {
30 namespace qlstm {
31 
32 namespace {
33 
34 // Inputs
35 constexpr uint32_t kNumInputs = 32;
36 
37 constexpr uint32_t kInputTensor = 0;
38 // Input weight tensors of size: [numUnits, inputSize].
39 constexpr uint32_t kInputToInputWeightsTensor = 1;
40 constexpr uint32_t kInputToForgetWeightsTensor = 2;
41 constexpr uint32_t kInputToCellWeightsTensor = 3;
42 constexpr uint32_t kInputToOutputWeightsTensor = 4;
43 
44 // Recurrent weight tensors of size [numUnits, outputSize].
45 constexpr uint32_t kRecurrentToInputWeightsTensor = 5;
46 constexpr uint32_t kRecurrentToForgetWeightsTensor = 6;
47 constexpr uint32_t kRecurrentToCellWeightsTensor = 7;
48 constexpr uint32_t kRecurrentToOutputWeightsTensor = 8;
49 
50 // For peephole (optional).
51 // Cell to input/forget/output weights of size [numUnits].
52 constexpr uint32_t kCellToInputWeightsTensor = 9;
53 constexpr uint32_t kCellToForgetWeightsTensor = 10;
54 constexpr uint32_t kCellToOutputWeightsTensor = 11;
55 
56 // Gates bias tensors of size [numUnits].
57 constexpr uint32_t kInputGateBiasTensor = 12;
58 constexpr uint32_t kForgetGateBiasTensor = 13;
59 constexpr uint32_t kCellGateBiasTensor = 14;
60 constexpr uint32_t kOutputGateBiasTensor = 15;
61 
62 // Projection weight tensor of size [outputSize, numUnits].
63 constexpr uint32_t kProjectionWeightsTensor = 16;
64 // Projection bias tensor of size [outputSize].
65 constexpr uint32_t kProjectionBiasTensor = 17;
66 
67 // Output from the previous time step, as tensor
68 // of size [numBatches, outputSize].
69 constexpr uint32_t kPrevOutputTensor = 18;
70 
71 // Cell state from the previous time step, as tensor
72 // of size [numBatches, numUnits].
73 constexpr uint32_t kPrevCellStateTensor = 19;
74 
75 // Layer normalization tensors of size [numUnits].
76 constexpr uint32_t kInputLayerNormTensor = 20;
77 constexpr uint32_t kForgetLayerNormTensor = 21;
78 constexpr uint32_t kCellLayerNormTensor = 22;
79 constexpr uint32_t kOutputLayerNormTensor = 23;
80 
81 // Clipping.
82 constexpr uint32_t kCellClip = 24;
83 constexpr uint32_t kProjectionClip = 25;
84 
85 // Scales of the result of matmul, i.e. input to layer normalization.
86 constexpr uint32_t kInputIntermediateScale = 26;
87 constexpr uint32_t kForgetIntermediateScale = 27;
88 constexpr uint32_t kCellIntermediateScale = 28;
89 constexpr uint32_t kOutputIntermediateScale = 29;
90 
91 // Zero point and scale of hidden state.
92 constexpr uint32_t kHiddenStateZeroPoint = 30;
93 constexpr uint32_t kHiddenStateScale = 31;
94 
95 // Outputs:
96 constexpr uint32_t kNumOutputs = 3;
97 constexpr uint32_t kOutputStateOutTensor = 0;
98 constexpr uint32_t kCellStateOutTensor = 1;
99 constexpr uint32_t kOutputTensor = 2;
100 
hasTensor(IOperationExecutionContext * context,const uint32_t tensor)101 inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {
102     return context->getInputBuffer(tensor) != nullptr;
103 }
104 
105 }  // namespace
106 
validate(const IOperationValidationContext * context)107 Result<Version> validate(const IOperationValidationContext* context) {
108     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
109     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
110 
111     std::vector<OperandType> inExpectedTypes;
112     // Input.
113     inExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED);
114     // Input-to-* and recurrent-to-* weights.
115     for (int i = 0; i < 8; ++i) {
116         inExpectedTypes.push_back(OperandType::TENSOR_QUANT8_SYMM);
117     }
118     // Cell-to-* weights.
119     for (int i = 0; i < 3; ++i) {
120         inExpectedTypes.push_back(OperandType::TENSOR_QUANT16_SYMM);
121     }
122     // Gate biases.
123     for (int i = 0; i < 4; ++i) {
124         inExpectedTypes.push_back(OperandType::TENSOR_INT32);
125     }
126     // Projection.
127     inExpectedTypes.push_back(OperandType::TENSOR_QUANT8_SYMM);
128     inExpectedTypes.push_back(OperandType::TENSOR_INT32);
129     // Previous output.
130     inExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED);
131     // Previous cell state.
132     inExpectedTypes.push_back(OperandType::TENSOR_QUANT16_SYMM);
133     // Layer norm weights
134     for (int i = 0; i < 4; ++i) {
135         inExpectedTypes.push_back(OperandType::TENSOR_QUANT16_SYMM);
136     }
137     // Cell/projection clipping and scales of intermediate results at the 4 gates.
138     for (int i = 0; i < 6; ++i) {
139         inExpectedTypes.push_back(OperandType::FLOAT32);
140     }
141     // Zero point and scale of the hidden state.
142     inExpectedTypes.push_back(OperandType::INT32);
143     inExpectedTypes.push_back(OperandType::FLOAT32);
144     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
145 
146     std::vector<OperandType> outExpectedTypes;
147     // Output state (out).
148     outExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED);
149     // Cell state (out).
150     outExpectedTypes.push_back(OperandType::TENSOR_QUANT16_SYMM);
151     // Output.
152     outExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED);
153     NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
154 
155     return Version::ANDROID_R;
156 }
157 
prepare(IOperationExecutionContext * context)158 bool prepare(IOperationExecutionContext* context) {
159     // Check that none of the required inputs are omitted
160     const std::vector<int> requiredTensorInputs = {
161             kInputTensor,
162             kInputToForgetWeightsTensor,
163             kInputToCellWeightsTensor,
164             kInputToOutputWeightsTensor,
165             kRecurrentToForgetWeightsTensor,
166             kRecurrentToCellWeightsTensor,
167             kRecurrentToOutputWeightsTensor,
168             kForgetGateBiasTensor,
169             kCellGateBiasTensor,
170             kOutputGateBiasTensor,
171             kPrevOutputTensor,
172             kPrevCellStateTensor,
173     };
174     for (const int tensor : requiredTensorInputs) {
175         NN_RET_CHECK(!context->isOmittedInput(tensor))
176                 << "required input " << tensor << " is omitted";
177     }
178 
179     const Shape inputShape = context->getInputShape(kInputTensor);
180     const uint32_t inputRank = getNumberOfDimensions(inputShape);
181     NN_RET_CHECK_EQ(inputRank, 2) << "Invalid input tensor rank: " << inputRank;
182 
183     const uint32_t batchSize = getSizeOfDimension(inputShape, 0);
184     const uint32_t inputSize = getSizeOfDimension(inputShape, 1);
185 
186     const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor);
187     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2);
188     NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize);
189     const uint32_t numUnits = getSizeOfDimension(inputToOutputShape, 0);
190 
191     const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor);
192     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2);
193     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numUnits);
194     const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1);
195 
196     if (hasTensor(context, kInputToInputWeightsTensor)) {
197         const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor);
198         NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2);
199         NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numUnits);
200         NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize);
201     }
202 
203     const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor);
204     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2);
205     NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numUnits);
206     NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize);
207     const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor);
208     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2);
209     NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numUnits);
210     NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize);
211 
212     if (hasTensor(context, kRecurrentToInputWeightsTensor)) {
213         const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor);
214         NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2);
215         NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numUnits);
216         NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize);
217     }
218 
219     const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor);
220     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2);
221     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numUnits);
222     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize);
223     const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor);
224     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2);
225     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numUnits);
226     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize);
227 
228     // Make sure the input-gate's parameters are either all present (non-CIFG) or
229     // not at all (CIFG).
230     const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) &&
231                                        hasTensor(context, kRecurrentToInputWeightsTensor)) ||
232                                       (!hasTensor(context, kInputToInputWeightsTensor) &&
233                                        !hasTensor(context, kRecurrentToInputWeightsTensor));
234     NN_RET_CHECK(cifgWeightsAllOrNone);
235 
236     if (hasTensor(context, kCellToInputWeightsTensor)) {
237         const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
238         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1);
239         NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numUnits);
240     }
241 
242     if (hasTensor(context, kCellToForgetWeightsTensor)) {
243         const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
244         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1);
245         NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numUnits);
246     }
247 
248     if (hasTensor(context, kCellToOutputWeightsTensor)) {
249         const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
250         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1);
251         NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numUnits);
252     }
253 
254     // Making sure the peephole weights are there all or none.
255     const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor);
256     const bool peepholeWeightsAllOrNone =
257             ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) &&
258              hasTensor(context, kCellToForgetWeightsTensor) &&
259              hasTensor(context, kCellToOutputWeightsTensor)) ||
260             (!hasTensor(context, kCellToInputWeightsTensor) &&
261              !hasTensor(context, kCellToForgetWeightsTensor) &&
262              !hasTensor(context, kCellToOutputWeightsTensor));
263     NN_RET_CHECK(peepholeWeightsAllOrNone);
264 
265     if (!cifgUsed) {
266         NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor));
267         const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor);
268         NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1);
269         NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numUnits);
270     } else {
271         NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor))
272                 << "Input gate bias tensor is present when CIFG is used";
273     }
274 
275     const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor);
276     NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1);
277     NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numUnits);
278     const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor);
279     NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1);
280     NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numUnits);
281     const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor);
282     NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1);
283     NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numUnits);
284 
285     if (hasTensor(context, kProjectionWeightsTensor)) {
286         const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor);
287         NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2);
288         NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize);
289         NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numUnits);
290     }
291 
292     if (hasTensor(context, kProjectionBiasTensor)) {
293         const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor);
294         NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1);
295         NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize);
296     }
297 
298     const Shape outputStateShape = context->getInputShape(kPrevOutputTensor);
299     NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2);
300     NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize);
301     NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize);
302     const Shape cellStateShape = context->getInputShape(kPrevCellStateTensor);
303     NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2);
304     NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize);
305     NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numUnits);
306 
307     if (hasTensor(context, kInputLayerNormTensor)) {
308         const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor);
309         NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1);
310         NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numUnits);
311     }
312 
313     if (hasTensor(context, kForgetLayerNormTensor)) {
314         const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor);
315         NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1);
316         NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numUnits);
317     }
318 
319     if (hasTensor(context, kCellLayerNormTensor)) {
320         const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor);
321         NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1);
322         NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numUnits);
323     }
324 
325     if (hasTensor(context, kOutputLayerNormTensor)) {
326         const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor);
327         NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1);
328         NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numUnits);
329     }
330 
331     if (cifgUsed) {
332         NN_RET_CHECK(!hasTensor(context, kInputLayerNormTensor))
333                 << "Input layer norm weights tensor is present when CIFG is used";
334         const bool layerNormWeightsAllOrNoneCifg = (hasTensor(context, kForgetLayerNormTensor) &&
335                                                     hasTensor(context, kCellLayerNormTensor) &&
336                                                     hasTensor(context, kOutputLayerNormTensor)) ||
337                                                    (!hasTensor(context, kForgetLayerNormTensor) &&
338                                                     !hasTensor(context, kCellLayerNormTensor) &&
339                                                     !hasTensor(context, kOutputLayerNormTensor));
340         NN_RET_CHECK(layerNormWeightsAllOrNoneCifg);
341     } else {
342         const bool layerNormWeightsAllOrNone = (hasTensor(context, kInputLayerNormTensor) &&
343                                                 hasTensor(context, kForgetLayerNormTensor) &&
344                                                 hasTensor(context, kCellLayerNormTensor) &&
345                                                 hasTensor(context, kOutputLayerNormTensor)) ||
346                                                (!hasTensor(context, kInputLayerNormTensor) &&
347                                                 !hasTensor(context, kForgetLayerNormTensor) &&
348                                                 !hasTensor(context, kCellLayerNormTensor) &&
349                                                 !hasTensor(context, kOutputLayerNormTensor));
350         NN_RET_CHECK(layerNormWeightsAllOrNone);
351     }
352 
353     const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor);
354     Shape outputShape = context->getOutputShape(kOutputTensor);
355     outputShape.dimensions = prevOutputShape.dimensions;
356 
357     const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor);
358     Shape cellStateOutShape = context->getOutputShape(kCellStateOutTensor);
359     cellStateOutShape.dimensions = prevCellStateShape.dimensions;
360 
361     return context->setOutputShape(kOutputStateOutTensor, outputShape) &&
362            context->setOutputShape(kCellStateOutTensor, cellStateOutShape) &&
363            context->setOutputShape(kOutputTensor, outputShape);
364 }
365 
366 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
execute(IOperationExecutionContext * context)367 bool execute(IOperationExecutionContext* context) {
368     // Gets the inputs.
369     const Shape inputShape = context->getInputShape(kInputTensor);
370     const Shape inputToInputWeightsShape = context->getInputShape(kInputToInputWeightsTensor);
371     const Shape recurrentToInputWeightsShape =
372             context->getInputShape(kRecurrentToInputWeightsTensor);
373     const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
374     const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor);
375     const Shape inputToForgetWeightsShape = context->getInputShape(kInputToForgetWeightsTensor);
376     const Shape recurrentToForgetWeightsShape =
377             context->getInputShape(kRecurrentToForgetWeightsTensor);
378     const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
379     const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor);
380     const Shape inputToCellWeightsShape = context->getInputShape(kInputToCellWeightsTensor);
381     const Shape recurrentToCellWeightsShape = context->getInputShape(kRecurrentToCellWeightsTensor);
382     const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor);
383     const Shape inputToOutputWeightsShape = context->getInputShape(kInputToOutputWeightsTensor);
384     const Shape recurrentToOutputWeightsShape =
385             context->getInputShape(kRecurrentToOutputWeightsTensor);
386     const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
387     const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor);
388     const Shape projectionWeightsShape = context->getInputShape(kProjectionWeightsTensor);
389     const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor);
390     const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor);
391 
392     const uint32_t batchSize = inputShape.dimensions[0];
393     const uint32_t inputSize = inputShape.dimensions[1];
394     const uint32_t numUnits = inputToOutputWeightsShape.dimensions[0];
395     const uint32_t outputSize = recurrentToOutputWeightsShape.dimensions[1];
396 
397     const float cellClip = context->getInputValue<float>(kCellClip);
398     const float projectionClip = context->getInputValue<float>(kProjectionClip);
399     const float inputIntermediateScale = context->getInputValue<float>(kInputIntermediateScale);
400     const float forgetIntermediateScale = context->getInputValue<float>(kForgetIntermediateScale);
401     const float cellIntermediateScale = context->getInputValue<float>(kCellIntermediateScale);
402     const float outputIntermediateScale = context->getInputValue<float>(kOutputIntermediateScale);
403     const int8_t hiddenStateZeroPoint = context->getInputValue<int8_t>(kHiddenStateZeroPoint);
404     const float hiddenStateScale = context->getInputValue<float>(kHiddenStateScale);
405 
406     const int8_t* inputBuffer =
407             reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputTensor));
408 
409     const int8_t* inputToInputWeightsBuffer =
410             reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToInputWeightsTensor));
411     const bool useCifg = (inputToInputWeightsBuffer == nullptr);
412     const int8_t* recurrentToInputWeightsBuffer = reinterpret_cast<const int8_t*>(
413             context->getInputBuffer(kRecurrentToInputWeightsTensor));
414     const int16_t* cellToInputBuffer =
415             reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToInputWeightsTensor));
416     const int16_t* inputLayerNormBuffer =
417             reinterpret_cast<const int16_t*>(context->getInputBuffer(kInputLayerNormTensor));
418     const int32_t* inputBiasBuffer =
419             reinterpret_cast<const int32_t*>(context->getInputBuffer(kInputGateBiasTensor));
420 
421     const int8_t* inputToForgetWeightsBuffer =
422             reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToForgetWeightsTensor));
423     const int8_t* recurrentToForgetWeightsBuffer = reinterpret_cast<const int8_t*>(
424             context->getInputBuffer(kRecurrentToForgetWeightsTensor));
425     const int16_t* cellToForgetBuffer =
426             reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToForgetWeightsTensor));
427     const int16_t* forgetLayerNormBuffer =
428             reinterpret_cast<const int16_t*>(context->getInputBuffer(kForgetLayerNormTensor));
429     const int32_t* forgetBiasBuffer =
430             reinterpret_cast<const int32_t*>(context->getInputBuffer(kForgetGateBiasTensor));
431 
432     const int8_t* inputToCellWeightsBuffer =
433             reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToCellWeightsTensor));
434     const int8_t* recurrentToCellWeightsBuffer =
435             reinterpret_cast<const int8_t*>(context->getInputBuffer(kRecurrentToCellWeightsTensor));
436     const int16_t* cellLayerNormBuffer =
437             reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellLayerNormTensor));
438     const int32_t* cellBiasBuffer =
439             reinterpret_cast<const int32_t*>(context->getInputBuffer(kCellGateBiasTensor));
440 
441     const int8_t* inputToOutputWeightsBuffer =
442             reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToOutputWeightsTensor));
443     const int8_t* recurrentToOutputWeightsBuffer = reinterpret_cast<const int8_t*>(
444             context->getInputBuffer(kRecurrentToOutputWeightsTensor));
445     const int16_t* cellToOutputBuffer =
446             reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToOutputWeightsTensor));
447     const int16_t* outputLayerNormBuffer =
448             reinterpret_cast<const int16_t*>(context->getInputBuffer(kOutputLayerNormTensor));
449     const int32_t* outputBiasBuffer =
450             reinterpret_cast<const int32_t*>(context->getInputBuffer(kOutputGateBiasTensor));
451 
452     const int8_t* projectionWeightsBuffer =
453             reinterpret_cast<const int8_t*>(context->getInputBuffer(kProjectionWeightsTensor));
454     const int32_t* projectionBiasBuffer =
455             reinterpret_cast<const int32_t*>(context->getInputBuffer(kProjectionBiasTensor));
456 
457     const int8_t* prevOutputBuffer =
458             reinterpret_cast<const int8_t*>(context->getInputBuffer(kPrevOutputTensor));
459     const int16_t* prevCellStateBuffer =
460             reinterpret_cast<const int16_t*>(context->getInputBuffer(kPrevCellStateTensor));
461 
462     uint8_t* outputStateBuffer =
463             reinterpret_cast<uint8_t*>(context->getOutputBuffer(kOutputStateOutTensor));
464     int16_t* cellStateBuffer =
465             reinterpret_cast<int16_t*>(context->getOutputBuffer(kCellStateOutTensor));
466     int8_t* outputBuffer = reinterpret_cast<int8_t*>(context->getOutputBuffer(kOutputTensor));
467 
468     // Calculates and decomposes effective scales.
469     // This is for optimizing the matmul calculation.
470     int cellShift;
471     NN_RET_CHECK(CheckedLog2(prevCellStateShape.scale, &cellShift));
472     NN_RET_CHECK(cellShift <= -9);
473 
474     int32_t inputToInputEffectiveScaleA;
475     int32_t inputToInputEffectiveScaleB;
476     int32_t recurrentToInputEffectiveScaleA;
477     int32_t recurrentToInputEffectiveScaleB;
478     int32_t cellToInputEffectiveScaleA;
479     int32_t cellToInputEffectiveScaleB;
480     if (!useCifg) {
481         const float inputToInputEffectiveScale =
482                 inputToInputWeightsShape.scale * inputShape.scale / inputIntermediateScale;
483         NN_RET_CHECK(QuantizeMultiplier(inputToInputEffectiveScale, &inputToInputEffectiveScaleA,
484                                         &inputToInputEffectiveScaleB));
485         const float recurrentToInputEffectiveScale =
486                 recurrentToInputWeightsShape.scale * prevOutputShape.scale / inputIntermediateScale;
487         NN_RET_CHECK(QuantizeMultiplier(recurrentToInputEffectiveScale,
488                                         &recurrentToInputEffectiveScaleA,
489                                         &recurrentToInputEffectiveScaleB));
490         if (cellToInputBuffer != nullptr) {
491             const float cellToInputEffectiveScale =
492                     std::pow(2, cellShift) * cellToInputShape.scale / inputIntermediateScale;
493             NN_RET_CHECK(QuantizeMultiplier(cellToInputEffectiveScale, &cellToInputEffectiveScaleA,
494                                             &cellToInputEffectiveScaleB));
495         }
496     }
497 
498     int32_t inputLayerNormScaleA;
499     int32_t inputLayerNormScaleB;
500     if (inputLayerNormBuffer != nullptr) {
501         NN_RET_CHECK(QuantizeMultiplier(inputLayerNormShape.scale, &inputLayerNormScaleA,
502                                         &inputLayerNormScaleB));
503     }
504 
505     const float inputToForgetEffectiveScale =
506             inputToForgetWeightsShape.scale * inputShape.scale / forgetIntermediateScale;
507     int32_t inputToForgetEffectiveScaleA;
508     int32_t inputToForgetEffectiveScaleB;
509     NN_RET_CHECK(QuantizeMultiplier(inputToForgetEffectiveScale, &inputToForgetEffectiveScaleA,
510                                     &inputToForgetEffectiveScaleB));
511     const float recurrentToForgetEffectiveScale =
512             recurrentToForgetWeightsShape.scale * prevOutputShape.scale / forgetIntermediateScale;
513     int32_t recurrentToForgetEffectiveScaleA;
514     int32_t recurrentToForgetEffectiveScaleB;
515     NN_RET_CHECK(QuantizeMultiplier(recurrentToForgetEffectiveScale,
516                                     &recurrentToForgetEffectiveScaleA,
517                                     &recurrentToForgetEffectiveScaleB));
518     int32_t cellToForgetEffectiveScaleA;
519     int32_t cellToForgetEffectiveScaleB;
520     if (cellToForgetBuffer != nullptr) {
521         const float cellToForgetEffectiveScale =
522                 std::pow(2, cellShift) * cellToForgetShape.scale / forgetIntermediateScale;
523         NN_RET_CHECK(QuantizeMultiplier(cellToForgetEffectiveScale, &cellToForgetEffectiveScaleA,
524                                         &cellToForgetEffectiveScaleB));
525     }
526     int32_t forgetLayerNormScaleA;
527     int32_t forgetLayerNormScaleB;
528     if (forgetLayerNormBuffer != nullptr) {
529         NN_RET_CHECK(QuantizeMultiplier(forgetLayerNormShape.scale, &forgetLayerNormScaleA,
530                                         &forgetLayerNormScaleB));
531     }
532 
533     const float inputToCellEffectiveScale =
534             inputToCellWeightsShape.scale * inputShape.scale / cellIntermediateScale;
535     int32_t inputToCellEffectiveScaleA;
536     int32_t inputToCellEffectiveScaleB;
537     NN_RET_CHECK(QuantizeMultiplier(inputToCellEffectiveScale, &inputToCellEffectiveScaleA,
538                                     &inputToCellEffectiveScaleB));
539     const float recurrentToCellEffectiveScale =
540             recurrentToCellWeightsShape.scale * prevOutputShape.scale / cellIntermediateScale;
541     int32_t recurrentToCellEffectiveScaleA;
542     int32_t recurrentToCellEffectiveScaleB;
543     NN_RET_CHECK(QuantizeMultiplier(recurrentToCellEffectiveScale, &recurrentToCellEffectiveScaleA,
544                                     &recurrentToCellEffectiveScaleB));
545 
546     int32_t cellLayerNormScaleA;
547     int32_t cellLayerNormScaleB;
548     if (cellLayerNormBuffer != nullptr) {
549         NN_RET_CHECK(QuantizeMultiplier(cellLayerNormShape.scale, &cellLayerNormScaleA,
550                                         &cellLayerNormScaleB));
551     }
552 
553     const float inputToOutputEffectiveScale =
554             inputToOutputWeightsShape.scale * inputShape.scale / outputIntermediateScale;
555     int32_t inputToOutputEffectiveScaleA;
556     int32_t inputToOutputEffectiveScaleB;
557     NN_RET_CHECK(QuantizeMultiplier(inputToOutputEffectiveScale, &inputToOutputEffectiveScaleA,
558                                     &inputToOutputEffectiveScaleB));
559     const float recurrentToOutputEffectiveScale =
560             recurrentToOutputWeightsShape.scale * prevOutputShape.scale / outputIntermediateScale;
561     int32_t recurrentToOutputEffectiveScaleA;
562     int32_t recurrentToOutputEffectiveScaleB;
563     NN_RET_CHECK(QuantizeMultiplier(recurrentToOutputEffectiveScale,
564                                     &recurrentToOutputEffectiveScaleA,
565                                     &recurrentToOutputEffectiveScaleB));
566     int32_t cellToOutputEffectiveScaleA;
567     int32_t cellToOutputEffectiveScaleB;
568     if (cellToOutputBuffer != nullptr) {
569         const float cellToOutputEffectiveScale =
570                 std::pow(2, cellShift) * cellToOutputShape.scale / outputIntermediateScale;
571         NN_RET_CHECK(QuantizeMultiplier(cellToOutputEffectiveScale, &cellToOutputEffectiveScaleA,
572                                         &cellToOutputEffectiveScaleB));
573     }
574     int32_t outputLayerNormScaleA;
575     int32_t outputLayerNormScaleB;
576     if (outputLayerNormBuffer != nullptr) {
577         NN_RET_CHECK(QuantizeMultiplier(outputLayerNormShape.scale, &outputLayerNormScaleA,
578                                         &outputLayerNormScaleB));
579     }
580 
581     const float hiddenStateEffectiveScale = std::pow(2, -15) / hiddenStateScale * std::pow(2, -15);
582     int32_t hiddenStateEffectiveScaleA;
583     int32_t hiddenStateEffectiveScaleB;
584     NN_RET_CHECK(QuantizeMultiplier(hiddenStateEffectiveScale, &hiddenStateEffectiveScaleA,
585                                     &hiddenStateEffectiveScaleB));
586 
587     int32_t projectionEffectiveScaleA;
588     int32_t projectionEffectiveScaleB;
589     if (projectionWeightsBuffer != nullptr) {
590         const float projectionEffectiveScale =
591                 projectionWeightsShape.scale * hiddenStateScale / prevOutputShape.scale;
592         NN_RET_CHECK(QuantizeMultiplier(projectionEffectiveScale, &projectionEffectiveScaleA,
593                                         &projectionEffectiveScaleB));
594     }
595 
596     // Calculates quantized clipping parameters.
597     int16_t quantizedCellClip = 0;
598     if (cellClip > 0.0) {
599         quantizedCellClip = static_cast<int32_t>(
600                 std::min(std::max(cellClip / prevCellStateShape.scale, -32768.0f), 32767.0f));
601     }
602     int8_t quantizedProjectionClip = 0;
603     if (projectionClip > 0.0) {
604         quantizedProjectionClip = static_cast<int32_t>(
605                 std::min(std::max(projectionClip / projectionWeightsShape.scale, -128.0f), 127.0f));
606     }
607 
608     // Calculates effective bias.
609     // This is for optimizing the matmul calculation.
610     std::unique_ptr<int32_t[]> inputToInputEffectiveBias;
611     std::unique_ptr<int32_t[]> recurrentToInputEffectiveBias;
612     if (!useCifg) {
613         NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
614                 -inputShape.offset, inputToInputWeightsBuffer, inputToInputWeightsShape,
615                 /*bias=*/nullptr, &inputToInputEffectiveBias));
616         NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
617                 -prevOutputShape.offset, recurrentToInputWeightsBuffer,
618                 recurrentToInputWeightsShape,
619                 /*bias=*/nullptr, &recurrentToInputEffectiveBias));
620     }
621 
622     std::unique_ptr<int32_t[]> inputToForgetEffectiveBias;
623     NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
624             -inputShape.offset, inputToForgetWeightsBuffer, inputToForgetWeightsShape,
625             /*bias=*/nullptr, &inputToForgetEffectiveBias));
626     std::unique_ptr<int32_t[]> recurrentToForgetEffectiveBias;
627     NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
628             -prevOutputShape.offset, recurrentToForgetWeightsBuffer, recurrentToForgetWeightsShape,
629             /*bias=*/nullptr, &recurrentToForgetEffectiveBias));
630 
631     std::unique_ptr<int32_t[]> inputToCellEffectiveBias;
632     NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
633             -inputShape.offset, inputToCellWeightsBuffer, inputToCellWeightsShape,
634             /*bias=*/nullptr, &inputToCellEffectiveBias));
635     std::unique_ptr<int32_t[]> recurrentToCellEffectiveBias;
636     NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
637             -prevOutputShape.offset, recurrentToCellWeightsBuffer, recurrentToCellWeightsShape,
638             /*bias=*/nullptr, &recurrentToCellEffectiveBias));
639 
640     std::unique_ptr<int32_t[]> inputToOutputEffectiveBias;
641     NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
642             -inputShape.offset, inputToOutputWeightsBuffer, inputToOutputWeightsShape,
643             /*bias=*/nullptr, &inputToOutputEffectiveBias));
644     std::unique_ptr<int32_t[]> recurrentToOutputEffectiveBias;
645     NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
646             -prevOutputShape.offset, recurrentToOutputWeightsBuffer, recurrentToOutputWeightsShape,
647             /*bias=*/nullptr, &recurrentToOutputEffectiveBias));
648 
649     std::unique_ptr<int32_t[]> projectionEffectiveBias;
650     if (projectionBiasBuffer != nullptr) {
651         NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
652                 hiddenStateZeroPoint, projectionWeightsBuffer, projectionWeightsShape,
653                 projectionBiasBuffer, &projectionEffectiveBias));
654     }
655 
656     // Temporary buffers.
657     std::vector<int16_t> inputGateBuffer(batchSize * numUnits);
658     std::vector<int16_t> forgetGateBuffer(batchSize * numUnits);
659     std::vector<int16_t> cellGateBuffer(batchSize * numUnits);
660     std::vector<int16_t> outputGateBuffer(batchSize * numUnits);
661     std::vector<int8_t> buffer8(batchSize * numUnits);
662 
663     // To avoid overflow when calculating layer norm.
664     const int32_t inputInvLargeValue =
665             std::min(1, static_cast<int32_t>(10000 * inputLayerNormShape.scale));
666     const int32_t forgetInvLargeValue =
667             std::min(1, static_cast<int32_t>(10000 * forgetLayerNormShape.scale));
668     const int32_t cellInvLargeValue =
669             std::min(1, static_cast<int32_t>(10000 * cellLayerNormShape.scale));
670     const int32_t outputInvLargeValue =
671             std::min(1, static_cast<int32_t>(10000 * outputLayerNormShape.scale));
672 
673     // Forget gate.
674     MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToForgetEffectiveBias.get(),
675                                         inputToForgetWeightsBuffer, inputToForgetEffectiveScaleA,
676                                         inputToForgetEffectiveScaleB, batchSize, inputSize,
677                                         numUnits,
678                                         /*outputZeroPoint=*/0, forgetGateBuffer.data());
679     MatrixBatchVectorMultiplyAccumulate(
680             prevOutputBuffer, recurrentToForgetEffectiveBias.get(), recurrentToForgetWeightsBuffer,
681             recurrentToForgetEffectiveScaleA, recurrentToForgetEffectiveScaleB, batchSize,
682             outputSize, numUnits,
683             /*outputZeroPoint=*/0, forgetGateBuffer.data());
684     if (cellToForgetBuffer != nullptr) {
685         VectorBatchVectorCwiseProductAccumulate(
686                 cellToForgetBuffer, outputSize, cellStateBuffer, batchSize,
687                 cellToForgetEffectiveScaleA, cellToForgetEffectiveScaleB, forgetGateBuffer.data());
688     }
689     if (forgetLayerNormBuffer != nullptr) {
690         ApplyLayerNorm(forgetGateBuffer.data(), forgetLayerNormBuffer, forgetBiasBuffer,
691                        forgetLayerNormScaleA, forgetLayerNormScaleB, forgetInvLargeValue, batchSize,
692                        numUnits, forgetGateBuffer.data());
693     }
694     ApplySigmoid(forgetGateBuffer.data(), batchSize, numUnits, forgetGateBuffer.data());
695 
696     // Modulation gate.
697     MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToCellEffectiveBias.get(),
698                                         inputToCellWeightsBuffer, inputToCellEffectiveScaleA,
699                                         inputToCellEffectiveScaleB, batchSize, inputSize, numUnits,
700                                         /*outputZeroPoint=*/0, cellGateBuffer.data());
701     MatrixBatchVectorMultiplyAccumulate(
702             prevOutputBuffer, recurrentToCellEffectiveBias.get(), recurrentToCellWeightsBuffer,
703             recurrentToCellEffectiveScaleA, recurrentToCellEffectiveScaleB, batchSize, outputSize,
704             numUnits,
705             /*outputZeroPoint=*/0, cellGateBuffer.data());
706     if (cellLayerNormBuffer != nullptr) {
707         ApplyLayerNorm(cellGateBuffer.data(), cellLayerNormBuffer, cellBiasBuffer,
708                        cellLayerNormScaleA, cellLayerNormScaleB, cellInvLargeValue, batchSize,
709                        numUnits, cellGateBuffer.data());
710     }
711     ApplyTanh<3>(cellGateBuffer.data(), batchSize, numUnits, cellGateBuffer.data());
712 
713     // Input gate.
714     if (useCifg) {
715         Sub1Vector(forgetGateBuffer.data(), batchSize * numUnits, inputGateBuffer.data());
716     } else {
717         MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToInputEffectiveBias.get(),
718                                             inputToInputWeightsBuffer, inputToInputEffectiveScaleA,
719                                             inputToInputEffectiveScaleB, batchSize, inputSize,
720                                             numUnits,
721                                             /*outputZeroPoint=*/0, inputGateBuffer.data());
722         MatrixBatchVectorMultiplyAccumulate(
723                 prevOutputBuffer, recurrentToInputEffectiveBias.get(),
724                 recurrentToInputWeightsBuffer, recurrentToInputEffectiveScaleA,
725                 recurrentToInputEffectiveScaleB, batchSize, outputSize, numUnits,
726                 /*outputZeroPoint=*/0, inputGateBuffer.data());
727         if (cellToInputBuffer != nullptr) {
728             VectorBatchVectorCwiseProductAccumulate(
729                     cellToInputBuffer, outputSize, cellStateBuffer, batchSize,
730                     cellToInputEffectiveScaleA, cellToInputEffectiveScaleB, inputGateBuffer.data());
731         }
732         if (inputLayerNormBuffer != nullptr) {
733             ApplyLayerNorm(inputGateBuffer.data(), inputLayerNormBuffer, inputBiasBuffer,
734                            inputLayerNormScaleA, inputLayerNormScaleB, inputInvLargeValue,
735                            batchSize, numUnits, inputGateBuffer.data());
736         }
737         ApplySigmoid(inputGateBuffer.data(), batchSize, numUnits, inputGateBuffer.data());
738     }
739 
740     // Cell.
741     CwiseMul(forgetGateBuffer.data(), prevCellStateBuffer, batchSize, numUnits,
742              /*shift=*/15, forgetGateBuffer.data());
743     CwiseMul(inputGateBuffer.data(), cellGateBuffer.data(), batchSize, numUnits, 30 + cellShift,
744              cellGateBuffer.data());
745     CwiseAdd(forgetGateBuffer.data(), cellGateBuffer.data(), batchSize, numUnits, cellStateBuffer);
746     if (quantizedCellClip > 0) {
747         CwiseClipping(cellStateBuffer, quantizedCellClip, batchSize, numUnits);
748     }
749 
750     // Output gate.
751     MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToOutputEffectiveBias.get(),
752                                         inputToOutputWeightsBuffer, inputToOutputEffectiveScaleA,
753                                         inputToOutputEffectiveScaleB, batchSize, inputSize,
754                                         numUnits,
755                                         /*outputZeroPoint=*/0, outputGateBuffer.data());
756     MatrixBatchVectorMultiplyAccumulate(
757             prevOutputBuffer, recurrentToOutputEffectiveBias.get(), recurrentToOutputWeightsBuffer,
758             recurrentToOutputEffectiveScaleA, recurrentToOutputEffectiveScaleB, batchSize,
759             outputSize, numUnits,
760             /*outputZeroPoint=*/0, outputGateBuffer.data());
761     if (cellToOutputBuffer != nullptr) {
762         VectorBatchVectorCwiseProductAccumulate(
763                 cellToOutputBuffer, outputSize, cellStateBuffer, batchSize,
764                 cellToOutputEffectiveScaleA, cellToOutputEffectiveScaleB, outputGateBuffer.data());
765     }
766     if (outputLayerNormBuffer != nullptr) {
767         ApplyLayerNorm(outputGateBuffer.data(), outputLayerNormBuffer, outputBiasBuffer,
768                        outputLayerNormScaleA, outputLayerNormScaleB, outputInvLargeValue, batchSize,
769                        numUnits, outputGateBuffer.data());
770     }
771     ApplySigmoid(outputGateBuffer.data(), batchSize, numUnits, outputGateBuffer.data());
772 
773     // Hidden.
774     ApplyTanh(cellShift + 15, cellStateBuffer, batchSize, numUnits, inputGateBuffer.data());
775     CwiseMul(outputGateBuffer.data(), inputGateBuffer.data(), hiddenStateEffectiveScaleA,
776              hiddenStateEffectiveScaleB, batchSize, numUnits, hiddenStateZeroPoint, buffer8.data());
777 
778     // Projection.
779     if (projectionWeightsBuffer != nullptr) {
780         memset(outputBuffer, 0, batchSize * outputSize * sizeof(int8_t));
781         MatrixBatchVectorMultiplyAccumulate(buffer8.data(), projectionEffectiveBias.get(),
782                                             projectionWeightsBuffer, projectionEffectiveScaleA,
783                                             projectionEffectiveScaleB, batchSize, numUnits,
784                                             outputSize, prevOutputShape.offset, outputBuffer);
785         if (quantizedProjectionClip > 0) {
786             CwiseClipping(outputBuffer, quantizedProjectionClip, batchSize, outputSize);
787         }
788     } else {
789         std::copy_n(buffer8.data(), batchSize * outputSize, outputBuffer);
790     }
791 
792     // Copy output to output state out.
793     for (unsigned int i = 0; i < batchSize * outputSize; ++i) {
794         outputStateBuffer[i] = outputBuffer[i];
795     }
796 
797     return true;
798 }
799 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
800 
801 }  // namespace qlstm
802 
803 NN_REGISTER_OPERATION(QUANTIZED_LSTM, "QUANTIZED_LSTM", qlstm::validate, qlstm::prepare,
804                       qlstm::execute, .allowOmittedOperand = true);
805 
806 }  // namespace nn
807 }  // namespace android
808