1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include <memory>
19 #include <vector>
20
21 #include "CpuExecutor.h"
22 #include "OperationsUtils.h"
23
24 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
25 #include "QuantUtils.h"
26 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
27
28 namespace android {
29 namespace nn {
30 namespace qlstm {
31
32 namespace {
33
34 // Inputs
35 constexpr uint32_t kNumInputs = 32;
36
37 constexpr uint32_t kInputTensor = 0;
38 // Input weight tensors of size: [numUnits, inputSize].
39 constexpr uint32_t kInputToInputWeightsTensor = 1;
40 constexpr uint32_t kInputToForgetWeightsTensor = 2;
41 constexpr uint32_t kInputToCellWeightsTensor = 3;
42 constexpr uint32_t kInputToOutputWeightsTensor = 4;
43
44 // Recurrent weight tensors of size [numUnits, outputSize].
45 constexpr uint32_t kRecurrentToInputWeightsTensor = 5;
46 constexpr uint32_t kRecurrentToForgetWeightsTensor = 6;
47 constexpr uint32_t kRecurrentToCellWeightsTensor = 7;
48 constexpr uint32_t kRecurrentToOutputWeightsTensor = 8;
49
50 // For peephole (optional).
51 // Cell to input/forget/output weights of size [numUnits].
52 constexpr uint32_t kCellToInputWeightsTensor = 9;
53 constexpr uint32_t kCellToForgetWeightsTensor = 10;
54 constexpr uint32_t kCellToOutputWeightsTensor = 11;
55
56 // Gates bias tensors of size [numUnits].
57 constexpr uint32_t kInputGateBiasTensor = 12;
58 constexpr uint32_t kForgetGateBiasTensor = 13;
59 constexpr uint32_t kCellGateBiasTensor = 14;
60 constexpr uint32_t kOutputGateBiasTensor = 15;
61
62 // Projection weight tensor of size [outputSize, numUnits].
63 constexpr uint32_t kProjectionWeightsTensor = 16;
64 // Projection bias tensor of size [outputSize].
65 constexpr uint32_t kProjectionBiasTensor = 17;
66
67 // Output from the previous time step, as tensor
68 // of size [numBatches, outputSize].
69 constexpr uint32_t kPrevOutputTensor = 18;
70
71 // Cell state from the previous time step, as tensor
72 // of size [numBatches, numUnits].
73 constexpr uint32_t kPrevCellStateTensor = 19;
74
75 // Layer normalization tensors of size [numUnits].
76 constexpr uint32_t kInputLayerNormTensor = 20;
77 constexpr uint32_t kForgetLayerNormTensor = 21;
78 constexpr uint32_t kCellLayerNormTensor = 22;
79 constexpr uint32_t kOutputLayerNormTensor = 23;
80
81 // Clipping.
82 constexpr uint32_t kCellClip = 24;
83 constexpr uint32_t kProjectionClip = 25;
84
85 // Scales of the result of matmul, i.e. input to layer normalization.
86 constexpr uint32_t kInputIntermediateScale = 26;
87 constexpr uint32_t kForgetIntermediateScale = 27;
88 constexpr uint32_t kCellIntermediateScale = 28;
89 constexpr uint32_t kOutputIntermediateScale = 29;
90
91 // Zero point and scale of hidden state.
92 constexpr uint32_t kHiddenStateZeroPoint = 30;
93 constexpr uint32_t kHiddenStateScale = 31;
94
95 // Outputs:
96 constexpr uint32_t kNumOutputs = 3;
97 constexpr uint32_t kOutputStateOutTensor = 0;
98 constexpr uint32_t kCellStateOutTensor = 1;
99 constexpr uint32_t kOutputTensor = 2;
100
hasTensor(IOperationExecutionContext * context,const uint32_t tensor)101 inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {
102 return context->getInputBuffer(tensor) != nullptr;
103 }
104
105 } // namespace
106
validate(const IOperationValidationContext * context)107 Result<Version> validate(const IOperationValidationContext* context) {
108 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
109 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
110
111 std::vector<OperandType> inExpectedTypes;
112 // Input.
113 inExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED);
114 // Input-to-* and recurrent-to-* weights.
115 for (int i = 0; i < 8; ++i) {
116 inExpectedTypes.push_back(OperandType::TENSOR_QUANT8_SYMM);
117 }
118 // Cell-to-* weights.
119 for (int i = 0; i < 3; ++i) {
120 inExpectedTypes.push_back(OperandType::TENSOR_QUANT16_SYMM);
121 }
122 // Gate biases.
123 for (int i = 0; i < 4; ++i) {
124 inExpectedTypes.push_back(OperandType::TENSOR_INT32);
125 }
126 // Projection.
127 inExpectedTypes.push_back(OperandType::TENSOR_QUANT8_SYMM);
128 inExpectedTypes.push_back(OperandType::TENSOR_INT32);
129 // Previous output.
130 inExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED);
131 // Previous cell state.
132 inExpectedTypes.push_back(OperandType::TENSOR_QUANT16_SYMM);
133 // Layer norm weights
134 for (int i = 0; i < 4; ++i) {
135 inExpectedTypes.push_back(OperandType::TENSOR_QUANT16_SYMM);
136 }
137 // Cell/projection clipping and scales of intermediate results at the 4 gates.
138 for (int i = 0; i < 6; ++i) {
139 inExpectedTypes.push_back(OperandType::FLOAT32);
140 }
141 // Zero point and scale of the hidden state.
142 inExpectedTypes.push_back(OperandType::INT32);
143 inExpectedTypes.push_back(OperandType::FLOAT32);
144 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
145
146 std::vector<OperandType> outExpectedTypes;
147 // Output state (out).
148 outExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED);
149 // Cell state (out).
150 outExpectedTypes.push_back(OperandType::TENSOR_QUANT16_SYMM);
151 // Output.
152 outExpectedTypes.push_back(OperandType::TENSOR_QUANT8_ASYMM_SIGNED);
153 NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
154
155 return Version::ANDROID_R;
156 }
157
prepare(IOperationExecutionContext * context)158 bool prepare(IOperationExecutionContext* context) {
159 // Check that none of the required inputs are omitted
160 const std::vector<int> requiredTensorInputs = {
161 kInputTensor,
162 kInputToForgetWeightsTensor,
163 kInputToCellWeightsTensor,
164 kInputToOutputWeightsTensor,
165 kRecurrentToForgetWeightsTensor,
166 kRecurrentToCellWeightsTensor,
167 kRecurrentToOutputWeightsTensor,
168 kForgetGateBiasTensor,
169 kCellGateBiasTensor,
170 kOutputGateBiasTensor,
171 kPrevOutputTensor,
172 kPrevCellStateTensor,
173 };
174 for (const int tensor : requiredTensorInputs) {
175 NN_RET_CHECK(!context->isOmittedInput(tensor))
176 << "required input " << tensor << " is omitted";
177 }
178
179 const Shape inputShape = context->getInputShape(kInputTensor);
180 const uint32_t inputRank = getNumberOfDimensions(inputShape);
181 NN_RET_CHECK_EQ(inputRank, 2) << "Invalid input tensor rank: " << inputRank;
182
183 const uint32_t batchSize = getSizeOfDimension(inputShape, 0);
184 const uint32_t inputSize = getSizeOfDimension(inputShape, 1);
185
186 const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor);
187 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2);
188 NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize);
189 const uint32_t numUnits = getSizeOfDimension(inputToOutputShape, 0);
190
191 const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor);
192 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2);
193 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numUnits);
194 const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1);
195
196 if (hasTensor(context, kInputToInputWeightsTensor)) {
197 const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor);
198 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2);
199 NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numUnits);
200 NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize);
201 }
202
203 const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor);
204 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2);
205 NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numUnits);
206 NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize);
207 const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor);
208 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2);
209 NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numUnits);
210 NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize);
211
212 if (hasTensor(context, kRecurrentToInputWeightsTensor)) {
213 const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor);
214 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2);
215 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numUnits);
216 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize);
217 }
218
219 const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor);
220 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2);
221 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numUnits);
222 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize);
223 const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor);
224 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2);
225 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numUnits);
226 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize);
227
228 // Make sure the input-gate's parameters are either all present (non-CIFG) or
229 // not at all (CIFG).
230 const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) &&
231 hasTensor(context, kRecurrentToInputWeightsTensor)) ||
232 (!hasTensor(context, kInputToInputWeightsTensor) &&
233 !hasTensor(context, kRecurrentToInputWeightsTensor));
234 NN_RET_CHECK(cifgWeightsAllOrNone);
235
236 if (hasTensor(context, kCellToInputWeightsTensor)) {
237 const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
238 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1);
239 NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numUnits);
240 }
241
242 if (hasTensor(context, kCellToForgetWeightsTensor)) {
243 const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
244 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1);
245 NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numUnits);
246 }
247
248 if (hasTensor(context, kCellToOutputWeightsTensor)) {
249 const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
250 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1);
251 NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numUnits);
252 }
253
254 // Making sure the peephole weights are there all or none.
255 const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor);
256 const bool peepholeWeightsAllOrNone =
257 ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) &&
258 hasTensor(context, kCellToForgetWeightsTensor) &&
259 hasTensor(context, kCellToOutputWeightsTensor)) ||
260 (!hasTensor(context, kCellToInputWeightsTensor) &&
261 !hasTensor(context, kCellToForgetWeightsTensor) &&
262 !hasTensor(context, kCellToOutputWeightsTensor));
263 NN_RET_CHECK(peepholeWeightsAllOrNone);
264
265 if (!cifgUsed) {
266 NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor));
267 const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor);
268 NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1);
269 NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numUnits);
270 } else {
271 NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor))
272 << "Input gate bias tensor is present when CIFG is used";
273 }
274
275 const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor);
276 NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1);
277 NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numUnits);
278 const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor);
279 NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1);
280 NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numUnits);
281 const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor);
282 NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1);
283 NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numUnits);
284
285 if (hasTensor(context, kProjectionWeightsTensor)) {
286 const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor);
287 NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2);
288 NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize);
289 NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numUnits);
290 }
291
292 if (hasTensor(context, kProjectionBiasTensor)) {
293 const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor);
294 NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1);
295 NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize);
296 }
297
298 const Shape outputStateShape = context->getInputShape(kPrevOutputTensor);
299 NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2);
300 NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize);
301 NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize);
302 const Shape cellStateShape = context->getInputShape(kPrevCellStateTensor);
303 NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2);
304 NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize);
305 NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numUnits);
306
307 if (hasTensor(context, kInputLayerNormTensor)) {
308 const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor);
309 NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1);
310 NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numUnits);
311 }
312
313 if (hasTensor(context, kForgetLayerNormTensor)) {
314 const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor);
315 NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1);
316 NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numUnits);
317 }
318
319 if (hasTensor(context, kCellLayerNormTensor)) {
320 const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor);
321 NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1);
322 NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numUnits);
323 }
324
325 if (hasTensor(context, kOutputLayerNormTensor)) {
326 const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor);
327 NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1);
328 NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numUnits);
329 }
330
331 if (cifgUsed) {
332 NN_RET_CHECK(!hasTensor(context, kInputLayerNormTensor))
333 << "Input layer norm weights tensor is present when CIFG is used";
334 const bool layerNormWeightsAllOrNoneCifg = (hasTensor(context, kForgetLayerNormTensor) &&
335 hasTensor(context, kCellLayerNormTensor) &&
336 hasTensor(context, kOutputLayerNormTensor)) ||
337 (!hasTensor(context, kForgetLayerNormTensor) &&
338 !hasTensor(context, kCellLayerNormTensor) &&
339 !hasTensor(context, kOutputLayerNormTensor));
340 NN_RET_CHECK(layerNormWeightsAllOrNoneCifg);
341 } else {
342 const bool layerNormWeightsAllOrNone = (hasTensor(context, kInputLayerNormTensor) &&
343 hasTensor(context, kForgetLayerNormTensor) &&
344 hasTensor(context, kCellLayerNormTensor) &&
345 hasTensor(context, kOutputLayerNormTensor)) ||
346 (!hasTensor(context, kInputLayerNormTensor) &&
347 !hasTensor(context, kForgetLayerNormTensor) &&
348 !hasTensor(context, kCellLayerNormTensor) &&
349 !hasTensor(context, kOutputLayerNormTensor));
350 NN_RET_CHECK(layerNormWeightsAllOrNone);
351 }
352
353 const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor);
354 Shape outputShape = context->getOutputShape(kOutputTensor);
355 outputShape.dimensions = prevOutputShape.dimensions;
356
357 const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor);
358 Shape cellStateOutShape = context->getOutputShape(kCellStateOutTensor);
359 cellStateOutShape.dimensions = prevCellStateShape.dimensions;
360
361 return context->setOutputShape(kOutputStateOutTensor, outputShape) &&
362 context->setOutputShape(kCellStateOutTensor, cellStateOutShape) &&
363 context->setOutputShape(kOutputTensor, outputShape);
364 }
365
366 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
execute(IOperationExecutionContext * context)367 bool execute(IOperationExecutionContext* context) {
368 // Gets the inputs.
369 const Shape inputShape = context->getInputShape(kInputTensor);
370 const Shape inputToInputWeightsShape = context->getInputShape(kInputToInputWeightsTensor);
371 const Shape recurrentToInputWeightsShape =
372 context->getInputShape(kRecurrentToInputWeightsTensor);
373 const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
374 const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormTensor);
375 const Shape inputToForgetWeightsShape = context->getInputShape(kInputToForgetWeightsTensor);
376 const Shape recurrentToForgetWeightsShape =
377 context->getInputShape(kRecurrentToForgetWeightsTensor);
378 const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
379 const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormTensor);
380 const Shape inputToCellWeightsShape = context->getInputShape(kInputToCellWeightsTensor);
381 const Shape recurrentToCellWeightsShape = context->getInputShape(kRecurrentToCellWeightsTensor);
382 const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormTensor);
383 const Shape inputToOutputWeightsShape = context->getInputShape(kInputToOutputWeightsTensor);
384 const Shape recurrentToOutputWeightsShape =
385 context->getInputShape(kRecurrentToOutputWeightsTensor);
386 const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
387 const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormTensor);
388 const Shape projectionWeightsShape = context->getInputShape(kProjectionWeightsTensor);
389 const Shape prevOutputShape = context->getInputShape(kPrevOutputTensor);
390 const Shape prevCellStateShape = context->getInputShape(kPrevCellStateTensor);
391
392 const uint32_t batchSize = inputShape.dimensions[0];
393 const uint32_t inputSize = inputShape.dimensions[1];
394 const uint32_t numUnits = inputToOutputWeightsShape.dimensions[0];
395 const uint32_t outputSize = recurrentToOutputWeightsShape.dimensions[1];
396
397 const float cellClip = context->getInputValue<float>(kCellClip);
398 const float projectionClip = context->getInputValue<float>(kProjectionClip);
399 const float inputIntermediateScale = context->getInputValue<float>(kInputIntermediateScale);
400 const float forgetIntermediateScale = context->getInputValue<float>(kForgetIntermediateScale);
401 const float cellIntermediateScale = context->getInputValue<float>(kCellIntermediateScale);
402 const float outputIntermediateScale = context->getInputValue<float>(kOutputIntermediateScale);
403 const int8_t hiddenStateZeroPoint = context->getInputValue<int8_t>(kHiddenStateZeroPoint);
404 const float hiddenStateScale = context->getInputValue<float>(kHiddenStateScale);
405
406 const int8_t* inputBuffer =
407 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputTensor));
408
409 const int8_t* inputToInputWeightsBuffer =
410 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToInputWeightsTensor));
411 const bool useCifg = (inputToInputWeightsBuffer == nullptr);
412 const int8_t* recurrentToInputWeightsBuffer = reinterpret_cast<const int8_t*>(
413 context->getInputBuffer(kRecurrentToInputWeightsTensor));
414 const int16_t* cellToInputBuffer =
415 reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToInputWeightsTensor));
416 const int16_t* inputLayerNormBuffer =
417 reinterpret_cast<const int16_t*>(context->getInputBuffer(kInputLayerNormTensor));
418 const int32_t* inputBiasBuffer =
419 reinterpret_cast<const int32_t*>(context->getInputBuffer(kInputGateBiasTensor));
420
421 const int8_t* inputToForgetWeightsBuffer =
422 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToForgetWeightsTensor));
423 const int8_t* recurrentToForgetWeightsBuffer = reinterpret_cast<const int8_t*>(
424 context->getInputBuffer(kRecurrentToForgetWeightsTensor));
425 const int16_t* cellToForgetBuffer =
426 reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToForgetWeightsTensor));
427 const int16_t* forgetLayerNormBuffer =
428 reinterpret_cast<const int16_t*>(context->getInputBuffer(kForgetLayerNormTensor));
429 const int32_t* forgetBiasBuffer =
430 reinterpret_cast<const int32_t*>(context->getInputBuffer(kForgetGateBiasTensor));
431
432 const int8_t* inputToCellWeightsBuffer =
433 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToCellWeightsTensor));
434 const int8_t* recurrentToCellWeightsBuffer =
435 reinterpret_cast<const int8_t*>(context->getInputBuffer(kRecurrentToCellWeightsTensor));
436 const int16_t* cellLayerNormBuffer =
437 reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellLayerNormTensor));
438 const int32_t* cellBiasBuffer =
439 reinterpret_cast<const int32_t*>(context->getInputBuffer(kCellGateBiasTensor));
440
441 const int8_t* inputToOutputWeightsBuffer =
442 reinterpret_cast<const int8_t*>(context->getInputBuffer(kInputToOutputWeightsTensor));
443 const int8_t* recurrentToOutputWeightsBuffer = reinterpret_cast<const int8_t*>(
444 context->getInputBuffer(kRecurrentToOutputWeightsTensor));
445 const int16_t* cellToOutputBuffer =
446 reinterpret_cast<const int16_t*>(context->getInputBuffer(kCellToOutputWeightsTensor));
447 const int16_t* outputLayerNormBuffer =
448 reinterpret_cast<const int16_t*>(context->getInputBuffer(kOutputLayerNormTensor));
449 const int32_t* outputBiasBuffer =
450 reinterpret_cast<const int32_t*>(context->getInputBuffer(kOutputGateBiasTensor));
451
452 const int8_t* projectionWeightsBuffer =
453 reinterpret_cast<const int8_t*>(context->getInputBuffer(kProjectionWeightsTensor));
454 const int32_t* projectionBiasBuffer =
455 reinterpret_cast<const int32_t*>(context->getInputBuffer(kProjectionBiasTensor));
456
457 const int8_t* prevOutputBuffer =
458 reinterpret_cast<const int8_t*>(context->getInputBuffer(kPrevOutputTensor));
459 const int16_t* prevCellStateBuffer =
460 reinterpret_cast<const int16_t*>(context->getInputBuffer(kPrevCellStateTensor));
461
462 uint8_t* outputStateBuffer =
463 reinterpret_cast<uint8_t*>(context->getOutputBuffer(kOutputStateOutTensor));
464 int16_t* cellStateBuffer =
465 reinterpret_cast<int16_t*>(context->getOutputBuffer(kCellStateOutTensor));
466 int8_t* outputBuffer = reinterpret_cast<int8_t*>(context->getOutputBuffer(kOutputTensor));
467
468 // Calculates and decomposes effective scales.
469 // This is for optimizing the matmul calculation.
470 int cellShift;
471 NN_RET_CHECK(CheckedLog2(prevCellStateShape.scale, &cellShift));
472 NN_RET_CHECK(cellShift <= -9);
473
474 int32_t inputToInputEffectiveScaleA;
475 int32_t inputToInputEffectiveScaleB;
476 int32_t recurrentToInputEffectiveScaleA;
477 int32_t recurrentToInputEffectiveScaleB;
478 int32_t cellToInputEffectiveScaleA;
479 int32_t cellToInputEffectiveScaleB;
480 if (!useCifg) {
481 const float inputToInputEffectiveScale =
482 inputToInputWeightsShape.scale * inputShape.scale / inputIntermediateScale;
483 NN_RET_CHECK(QuantizeMultiplier(inputToInputEffectiveScale, &inputToInputEffectiveScaleA,
484 &inputToInputEffectiveScaleB));
485 const float recurrentToInputEffectiveScale =
486 recurrentToInputWeightsShape.scale * prevOutputShape.scale / inputIntermediateScale;
487 NN_RET_CHECK(QuantizeMultiplier(recurrentToInputEffectiveScale,
488 &recurrentToInputEffectiveScaleA,
489 &recurrentToInputEffectiveScaleB));
490 if (cellToInputBuffer != nullptr) {
491 const float cellToInputEffectiveScale =
492 std::pow(2, cellShift) * cellToInputShape.scale / inputIntermediateScale;
493 NN_RET_CHECK(QuantizeMultiplier(cellToInputEffectiveScale, &cellToInputEffectiveScaleA,
494 &cellToInputEffectiveScaleB));
495 }
496 }
497
498 int32_t inputLayerNormScaleA;
499 int32_t inputLayerNormScaleB;
500 if (inputLayerNormBuffer != nullptr) {
501 NN_RET_CHECK(QuantizeMultiplier(inputLayerNormShape.scale, &inputLayerNormScaleA,
502 &inputLayerNormScaleB));
503 }
504
505 const float inputToForgetEffectiveScale =
506 inputToForgetWeightsShape.scale * inputShape.scale / forgetIntermediateScale;
507 int32_t inputToForgetEffectiveScaleA;
508 int32_t inputToForgetEffectiveScaleB;
509 NN_RET_CHECK(QuantizeMultiplier(inputToForgetEffectiveScale, &inputToForgetEffectiveScaleA,
510 &inputToForgetEffectiveScaleB));
511 const float recurrentToForgetEffectiveScale =
512 recurrentToForgetWeightsShape.scale * prevOutputShape.scale / forgetIntermediateScale;
513 int32_t recurrentToForgetEffectiveScaleA;
514 int32_t recurrentToForgetEffectiveScaleB;
515 NN_RET_CHECK(QuantizeMultiplier(recurrentToForgetEffectiveScale,
516 &recurrentToForgetEffectiveScaleA,
517 &recurrentToForgetEffectiveScaleB));
518 int32_t cellToForgetEffectiveScaleA;
519 int32_t cellToForgetEffectiveScaleB;
520 if (cellToForgetBuffer != nullptr) {
521 const float cellToForgetEffectiveScale =
522 std::pow(2, cellShift) * cellToForgetShape.scale / forgetIntermediateScale;
523 NN_RET_CHECK(QuantizeMultiplier(cellToForgetEffectiveScale, &cellToForgetEffectiveScaleA,
524 &cellToForgetEffectiveScaleB));
525 }
526 int32_t forgetLayerNormScaleA;
527 int32_t forgetLayerNormScaleB;
528 if (forgetLayerNormBuffer != nullptr) {
529 NN_RET_CHECK(QuantizeMultiplier(forgetLayerNormShape.scale, &forgetLayerNormScaleA,
530 &forgetLayerNormScaleB));
531 }
532
533 const float inputToCellEffectiveScale =
534 inputToCellWeightsShape.scale * inputShape.scale / cellIntermediateScale;
535 int32_t inputToCellEffectiveScaleA;
536 int32_t inputToCellEffectiveScaleB;
537 NN_RET_CHECK(QuantizeMultiplier(inputToCellEffectiveScale, &inputToCellEffectiveScaleA,
538 &inputToCellEffectiveScaleB));
539 const float recurrentToCellEffectiveScale =
540 recurrentToCellWeightsShape.scale * prevOutputShape.scale / cellIntermediateScale;
541 int32_t recurrentToCellEffectiveScaleA;
542 int32_t recurrentToCellEffectiveScaleB;
543 NN_RET_CHECK(QuantizeMultiplier(recurrentToCellEffectiveScale, &recurrentToCellEffectiveScaleA,
544 &recurrentToCellEffectiveScaleB));
545
546 int32_t cellLayerNormScaleA;
547 int32_t cellLayerNormScaleB;
548 if (cellLayerNormBuffer != nullptr) {
549 NN_RET_CHECK(QuantizeMultiplier(cellLayerNormShape.scale, &cellLayerNormScaleA,
550 &cellLayerNormScaleB));
551 }
552
553 const float inputToOutputEffectiveScale =
554 inputToOutputWeightsShape.scale * inputShape.scale / outputIntermediateScale;
555 int32_t inputToOutputEffectiveScaleA;
556 int32_t inputToOutputEffectiveScaleB;
557 NN_RET_CHECK(QuantizeMultiplier(inputToOutputEffectiveScale, &inputToOutputEffectiveScaleA,
558 &inputToOutputEffectiveScaleB));
559 const float recurrentToOutputEffectiveScale =
560 recurrentToOutputWeightsShape.scale * prevOutputShape.scale / outputIntermediateScale;
561 int32_t recurrentToOutputEffectiveScaleA;
562 int32_t recurrentToOutputEffectiveScaleB;
563 NN_RET_CHECK(QuantizeMultiplier(recurrentToOutputEffectiveScale,
564 &recurrentToOutputEffectiveScaleA,
565 &recurrentToOutputEffectiveScaleB));
566 int32_t cellToOutputEffectiveScaleA;
567 int32_t cellToOutputEffectiveScaleB;
568 if (cellToOutputBuffer != nullptr) {
569 const float cellToOutputEffectiveScale =
570 std::pow(2, cellShift) * cellToOutputShape.scale / outputIntermediateScale;
571 NN_RET_CHECK(QuantizeMultiplier(cellToOutputEffectiveScale, &cellToOutputEffectiveScaleA,
572 &cellToOutputEffectiveScaleB));
573 }
574 int32_t outputLayerNormScaleA;
575 int32_t outputLayerNormScaleB;
576 if (outputLayerNormBuffer != nullptr) {
577 NN_RET_CHECK(QuantizeMultiplier(outputLayerNormShape.scale, &outputLayerNormScaleA,
578 &outputLayerNormScaleB));
579 }
580
581 const float hiddenStateEffectiveScale = std::pow(2, -15) / hiddenStateScale * std::pow(2, -15);
582 int32_t hiddenStateEffectiveScaleA;
583 int32_t hiddenStateEffectiveScaleB;
584 NN_RET_CHECK(QuantizeMultiplier(hiddenStateEffectiveScale, &hiddenStateEffectiveScaleA,
585 &hiddenStateEffectiveScaleB));
586
587 int32_t projectionEffectiveScaleA;
588 int32_t projectionEffectiveScaleB;
589 if (projectionWeightsBuffer != nullptr) {
590 const float projectionEffectiveScale =
591 projectionWeightsShape.scale * hiddenStateScale / prevOutputShape.scale;
592 NN_RET_CHECK(QuantizeMultiplier(projectionEffectiveScale, &projectionEffectiveScaleA,
593 &projectionEffectiveScaleB));
594 }
595
596 // Calculates quantized clipping parameters.
597 int16_t quantizedCellClip = 0;
598 if (cellClip > 0.0) {
599 quantizedCellClip = static_cast<int32_t>(
600 std::min(std::max(cellClip / prevCellStateShape.scale, -32768.0f), 32767.0f));
601 }
602 int8_t quantizedProjectionClip = 0;
603 if (projectionClip > 0.0) {
604 quantizedProjectionClip = static_cast<int32_t>(
605 std::min(std::max(projectionClip / projectionWeightsShape.scale, -128.0f), 127.0f));
606 }
607
608 // Calculates effective bias.
609 // This is for optimizing the matmul calculation.
610 std::unique_ptr<int32_t[]> inputToInputEffectiveBias;
611 std::unique_ptr<int32_t[]> recurrentToInputEffectiveBias;
612 if (!useCifg) {
613 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
614 -inputShape.offset, inputToInputWeightsBuffer, inputToInputWeightsShape,
615 /*bias=*/nullptr, &inputToInputEffectiveBias));
616 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
617 -prevOutputShape.offset, recurrentToInputWeightsBuffer,
618 recurrentToInputWeightsShape,
619 /*bias=*/nullptr, &recurrentToInputEffectiveBias));
620 }
621
622 std::unique_ptr<int32_t[]> inputToForgetEffectiveBias;
623 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
624 -inputShape.offset, inputToForgetWeightsBuffer, inputToForgetWeightsShape,
625 /*bias=*/nullptr, &inputToForgetEffectiveBias));
626 std::unique_ptr<int32_t[]> recurrentToForgetEffectiveBias;
627 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
628 -prevOutputShape.offset, recurrentToForgetWeightsBuffer, recurrentToForgetWeightsShape,
629 /*bias=*/nullptr, &recurrentToForgetEffectiveBias));
630
631 std::unique_ptr<int32_t[]> inputToCellEffectiveBias;
632 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
633 -inputShape.offset, inputToCellWeightsBuffer, inputToCellWeightsShape,
634 /*bias=*/nullptr, &inputToCellEffectiveBias));
635 std::unique_ptr<int32_t[]> recurrentToCellEffectiveBias;
636 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
637 -prevOutputShape.offset, recurrentToCellWeightsBuffer, recurrentToCellWeightsShape,
638 /*bias=*/nullptr, &recurrentToCellEffectiveBias));
639
640 std::unique_ptr<int32_t[]> inputToOutputEffectiveBias;
641 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
642 -inputShape.offset, inputToOutputWeightsBuffer, inputToOutputWeightsShape,
643 /*bias=*/nullptr, &inputToOutputEffectiveBias));
644 std::unique_ptr<int32_t[]> recurrentToOutputEffectiveBias;
645 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
646 -prevOutputShape.offset, recurrentToOutputWeightsBuffer, recurrentToOutputWeightsShape,
647 /*bias=*/nullptr, &recurrentToOutputEffectiveBias));
648
649 std::unique_ptr<int32_t[]> projectionEffectiveBias;
650 if (projectionBiasBuffer != nullptr) {
651 NN_RET_CHECK(PrecomputeZeroPointTimesWeightWithBias(
652 hiddenStateZeroPoint, projectionWeightsBuffer, projectionWeightsShape,
653 projectionBiasBuffer, &projectionEffectiveBias));
654 }
655
656 // Temporary buffers.
657 std::vector<int16_t> inputGateBuffer(batchSize * numUnits);
658 std::vector<int16_t> forgetGateBuffer(batchSize * numUnits);
659 std::vector<int16_t> cellGateBuffer(batchSize * numUnits);
660 std::vector<int16_t> outputGateBuffer(batchSize * numUnits);
661 std::vector<int8_t> buffer8(batchSize * numUnits);
662
663 // To avoid overflow when calculating layer norm.
664 const int32_t inputInvLargeValue =
665 std::min(1, static_cast<int32_t>(10000 * inputLayerNormShape.scale));
666 const int32_t forgetInvLargeValue =
667 std::min(1, static_cast<int32_t>(10000 * forgetLayerNormShape.scale));
668 const int32_t cellInvLargeValue =
669 std::min(1, static_cast<int32_t>(10000 * cellLayerNormShape.scale));
670 const int32_t outputInvLargeValue =
671 std::min(1, static_cast<int32_t>(10000 * outputLayerNormShape.scale));
672
673 // Forget gate.
674 MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToForgetEffectiveBias.get(),
675 inputToForgetWeightsBuffer, inputToForgetEffectiveScaleA,
676 inputToForgetEffectiveScaleB, batchSize, inputSize,
677 numUnits,
678 /*outputZeroPoint=*/0, forgetGateBuffer.data());
679 MatrixBatchVectorMultiplyAccumulate(
680 prevOutputBuffer, recurrentToForgetEffectiveBias.get(), recurrentToForgetWeightsBuffer,
681 recurrentToForgetEffectiveScaleA, recurrentToForgetEffectiveScaleB, batchSize,
682 outputSize, numUnits,
683 /*outputZeroPoint=*/0, forgetGateBuffer.data());
684 if (cellToForgetBuffer != nullptr) {
685 VectorBatchVectorCwiseProductAccumulate(
686 cellToForgetBuffer, outputSize, cellStateBuffer, batchSize,
687 cellToForgetEffectiveScaleA, cellToForgetEffectiveScaleB, forgetGateBuffer.data());
688 }
689 if (forgetLayerNormBuffer != nullptr) {
690 ApplyLayerNorm(forgetGateBuffer.data(), forgetLayerNormBuffer, forgetBiasBuffer,
691 forgetLayerNormScaleA, forgetLayerNormScaleB, forgetInvLargeValue, batchSize,
692 numUnits, forgetGateBuffer.data());
693 }
694 ApplySigmoid(forgetGateBuffer.data(), batchSize, numUnits, forgetGateBuffer.data());
695
696 // Modulation gate.
697 MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToCellEffectiveBias.get(),
698 inputToCellWeightsBuffer, inputToCellEffectiveScaleA,
699 inputToCellEffectiveScaleB, batchSize, inputSize, numUnits,
700 /*outputZeroPoint=*/0, cellGateBuffer.data());
701 MatrixBatchVectorMultiplyAccumulate(
702 prevOutputBuffer, recurrentToCellEffectiveBias.get(), recurrentToCellWeightsBuffer,
703 recurrentToCellEffectiveScaleA, recurrentToCellEffectiveScaleB, batchSize, outputSize,
704 numUnits,
705 /*outputZeroPoint=*/0, cellGateBuffer.data());
706 if (cellLayerNormBuffer != nullptr) {
707 ApplyLayerNorm(cellGateBuffer.data(), cellLayerNormBuffer, cellBiasBuffer,
708 cellLayerNormScaleA, cellLayerNormScaleB, cellInvLargeValue, batchSize,
709 numUnits, cellGateBuffer.data());
710 }
711 ApplyTanh<3>(cellGateBuffer.data(), batchSize, numUnits, cellGateBuffer.data());
712
713 // Input gate.
714 if (useCifg) {
715 Sub1Vector(forgetGateBuffer.data(), batchSize * numUnits, inputGateBuffer.data());
716 } else {
717 MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToInputEffectiveBias.get(),
718 inputToInputWeightsBuffer, inputToInputEffectiveScaleA,
719 inputToInputEffectiveScaleB, batchSize, inputSize,
720 numUnits,
721 /*outputZeroPoint=*/0, inputGateBuffer.data());
722 MatrixBatchVectorMultiplyAccumulate(
723 prevOutputBuffer, recurrentToInputEffectiveBias.get(),
724 recurrentToInputWeightsBuffer, recurrentToInputEffectiveScaleA,
725 recurrentToInputEffectiveScaleB, batchSize, outputSize, numUnits,
726 /*outputZeroPoint=*/0, inputGateBuffer.data());
727 if (cellToInputBuffer != nullptr) {
728 VectorBatchVectorCwiseProductAccumulate(
729 cellToInputBuffer, outputSize, cellStateBuffer, batchSize,
730 cellToInputEffectiveScaleA, cellToInputEffectiveScaleB, inputGateBuffer.data());
731 }
732 if (inputLayerNormBuffer != nullptr) {
733 ApplyLayerNorm(inputGateBuffer.data(), inputLayerNormBuffer, inputBiasBuffer,
734 inputLayerNormScaleA, inputLayerNormScaleB, inputInvLargeValue,
735 batchSize, numUnits, inputGateBuffer.data());
736 }
737 ApplySigmoid(inputGateBuffer.data(), batchSize, numUnits, inputGateBuffer.data());
738 }
739
740 // Cell.
741 CwiseMul(forgetGateBuffer.data(), prevCellStateBuffer, batchSize, numUnits,
742 /*shift=*/15, forgetGateBuffer.data());
743 CwiseMul(inputGateBuffer.data(), cellGateBuffer.data(), batchSize, numUnits, 30 + cellShift,
744 cellGateBuffer.data());
745 CwiseAdd(forgetGateBuffer.data(), cellGateBuffer.data(), batchSize, numUnits, cellStateBuffer);
746 if (quantizedCellClip > 0) {
747 CwiseClipping(cellStateBuffer, quantizedCellClip, batchSize, numUnits);
748 }
749
750 // Output gate.
751 MatrixBatchVectorMultiplyAccumulate(inputBuffer, inputToOutputEffectiveBias.get(),
752 inputToOutputWeightsBuffer, inputToOutputEffectiveScaleA,
753 inputToOutputEffectiveScaleB, batchSize, inputSize,
754 numUnits,
755 /*outputZeroPoint=*/0, outputGateBuffer.data());
756 MatrixBatchVectorMultiplyAccumulate(
757 prevOutputBuffer, recurrentToOutputEffectiveBias.get(), recurrentToOutputWeightsBuffer,
758 recurrentToOutputEffectiveScaleA, recurrentToOutputEffectiveScaleB, batchSize,
759 outputSize, numUnits,
760 /*outputZeroPoint=*/0, outputGateBuffer.data());
761 if (cellToOutputBuffer != nullptr) {
762 VectorBatchVectorCwiseProductAccumulate(
763 cellToOutputBuffer, outputSize, cellStateBuffer, batchSize,
764 cellToOutputEffectiveScaleA, cellToOutputEffectiveScaleB, outputGateBuffer.data());
765 }
766 if (outputLayerNormBuffer != nullptr) {
767 ApplyLayerNorm(outputGateBuffer.data(), outputLayerNormBuffer, outputBiasBuffer,
768 outputLayerNormScaleA, outputLayerNormScaleB, outputInvLargeValue, batchSize,
769 numUnits, outputGateBuffer.data());
770 }
771 ApplySigmoid(outputGateBuffer.data(), batchSize, numUnits, outputGateBuffer.data());
772
773 // Hidden.
774 ApplyTanh(cellShift + 15, cellStateBuffer, batchSize, numUnits, inputGateBuffer.data());
775 CwiseMul(outputGateBuffer.data(), inputGateBuffer.data(), hiddenStateEffectiveScaleA,
776 hiddenStateEffectiveScaleB, batchSize, numUnits, hiddenStateZeroPoint, buffer8.data());
777
778 // Projection.
779 if (projectionWeightsBuffer != nullptr) {
780 memset(outputBuffer, 0, batchSize * outputSize * sizeof(int8_t));
781 MatrixBatchVectorMultiplyAccumulate(buffer8.data(), projectionEffectiveBias.get(),
782 projectionWeightsBuffer, projectionEffectiveScaleA,
783 projectionEffectiveScaleB, batchSize, numUnits,
784 outputSize, prevOutputShape.offset, outputBuffer);
785 if (quantizedProjectionClip > 0) {
786 CwiseClipping(outputBuffer, quantizedProjectionClip, batchSize, outputSize);
787 }
788 } else {
789 std::copy_n(buffer8.data(), batchSize * outputSize, outputBuffer);
790 }
791
792 // Copy output to output state out.
793 for (unsigned int i = 0; i < batchSize * outputSize; ++i) {
794 outputStateBuffer[i] = outputBuffer[i];
795 }
796
797 return true;
798 }
799 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
800
801 } // namespace qlstm
802
803 NN_REGISTER_OPERATION(QUANTIZED_LSTM, "QUANTIZED_LSTM", qlstm::validate, qlstm::prepare,
804 qlstm::execute, .allowOmittedOperand = true);
805
806 } // namespace nn
807 } // namespace android
808