1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <gmock/gmock.h>
18 #include <gtest/gtest.h>
19 
20 #include <iostream>
21 #include <vector>
22 
23 #include "NeuralNetworksWrapper.h"
24 #include "QuantizedLSTM.h"
25 
26 namespace android {
27 namespace nn {
28 namespace wrapper {
29 
30 namespace {
31 
32 struct OperandTypeParams {
33     Type type;
34     std::vector<uint32_t> shape;
35     float scale;
36     int32_t zeroPoint;
37 
OperandTypeParamsandroid::nn::wrapper::__anon48a6dcc50110::OperandTypeParams38     OperandTypeParams(Type type, std::vector<uint32_t> shape, float scale, int32_t zeroPoint)
39         : type(type), shape(shape), scale(scale), zeroPoint(zeroPoint) {}
40 };
41 
42 }  // namespace
43 
44 using ::testing::Each;
45 using ::testing::ElementsAreArray;
46 using ::testing::FloatNear;
47 using ::testing::Matcher;
48 
49 class QuantizedLSTMOpModel {
50    public:
QuantizedLSTMOpModel(const std::vector<OperandTypeParams> & inputOperandTypeParams)51     QuantizedLSTMOpModel(const std::vector<OperandTypeParams>& inputOperandTypeParams) {
52         std::vector<uint32_t> inputs;
53 
54         for (int i = 0; i < NUM_INPUTS; ++i) {
55             const auto& curOTP = inputOperandTypeParams[i];
56             OperandType curType(curOTP.type, curOTP.shape, curOTP.scale, curOTP.zeroPoint);
57             inputs.push_back(model_.addOperand(&curType));
58         }
59 
60         const uint32_t numBatches = inputOperandTypeParams[0].shape[0];
61         inputSize_ = inputOperandTypeParams[0].shape[0];
62         const uint32_t outputSize =
63                 inputOperandTypeParams[QuantizedLSTMCell::kPrevCellStateTensor].shape[1];
64         outputSize_ = outputSize;
65 
66         std::vector<uint32_t> outputs;
67         OperandType cellStateOutOperandType(Type::TENSOR_QUANT16_SYMM, {numBatches, outputSize},
68                                             1. / 2048., 0);
69         outputs.push_back(model_.addOperand(&cellStateOutOperandType));
70         OperandType outputOperandType(Type::TENSOR_QUANT8_ASYMM, {numBatches, outputSize},
71                                       1. / 128., 128);
72         outputs.push_back(model_.addOperand(&outputOperandType));
73 
74         model_.addOperation(ANEURALNETWORKS_QUANTIZED_16BIT_LSTM, inputs, outputs);
75         model_.identifyInputsAndOutputs(inputs, outputs);
76 
77         initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kInputTensor], &input_);
78         initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kPrevOutputTensor],
79                             &prevOutput_);
80         initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kPrevCellStateTensor],
81                             &prevCellState_);
82 
83         cellStateOut_.resize(numBatches * outputSize, 0);
84         output_.resize(numBatches * outputSize, 0);
85 
86         model_.finish();
87     }
88 
invoke()89     void invoke() {
90         ASSERT_TRUE(model_.isValid());
91 
92         Compilation compilation(&model_);
93         compilation.finish();
94         Execution execution(&compilation);
95 
96         // Set all the inputs.
97         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputTensor, input_),
98                   Result::NO_ERROR);
99         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToInputWeightsTensor,
100                                  inputToInputWeights_),
101                   Result::NO_ERROR);
102         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToForgetWeightsTensor,
103                                  inputToForgetWeights_),
104                   Result::NO_ERROR);
105         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToCellWeightsTensor,
106                                  inputToCellWeights_),
107                   Result::NO_ERROR);
108         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToOutputWeightsTensor,
109                                  inputToOutputWeights_),
110                   Result::NO_ERROR);
111         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToInputWeightsTensor,
112                                  recurrentToInputWeights_),
113                   Result::NO_ERROR);
114         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToForgetWeightsTensor,
115                                  recurrentToForgetWeights_),
116                   Result::NO_ERROR);
117         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToCellWeightsTensor,
118                                  recurrentToCellWeights_),
119                   Result::NO_ERROR);
120         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToOutputWeightsTensor,
121                                  recurrentToOutputWeights_),
122                   Result::NO_ERROR);
123         ASSERT_EQ(
124                 setInputTensor(&execution, QuantizedLSTMCell::kInputGateBiasTensor, inputGateBias_),
125                 Result::NO_ERROR);
126         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kForgetGateBiasTensor,
127                                  forgetGateBias_),
128                   Result::NO_ERROR);
129         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kCellGateBiasTensor, cellGateBias_),
130                   Result::NO_ERROR);
131         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kOutputGateBiasTensor,
132                                  outputGateBias_),
133                   Result::NO_ERROR);
134         ASSERT_EQ(
135                 setInputTensor(&execution, QuantizedLSTMCell::kPrevCellStateTensor, prevCellState_),
136                 Result::NO_ERROR);
137         ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kPrevOutputTensor, prevOutput_),
138                   Result::NO_ERROR);
139         // Set all the outputs.
140         ASSERT_EQ(
141                 setOutputTensor(&execution, QuantizedLSTMCell::kCellStateOutTensor, &cellStateOut_),
142                 Result::NO_ERROR);
143         ASSERT_EQ(setOutputTensor(&execution, QuantizedLSTMCell::kOutputTensor, &output_),
144                   Result::NO_ERROR);
145 
146         ASSERT_EQ(execution.compute(), Result::NO_ERROR);
147 
148         // Put state outputs into inputs for the next step
149         prevOutput_ = output_;
150         prevCellState_ = cellStateOut_;
151     }
152 
inputSize()153     int inputSize() { return inputSize_; }
154 
outputSize()155     int outputSize() { return outputSize_; }
156 
setInput(const std::vector<uint8_t> & input)157     void setInput(const std::vector<uint8_t>& input) { input_ = input; }
158 
setWeightsAndBiases(std::vector<uint8_t> inputToInputWeights,std::vector<uint8_t> inputToForgetWeights,std::vector<uint8_t> inputToCellWeights,std::vector<uint8_t> inputToOutputWeights,std::vector<uint8_t> recurrentToInputWeights,std::vector<uint8_t> recurrentToForgetWeights,std::vector<uint8_t> recurrentToCellWeights,std::vector<uint8_t> recurrentToOutputWeights,std::vector<int32_t> inputGateBias,std::vector<int32_t> forgetGateBias,std::vector<int32_t> cellGateBias,std::vector<int32_t> outputGateBias)159     void setWeightsAndBiases(std::vector<uint8_t> inputToInputWeights,
160                              std::vector<uint8_t> inputToForgetWeights,
161                              std::vector<uint8_t> inputToCellWeights,
162                              std::vector<uint8_t> inputToOutputWeights,
163                              std::vector<uint8_t> recurrentToInputWeights,
164                              std::vector<uint8_t> recurrentToForgetWeights,
165                              std::vector<uint8_t> recurrentToCellWeights,
166                              std::vector<uint8_t> recurrentToOutputWeights,
167                              std::vector<int32_t> inputGateBias,
168                              std::vector<int32_t> forgetGateBias,
169                              std::vector<int32_t> cellGateBias,  //
170                              std::vector<int32_t> outputGateBias) {
171         inputToInputWeights_ = inputToInputWeights;
172         inputToForgetWeights_ = inputToForgetWeights;
173         inputToCellWeights_ = inputToCellWeights;
174         inputToOutputWeights_ = inputToOutputWeights;
175         recurrentToInputWeights_ = recurrentToInputWeights;
176         recurrentToForgetWeights_ = recurrentToForgetWeights;
177         recurrentToCellWeights_ = recurrentToCellWeights;
178         recurrentToOutputWeights_ = recurrentToOutputWeights;
179         inputGateBias_ = inputGateBias;
180         forgetGateBias_ = forgetGateBias;
181         cellGateBias_ = cellGateBias;
182         outputGateBias_ = outputGateBias;
183     }
184 
185     template <typename T>
initializeInputData(OperandTypeParams params,std::vector<T> * vec)186     void initializeInputData(OperandTypeParams params, std::vector<T>* vec) {
187         int size = 1;
188         for (int d : params.shape) {
189             size *= d;
190         }
191         vec->clear();
192         vec->resize(size, params.zeroPoint);
193     }
194 
getOutput()195     std::vector<uint8_t> getOutput() { return output_; }
196 
197    private:
198     static constexpr int NUM_INPUTS = 15;
199     static constexpr int NUM_OUTPUTS = 2;
200 
201     Model model_;
202     // Inputs
203     std::vector<uint8_t> input_;
204     std::vector<uint8_t> inputToInputWeights_;
205     std::vector<uint8_t> inputToForgetWeights_;
206     std::vector<uint8_t> inputToCellWeights_;
207     std::vector<uint8_t> inputToOutputWeights_;
208     std::vector<uint8_t> recurrentToInputWeights_;
209     std::vector<uint8_t> recurrentToForgetWeights_;
210     std::vector<uint8_t> recurrentToCellWeights_;
211     std::vector<uint8_t> recurrentToOutputWeights_;
212     std::vector<int32_t> inputGateBias_;
213     std::vector<int32_t> forgetGateBias_;
214     std::vector<int32_t> cellGateBias_;
215     std::vector<int32_t> outputGateBias_;
216     std::vector<int16_t> prevCellState_;
217     std::vector<uint8_t> prevOutput_;
218     // Outputs
219     std::vector<int16_t> cellStateOut_;
220     std::vector<uint8_t> output_;
221 
222     int inputSize_;
223     int outputSize_;
224 
225     template <typename T>
setInputTensor(Execution * execution,int tensor,const std::vector<T> & data)226     Result setInputTensor(Execution* execution, int tensor, const std::vector<T>& data) {
227         return execution->setInput(tensor, data.data(), sizeof(T) * data.size());
228     }
229     template <typename T>
setOutputTensor(Execution * execution,int tensor,std::vector<T> * data)230     Result setOutputTensor(Execution* execution, int tensor, std::vector<T>* data) {
231         return execution->setOutput(tensor, data->data(), sizeof(T) * data->size());
232     }
233 };
234 
235 class QuantizedLstmTest : public ::testing::Test {
236    protected:
VerifyGoldens(const std::vector<std::vector<uint8_t>> & input,const std::vector<std::vector<uint8_t>> & output,QuantizedLSTMOpModel * lstm)237     void VerifyGoldens(const std::vector<std::vector<uint8_t>>& input,
238                        const std::vector<std::vector<uint8_t>>& output,
239                        QuantizedLSTMOpModel* lstm) {
240         const int numBatches = input.size();
241         EXPECT_GT(numBatches, 0);
242         const int inputSize = lstm->inputSize();
243         EXPECT_GT(inputSize, 0);
244         const int inputSequenceSize = input[0].size() / inputSize;
245         EXPECT_GT(inputSequenceSize, 0);
246         for (int i = 0; i < inputSequenceSize; ++i) {
247             std::vector<uint8_t> inputStep;
248             for (int b = 0; b < numBatches; ++b) {
249                 const uint8_t* batchStart = input[b].data() + i * inputSize;
250                 const uint8_t* batchEnd = batchStart + inputSize;
251                 inputStep.insert(inputStep.end(), batchStart, batchEnd);
252             }
253             lstm->setInput(inputStep);
254             lstm->invoke();
255 
256             const int outputSize = lstm->outputSize();
257             std::vector<float> expected;
258             for (int b = 0; b < numBatches; ++b) {
259                 const uint8_t* goldenBatchStart = output[b].data() + i * outputSize;
260                 const uint8_t* goldenBatchEnd = goldenBatchStart + outputSize;
261                 expected.insert(expected.end(), goldenBatchStart, goldenBatchEnd);
262             }
263             EXPECT_THAT(lstm->getOutput(), ElementsAreArray(expected));
264         }
265     }
266 };
267 
268 // Inputs and weights in this test are random and the test only checks that the
269 // outputs are equal to outputs obtained from running TF Lite version of
270 // quantized LSTM on the same inputs.
TEST_F(QuantizedLstmTest,BasicQuantizedLstmTest)271 TEST_F(QuantizedLstmTest, BasicQuantizedLstmTest) {
272     const int numBatches = 2;
273     const int inputSize = 2;
274     const int outputSize = 4;
275 
276     float weightsScale = 0.00408021;
277     int weightsZeroPoint = 100;
278     // OperandType biasOperandType(Type::TENSOR_INT32, input_shapes[3],
279     // weightsScale / 128., 0);
280     // inputs.push_back(model_.addOperand(&biasOperandType));
281     // OperandType prevCellStateOperandType(Type::TENSOR_QUANT16_SYMM, input_shapes[4],
282     // 1. / 2048., 0);
283     // inputs.push_back(model_.addOperand(&prevCellStateOperandType));
284 
285     QuantizedLSTMOpModel lstm({
286             // input
287             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {numBatches, inputSize}, 1. / 128., 128),
288             // inputToInputWeights
289             // inputToForgetWeights
290             // inputToCellWeights
291             // inputToOutputWeights
292             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
293                               weightsZeroPoint),
294             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
295                               weightsZeroPoint),
296             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
297                               weightsZeroPoint),
298             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
299                               weightsZeroPoint),
300             // recurrentToInputWeights
301             // recurrentToForgetWeights
302             // recurrentToCellWeights
303             // recurrentToOutputWeights
304             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
305                               weightsZeroPoint),
306             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
307                               weightsZeroPoint),
308             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
309                               weightsZeroPoint),
310             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
311                               weightsZeroPoint),
312             // inputGateBias
313             // forgetGateBias
314             // cellGateBias
315             // outputGateBias
316             OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
317             OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
318             OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
319             OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
320             // prevCellState
321             OperandTypeParams(Type::TENSOR_QUANT16_SYMM, {numBatches, outputSize}, 1. / 2048., 0),
322             // prevOutput
323             OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {numBatches, outputSize}, 1. / 128., 128),
324     });
325 
326     lstm.setWeightsAndBiases(
327             // inputToInputWeights
328             {146, 250, 235, 171, 10, 218, 171, 108},
329             // inputToForgetWeights
330             {24, 50, 132, 179, 158, 110, 3, 169},
331             // inputToCellWeights
332             {133, 34, 29, 49, 206, 109, 54, 183},
333             // inputToOutputWeights
334             {195, 187, 11, 99, 109, 10, 218, 48},
335             // recurrentToInputWeights
336             {254, 206, 77, 168, 71, 20, 215, 6, 223, 7, 118, 225, 59, 130, 174, 26},
337             // recurrentToForgetWeights
338             {137, 240, 103, 52, 68, 51, 237, 112, 0, 220, 89, 23, 69, 4, 207, 253},
339             // recurrentToCellWeights
340             {172, 60, 205, 65, 14, 0, 140, 168, 240, 223, 133, 56, 142, 64, 246, 216},
341             // recurrentToOutputWeights
342             {106, 214, 67, 23, 59, 158, 45, 3, 119, 132, 49, 205, 129, 218, 11, 98},
343             // inputGateBias
344             {-7876, 13488, -726, 32839},
345             // forgetGateBias
346             {9206, -46884, -11693, -38724},
347             // cellGateBias
348             {39481, 48624, 48976, -21419},
349             // outputGateBias
350             {-58999, -17050, -41852, -40538});
351 
352     // LSTM input is stored as numBatches x (sequenceLength x inputSize) vector.
353     std::vector<std::vector<uint8_t>> lstmInput;
354     // clang-format off
355     lstmInput = {{154, 166,
356                   166, 179,
357                   141, 141},
358                  {100, 200,
359                   50,  150,
360                   111, 222}};
361     // clang-format on
362 
363     // LSTM output is stored as numBatches x (sequenceLength x outputSize) vector.
364     std::vector<std::vector<uint8_t>> lstmGoldenOutput;
365     // clang-format off
366     lstmGoldenOutput = {{136, 150, 140, 115,
367                          140, 151, 146, 112,
368                          139, 153, 146, 114},
369                         {135, 152, 138, 112,
370                          136, 156, 142, 112,
371                          141, 154, 146, 108}};
372     // clang-format on
373     VerifyGoldens(lstmInput, lstmGoldenOutput, &lstm);
374 };
375 
376 }  // namespace wrapper
377 }  // namespace nn
378 }  // namespace android
379