1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19
20 #include <iostream>
21 #include <vector>
22
23 #include "NeuralNetworksWrapper.h"
24 #include "QuantizedLSTM.h"
25
26 namespace android {
27 namespace nn {
28 namespace wrapper {
29
30 namespace {
31
32 struct OperandTypeParams {
33 Type type;
34 std::vector<uint32_t> shape;
35 float scale;
36 int32_t zeroPoint;
37
OperandTypeParamsandroid::nn::wrapper::__anon48a6dcc50110::OperandTypeParams38 OperandTypeParams(Type type, std::vector<uint32_t> shape, float scale, int32_t zeroPoint)
39 : type(type), shape(shape), scale(scale), zeroPoint(zeroPoint) {}
40 };
41
42 } // namespace
43
44 using ::testing::Each;
45 using ::testing::ElementsAreArray;
46 using ::testing::FloatNear;
47 using ::testing::Matcher;
48
49 class QuantizedLSTMOpModel {
50 public:
QuantizedLSTMOpModel(const std::vector<OperandTypeParams> & inputOperandTypeParams)51 QuantizedLSTMOpModel(const std::vector<OperandTypeParams>& inputOperandTypeParams) {
52 std::vector<uint32_t> inputs;
53
54 for (int i = 0; i < NUM_INPUTS; ++i) {
55 const auto& curOTP = inputOperandTypeParams[i];
56 OperandType curType(curOTP.type, curOTP.shape, curOTP.scale, curOTP.zeroPoint);
57 inputs.push_back(model_.addOperand(&curType));
58 }
59
60 const uint32_t numBatches = inputOperandTypeParams[0].shape[0];
61 inputSize_ = inputOperandTypeParams[0].shape[0];
62 const uint32_t outputSize =
63 inputOperandTypeParams[QuantizedLSTMCell::kPrevCellStateTensor].shape[1];
64 outputSize_ = outputSize;
65
66 std::vector<uint32_t> outputs;
67 OperandType cellStateOutOperandType(Type::TENSOR_QUANT16_SYMM, {numBatches, outputSize},
68 1. / 2048., 0);
69 outputs.push_back(model_.addOperand(&cellStateOutOperandType));
70 OperandType outputOperandType(Type::TENSOR_QUANT8_ASYMM, {numBatches, outputSize},
71 1. / 128., 128);
72 outputs.push_back(model_.addOperand(&outputOperandType));
73
74 model_.addOperation(ANEURALNETWORKS_QUANTIZED_16BIT_LSTM, inputs, outputs);
75 model_.identifyInputsAndOutputs(inputs, outputs);
76
77 initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kInputTensor], &input_);
78 initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kPrevOutputTensor],
79 &prevOutput_);
80 initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kPrevCellStateTensor],
81 &prevCellState_);
82
83 cellStateOut_.resize(numBatches * outputSize, 0);
84 output_.resize(numBatches * outputSize, 0);
85
86 model_.finish();
87 }
88
invoke()89 void invoke() {
90 ASSERT_TRUE(model_.isValid());
91
92 Compilation compilation(&model_);
93 compilation.finish();
94 Execution execution(&compilation);
95
96 // Set all the inputs.
97 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputTensor, input_),
98 Result::NO_ERROR);
99 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToInputWeightsTensor,
100 inputToInputWeights_),
101 Result::NO_ERROR);
102 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToForgetWeightsTensor,
103 inputToForgetWeights_),
104 Result::NO_ERROR);
105 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToCellWeightsTensor,
106 inputToCellWeights_),
107 Result::NO_ERROR);
108 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToOutputWeightsTensor,
109 inputToOutputWeights_),
110 Result::NO_ERROR);
111 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToInputWeightsTensor,
112 recurrentToInputWeights_),
113 Result::NO_ERROR);
114 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToForgetWeightsTensor,
115 recurrentToForgetWeights_),
116 Result::NO_ERROR);
117 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToCellWeightsTensor,
118 recurrentToCellWeights_),
119 Result::NO_ERROR);
120 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToOutputWeightsTensor,
121 recurrentToOutputWeights_),
122 Result::NO_ERROR);
123 ASSERT_EQ(
124 setInputTensor(&execution, QuantizedLSTMCell::kInputGateBiasTensor, inputGateBias_),
125 Result::NO_ERROR);
126 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kForgetGateBiasTensor,
127 forgetGateBias_),
128 Result::NO_ERROR);
129 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kCellGateBiasTensor, cellGateBias_),
130 Result::NO_ERROR);
131 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kOutputGateBiasTensor,
132 outputGateBias_),
133 Result::NO_ERROR);
134 ASSERT_EQ(
135 setInputTensor(&execution, QuantizedLSTMCell::kPrevCellStateTensor, prevCellState_),
136 Result::NO_ERROR);
137 ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kPrevOutputTensor, prevOutput_),
138 Result::NO_ERROR);
139 // Set all the outputs.
140 ASSERT_EQ(
141 setOutputTensor(&execution, QuantizedLSTMCell::kCellStateOutTensor, &cellStateOut_),
142 Result::NO_ERROR);
143 ASSERT_EQ(setOutputTensor(&execution, QuantizedLSTMCell::kOutputTensor, &output_),
144 Result::NO_ERROR);
145
146 ASSERT_EQ(execution.compute(), Result::NO_ERROR);
147
148 // Put state outputs into inputs for the next step
149 prevOutput_ = output_;
150 prevCellState_ = cellStateOut_;
151 }
152
inputSize()153 int inputSize() { return inputSize_; }
154
outputSize()155 int outputSize() { return outputSize_; }
156
setInput(const std::vector<uint8_t> & input)157 void setInput(const std::vector<uint8_t>& input) { input_ = input; }
158
setWeightsAndBiases(std::vector<uint8_t> inputToInputWeights,std::vector<uint8_t> inputToForgetWeights,std::vector<uint8_t> inputToCellWeights,std::vector<uint8_t> inputToOutputWeights,std::vector<uint8_t> recurrentToInputWeights,std::vector<uint8_t> recurrentToForgetWeights,std::vector<uint8_t> recurrentToCellWeights,std::vector<uint8_t> recurrentToOutputWeights,std::vector<int32_t> inputGateBias,std::vector<int32_t> forgetGateBias,std::vector<int32_t> cellGateBias,std::vector<int32_t> outputGateBias)159 void setWeightsAndBiases(std::vector<uint8_t> inputToInputWeights,
160 std::vector<uint8_t> inputToForgetWeights,
161 std::vector<uint8_t> inputToCellWeights,
162 std::vector<uint8_t> inputToOutputWeights,
163 std::vector<uint8_t> recurrentToInputWeights,
164 std::vector<uint8_t> recurrentToForgetWeights,
165 std::vector<uint8_t> recurrentToCellWeights,
166 std::vector<uint8_t> recurrentToOutputWeights,
167 std::vector<int32_t> inputGateBias,
168 std::vector<int32_t> forgetGateBias,
169 std::vector<int32_t> cellGateBias, //
170 std::vector<int32_t> outputGateBias) {
171 inputToInputWeights_ = inputToInputWeights;
172 inputToForgetWeights_ = inputToForgetWeights;
173 inputToCellWeights_ = inputToCellWeights;
174 inputToOutputWeights_ = inputToOutputWeights;
175 recurrentToInputWeights_ = recurrentToInputWeights;
176 recurrentToForgetWeights_ = recurrentToForgetWeights;
177 recurrentToCellWeights_ = recurrentToCellWeights;
178 recurrentToOutputWeights_ = recurrentToOutputWeights;
179 inputGateBias_ = inputGateBias;
180 forgetGateBias_ = forgetGateBias;
181 cellGateBias_ = cellGateBias;
182 outputGateBias_ = outputGateBias;
183 }
184
185 template <typename T>
initializeInputData(OperandTypeParams params,std::vector<T> * vec)186 void initializeInputData(OperandTypeParams params, std::vector<T>* vec) {
187 int size = 1;
188 for (int d : params.shape) {
189 size *= d;
190 }
191 vec->clear();
192 vec->resize(size, params.zeroPoint);
193 }
194
getOutput()195 std::vector<uint8_t> getOutput() { return output_; }
196
197 private:
198 static constexpr int NUM_INPUTS = 15;
199 static constexpr int NUM_OUTPUTS = 2;
200
201 Model model_;
202 // Inputs
203 std::vector<uint8_t> input_;
204 std::vector<uint8_t> inputToInputWeights_;
205 std::vector<uint8_t> inputToForgetWeights_;
206 std::vector<uint8_t> inputToCellWeights_;
207 std::vector<uint8_t> inputToOutputWeights_;
208 std::vector<uint8_t> recurrentToInputWeights_;
209 std::vector<uint8_t> recurrentToForgetWeights_;
210 std::vector<uint8_t> recurrentToCellWeights_;
211 std::vector<uint8_t> recurrentToOutputWeights_;
212 std::vector<int32_t> inputGateBias_;
213 std::vector<int32_t> forgetGateBias_;
214 std::vector<int32_t> cellGateBias_;
215 std::vector<int32_t> outputGateBias_;
216 std::vector<int16_t> prevCellState_;
217 std::vector<uint8_t> prevOutput_;
218 // Outputs
219 std::vector<int16_t> cellStateOut_;
220 std::vector<uint8_t> output_;
221
222 int inputSize_;
223 int outputSize_;
224
225 template <typename T>
setInputTensor(Execution * execution,int tensor,const std::vector<T> & data)226 Result setInputTensor(Execution* execution, int tensor, const std::vector<T>& data) {
227 return execution->setInput(tensor, data.data(), sizeof(T) * data.size());
228 }
229 template <typename T>
setOutputTensor(Execution * execution,int tensor,std::vector<T> * data)230 Result setOutputTensor(Execution* execution, int tensor, std::vector<T>* data) {
231 return execution->setOutput(tensor, data->data(), sizeof(T) * data->size());
232 }
233 };
234
235 class QuantizedLstmTest : public ::testing::Test {
236 protected:
VerifyGoldens(const std::vector<std::vector<uint8_t>> & input,const std::vector<std::vector<uint8_t>> & output,QuantizedLSTMOpModel * lstm)237 void VerifyGoldens(const std::vector<std::vector<uint8_t>>& input,
238 const std::vector<std::vector<uint8_t>>& output,
239 QuantizedLSTMOpModel* lstm) {
240 const int numBatches = input.size();
241 EXPECT_GT(numBatches, 0);
242 const int inputSize = lstm->inputSize();
243 EXPECT_GT(inputSize, 0);
244 const int inputSequenceSize = input[0].size() / inputSize;
245 EXPECT_GT(inputSequenceSize, 0);
246 for (int i = 0; i < inputSequenceSize; ++i) {
247 std::vector<uint8_t> inputStep;
248 for (int b = 0; b < numBatches; ++b) {
249 const uint8_t* batchStart = input[b].data() + i * inputSize;
250 const uint8_t* batchEnd = batchStart + inputSize;
251 inputStep.insert(inputStep.end(), batchStart, batchEnd);
252 }
253 lstm->setInput(inputStep);
254 lstm->invoke();
255
256 const int outputSize = lstm->outputSize();
257 std::vector<float> expected;
258 for (int b = 0; b < numBatches; ++b) {
259 const uint8_t* goldenBatchStart = output[b].data() + i * outputSize;
260 const uint8_t* goldenBatchEnd = goldenBatchStart + outputSize;
261 expected.insert(expected.end(), goldenBatchStart, goldenBatchEnd);
262 }
263 EXPECT_THAT(lstm->getOutput(), ElementsAreArray(expected));
264 }
265 }
266 };
267
268 // Inputs and weights in this test are random and the test only checks that the
269 // outputs are equal to outputs obtained from running TF Lite version of
270 // quantized LSTM on the same inputs.
TEST_F(QuantizedLstmTest,BasicQuantizedLstmTest)271 TEST_F(QuantizedLstmTest, BasicQuantizedLstmTest) {
272 const int numBatches = 2;
273 const int inputSize = 2;
274 const int outputSize = 4;
275
276 float weightsScale = 0.00408021;
277 int weightsZeroPoint = 100;
278 // OperandType biasOperandType(Type::TENSOR_INT32, input_shapes[3],
279 // weightsScale / 128., 0);
280 // inputs.push_back(model_.addOperand(&biasOperandType));
281 // OperandType prevCellStateOperandType(Type::TENSOR_QUANT16_SYMM, input_shapes[4],
282 // 1. / 2048., 0);
283 // inputs.push_back(model_.addOperand(&prevCellStateOperandType));
284
285 QuantizedLSTMOpModel lstm({
286 // input
287 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {numBatches, inputSize}, 1. / 128., 128),
288 // inputToInputWeights
289 // inputToForgetWeights
290 // inputToCellWeights
291 // inputToOutputWeights
292 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
293 weightsZeroPoint),
294 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
295 weightsZeroPoint),
296 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
297 weightsZeroPoint),
298 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
299 weightsZeroPoint),
300 // recurrentToInputWeights
301 // recurrentToForgetWeights
302 // recurrentToCellWeights
303 // recurrentToOutputWeights
304 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
305 weightsZeroPoint),
306 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
307 weightsZeroPoint),
308 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
309 weightsZeroPoint),
310 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
311 weightsZeroPoint),
312 // inputGateBias
313 // forgetGateBias
314 // cellGateBias
315 // outputGateBias
316 OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
317 OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
318 OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
319 OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
320 // prevCellState
321 OperandTypeParams(Type::TENSOR_QUANT16_SYMM, {numBatches, outputSize}, 1. / 2048., 0),
322 // prevOutput
323 OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {numBatches, outputSize}, 1. / 128., 128),
324 });
325
326 lstm.setWeightsAndBiases(
327 // inputToInputWeights
328 {146, 250, 235, 171, 10, 218, 171, 108},
329 // inputToForgetWeights
330 {24, 50, 132, 179, 158, 110, 3, 169},
331 // inputToCellWeights
332 {133, 34, 29, 49, 206, 109, 54, 183},
333 // inputToOutputWeights
334 {195, 187, 11, 99, 109, 10, 218, 48},
335 // recurrentToInputWeights
336 {254, 206, 77, 168, 71, 20, 215, 6, 223, 7, 118, 225, 59, 130, 174, 26},
337 // recurrentToForgetWeights
338 {137, 240, 103, 52, 68, 51, 237, 112, 0, 220, 89, 23, 69, 4, 207, 253},
339 // recurrentToCellWeights
340 {172, 60, 205, 65, 14, 0, 140, 168, 240, 223, 133, 56, 142, 64, 246, 216},
341 // recurrentToOutputWeights
342 {106, 214, 67, 23, 59, 158, 45, 3, 119, 132, 49, 205, 129, 218, 11, 98},
343 // inputGateBias
344 {-7876, 13488, -726, 32839},
345 // forgetGateBias
346 {9206, -46884, -11693, -38724},
347 // cellGateBias
348 {39481, 48624, 48976, -21419},
349 // outputGateBias
350 {-58999, -17050, -41852, -40538});
351
352 // LSTM input is stored as numBatches x (sequenceLength x inputSize) vector.
353 std::vector<std::vector<uint8_t>> lstmInput;
354 // clang-format off
355 lstmInput = {{154, 166,
356 166, 179,
357 141, 141},
358 {100, 200,
359 50, 150,
360 111, 222}};
361 // clang-format on
362
363 // LSTM output is stored as numBatches x (sequenceLength x outputSize) vector.
364 std::vector<std::vector<uint8_t>> lstmGoldenOutput;
365 // clang-format off
366 lstmGoldenOutput = {{136, 150, 140, 115,
367 140, 151, 146, 112,
368 139, 153, 146, 114},
369 {135, 152, 138, 112,
370 136, 156, 142, 112,
371 141, 154, 146, 108}};
372 // clang-format on
373 VerifyGoldens(lstmInput, lstmGoldenOutput, &lstm);
374 };
375
376 } // namespace wrapper
377 } // namespace nn
378 } // namespace android
379