1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <vector>
20
21 #include "OperationResolver.h"
22 #include "Tracing.h"
23
24 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
25 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
26 #include <tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h>
27 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
28 #include <tensorflow/lite/kernels/internal/types.h>
29
30 #include "CpuOperationUtils.h"
31 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
32
33 namespace android {
34 namespace nn {
35 namespace fully_connected {
36
37 constexpr char kOperationName[] = "FULLY_CONNECTED";
38
39 constexpr uint32_t kNumInputs = 4;
40 constexpr uint32_t kInputTensor = 0;
41 constexpr uint32_t kWeightsTensor = 1;
42 constexpr uint32_t kBiasTensor = 2;
43 constexpr uint32_t kActivationScalar = 3;
44
45 constexpr uint32_t kNumOutputs = 1;
46 constexpr uint32_t kOutputTensor = 0;
47
48 namespace {
49
50 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
51 // executionMutex is used to protect concurrent access of non-threadsafe resources
52 // like gemmlowp::GemmContext.
53 // std::mutex is safe for pthreads on Android.
54 static std::mutex executionMutex;
55
fullyConnectedFloat32(const float * inputData,const Shape & inputShape,const float * weightsData,const Shape & weightsShape,const float * biasData,const Shape & biasShape,int32_t activation,float * outputData,const Shape & outputShape)56 bool fullyConnectedFloat32(const float* inputData, const Shape& inputShape,
57 const float* weightsData, const Shape& weightsShape,
58 const float* biasData, const Shape& biasShape, int32_t activation,
59 float* outputData, const Shape& outputShape) {
60 NNTRACE_TRANS("fullyConnectedFloat32");
61 float output_activation_min, output_activation_max;
62 CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
63
64 // b/80425683, optimized implementation produces incorrect results when the
65 // number of input elements is the squre of batch_size.
66 uint32_t batch_size = getSizeOfDimension(outputShape, 0);
67 uint32_t input_n_elements = getNumberOfElements(inputShape);
68 if (batch_size * batch_size == input_n_elements) {
69 NNTRACE_COMP_SWITCH("reference_ops::FullyConnected");
70 tflite::reference_ops::FullyConnected(inputData, convertShapeToDims(inputShape),
71 weightsData, convertShapeToDims(weightsShape),
72 biasData, convertShapeToDims(biasShape),
73 output_activation_min, output_activation_max,
74 outputData, convertShapeToDims(outputShape));
75 } else {
76 NNTRACE_COMP_SWITCH("optimized_ops::FullyConnected");
77 tflite::optimized_ops::FullyConnected(inputData, convertShapeToDims(inputShape),
78 weightsData, convertShapeToDims(weightsShape),
79 biasData, convertShapeToDims(biasShape),
80 output_activation_min, output_activation_max,
81 outputData, convertShapeToDims(outputShape));
82 }
83 return true;
84 }
85
fullyConnectedFloat16(const _Float16 * inputData,const Shape & inputShape,const _Float16 * weightsData,const Shape & weightsShape,const _Float16 * biasData,const Shape & biasShape,int32_t activation,_Float16 * outputData,const Shape & outputShape)86 bool fullyConnectedFloat16(const _Float16* inputData, const Shape& inputShape,
87 const _Float16* weightsData, const Shape& weightsShape,
88 const _Float16* biasData, const Shape& biasShape, int32_t activation,
89 _Float16* outputData, const Shape& outputShape) {
90 NNTRACE_TRANS("fullyConnectedFloat16");
91 std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
92 convertFloat16ToFloat32(inputData, &inputDataFloat32);
93 std::vector<float> weightsDataFloat32(getNumberOfElements(weightsShape));
94 convertFloat16ToFloat32(weightsData, &weightsDataFloat32);
95 std::vector<float> biasDataFloat32(getNumberOfElements(biasShape));
96 convertFloat16ToFloat32(biasData, &biasDataFloat32);
97
98 std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
99 fullyConnectedFloat32(inputDataFloat32.data(), inputShape, weightsDataFloat32.data(),
100 weightsShape, biasDataFloat32.data(), biasShape, activation,
101 outputDataFloat32.data(), outputShape);
102 convertFloat32ToFloat16(outputDataFloat32, outputData);
103
104 return true;
105 }
106
fullyConnectedQuant8(const uint8_t * inputData,const Shape & inputShape,const uint8_t * weightsData,const Shape & weightsShape,const int32_t * biasData,const Shape & biasShape,int32_t activation,uint8_t * outputData,const Shape & outputShape)107 bool fullyConnectedQuant8(const uint8_t* inputData, const Shape& inputShape,
108 const uint8_t* weightsData, const Shape& weightsShape,
109 const int32_t* biasData, const Shape& biasShape, int32_t activation,
110 uint8_t* outputData, const Shape& outputShape) {
111 NNTRACE_TRANS("fullyConnectedQuant8");
112 int32_t inputOffset = -inputShape.offset;
113 int32_t weightsOffset = -weightsShape.offset;
114 int32_t outputOffset = outputShape.offset;
115
116 double realMultiplier = 0.0;
117 int32_t outputMultiplier = 0;
118 int32_t outputShift = 0;
119 int32_t outputActivationMin = 0;
120 int32_t outputActivationMax = 0;
121
122 NN_RET_CHECK(GetQuantizedConvolutionMultipler(inputShape, weightsShape, biasShape, outputShape,
123 &realMultiplier));
124 int exponent;
125 NN_RET_CHECK(QuantizeMultiplier(realMultiplier, &outputMultiplier, &exponent));
126 outputShift = -exponent;
127 CalculateActivationRangeUint8(activation, outputShape, &outputActivationMin,
128 &outputActivationMax);
129
130 static gemmlowp::GemmContext gemmContext;
131
132 // Prevent concurrent executions that access gemmContext.
133 std::unique_lock<std::mutex> lock(executionMutex);
134 // Alow gemmlowp automatically decide how many threads to use.
135 gemmContext.set_max_num_threads(0);
136
137 NNTRACE_COMP_SWITCH("optimized_ops::FullyConnected");
138 tflite::optimized_ops::FullyConnected(inputData, convertShapeToDims(inputShape), inputOffset,
139 weightsData, convertShapeToDims(weightsShape),
140 weightsOffset, biasData, convertShapeToDims(biasShape),
141 outputOffset, outputMultiplier, outputShift,
142 outputActivationMin, outputActivationMax, outputData,
143 convertShapeToDims(outputShape), &gemmContext);
144
145 return true;
146 }
147
fullyConnectedQuant8(const int8_t * inputData,const Shape & inputShape,const int8_t * weightsData,const Shape & weightsShape,const int32_t * biasData,const Shape & biasShape,int32_t activation,int8_t * outputData,const Shape & outputShape)148 bool fullyConnectedQuant8(const int8_t* inputData, const Shape& inputShape,
149 const int8_t* weightsData, const Shape& weightsShape,
150 const int32_t* biasData, const Shape& biasShape, int32_t activation,
151 int8_t* outputData, const Shape& outputShape) {
152 NNTRACE_TRANS("fullyConnectedQuant8Signed");
153
154 double realMultiplier = 0.0;
155 int32_t outputMultiplier = 0;
156 int32_t outputShift = 0;
157 int32_t outputActivationMin = 0;
158 int32_t outputActivationMax = 0;
159
160 NN_RET_CHECK(GetQuantizedConvolutionMultipler(inputShape, weightsShape, biasShape, outputShape,
161 &realMultiplier));
162 NN_RET_CHECK(QuantizeMultiplier(realMultiplier, &outputMultiplier, &outputShift));
163 CalculateActivationRangeInt8(activation, outputShape, &outputActivationMin,
164 &outputActivationMax);
165
166 tflite::FullyConnectedParams params;
167 params.input_offset = -inputShape.offset;
168 params.weights_offset = -weightsShape.offset;
169 params.output_offset = outputShape.offset;
170 params.output_multiplier = outputMultiplier;
171 params.output_shift = outputShift;
172 params.quantized_activation_min = outputActivationMin;
173 params.quantized_activation_max = outputActivationMax;
174
175 NNTRACE_COMP_SWITCH("reference_integer_ops::FullyConnected");
176 tflite::reference_integer_ops::FullyConnected(
177 params, convertShapeToTflshape(inputShape), inputData,
178 convertShapeToTflshape(weightsShape), weightsData, convertShapeToTflshape(biasShape),
179 biasData, convertShapeToTflshape(outputShape), outputData);
180
181 return true;
182 }
183 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
184
validateShapes(const Shape & input,const Shape & weights,const Shape & bias,Shape * output=nullptr)185 bool validateShapes(const Shape& input, const Shape& weights, const Shape& bias,
186 Shape* output = nullptr) {
187 // Check all the parameters of tensor match within themselves and match the
188 // input configuration.
189 NN_RET_CHECK(weights.type == input.type);
190 if (input.type == OperandType::TENSOR_QUANT8_ASYMM ||
191 input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
192 NN_RET_CHECK(bias.type == OperandType::TENSOR_INT32);
193 } else {
194 NN_RET_CHECK(bias.type == input.type);
195 }
196 // The Tensorflow fully connected layer specification says that input should
197 // be of at least rank 2, so we check. Tflite doesn't check.
198 NN_RET_CHECK_GE(getNumberOfDimensions(input), 2);
199 NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
200 NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2);
201 NN_RET_CHECK_EQ(getNumberOfDimensions(bias), 1);
202 uint32_t input_n_elements = getNumberOfElements(input);
203 uint32_t num_units = getSizeOfDimension(weights, 0);
204 uint32_t input_size = getSizeOfDimension(weights, 1);
205 uint32_t bias_len = getSizeOfDimension(bias, 0);
206 uint32_t batch_size = input_size == 0 ? 0 : input_n_elements / input_size;
207 if (batch_size != 0) {
208 NN_RET_CHECK_EQ(input_size * batch_size, input_n_elements);
209 }
210 if (num_units != 0 && bias_len != 0) {
211 NN_RET_CHECK_EQ(bias_len, num_units);
212 }
213 if (output != nullptr) {
214 // Only batch_size can be 0.
215 NN_RET_CHECK_GT(num_units, 0);
216 NN_RET_CHECK_GT(input_size, 0);
217 output->type = input.type;
218 output->dimensions = {batch_size, num_units};
219 }
220 return true;
221 }
222
223 } // namespace
224
validate(const IOperationValidationContext * context)225 Result<Version> validate(const IOperationValidationContext* context) {
226 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
227 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
228 auto inputType = context->getInputType(kInputTensor);
229 std::vector<OperandType> inExpectedTypes;
230 std::vector<OperandType> outExpectedTypes;
231 auto minSupportedVersion = Version::ANDROID_OC_MR1;
232 if (inputType == OperandType::TENSOR_FLOAT32) {
233 minSupportedVersion = Version::ANDROID_OC_MR1;
234 inExpectedTypes = {
235 OperandType::TENSOR_FLOAT32,
236 OperandType::TENSOR_FLOAT32,
237 OperandType::TENSOR_FLOAT32,
238 OperandType::INT32,
239 };
240 } else if (inputType == OperandType::TENSOR_FLOAT16) {
241 minSupportedVersion = Version::ANDROID_Q;
242 inExpectedTypes = {
243 OperandType::TENSOR_FLOAT16,
244 OperandType::TENSOR_FLOAT16,
245 OperandType::TENSOR_FLOAT16,
246 OperandType::INT32,
247 };
248 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
249 // NeuralNetworks.h specifies that ANEURALNETWORKS_FULLY_CONNECTED's output must
250 // meet "outputScale > inputScale * weightsScale" for the operand type
251 // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM before API level 29.
252 const float inputScale = context->getInputShape(kInputTensor).scale;
253 const float weightsScale = context->getInputShape(kWeightsTensor).scale;
254 const float outputScale = context->getOutputShape(kOutputTensor).scale;
255 bool meetsQuantizedScaleConstraintBeforeV1_2 = (outputScale > inputScale * weightsScale);
256
257 if (!meetsQuantizedScaleConstraintBeforeV1_2) {
258 minSupportedVersion = Version::ANDROID_Q;
259 } else {
260 minSupportedVersion = Version::ANDROID_OC_MR1;
261 }
262
263 inExpectedTypes = {
264 OperandType::TENSOR_QUANT8_ASYMM,
265 OperandType::TENSOR_QUANT8_ASYMM,
266 OperandType::TENSOR_INT32,
267 OperandType::INT32,
268 };
269 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
270 minSupportedVersion = Version::ANDROID_R;
271
272 inExpectedTypes = {
273 OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
274 OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
275 OperandType::TENSOR_INT32,
276 OperandType::INT32,
277 };
278 } else {
279 NN_RET_CHECK_FAIL() << "Unsupported input tensor type for operation " << kOperationName;
280 }
281 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
282 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
283
284 Shape input = context->getInputShape(kInputTensor);
285 Shape weights = context->getInputShape(kWeightsTensor);
286 Shape bias = context->getInputShape(kBiasTensor);
287 if (hasKnownRank(input) && hasKnownRank(weights) && hasKnownRank(bias)) {
288 NN_RET_CHECK(validateShapes(input, weights, bias));
289 }
290
291 return minSupportedVersion;
292 }
293
294 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
prepare(IOperationExecutionContext * context)295 bool prepare(IOperationExecutionContext* context) {
296 Shape input = context->getInputShape(kInputTensor);
297 Shape weights = context->getInputShape(kWeightsTensor);
298 Shape bias = context->getInputShape(kBiasTensor);
299 Shape output = context->getOutputShape(kOutputTensor);
300 NN_RET_CHECK(validateShapes(input, weights, bias, &output));
301 return context->setOutputShape(kOutputTensor, output);
302 }
303
execute(IOperationExecutionContext * context)304 bool execute(IOperationExecutionContext* context) {
305 // Bypass execution in the case of zero-sized input.
306 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
307 switch (context->getInputType(kInputTensor)) {
308 case OperandType::TENSOR_FLOAT32:
309 return fullyConnectedFloat32(context->getInputBuffer<float>(kInputTensor),
310 context->getInputShape(kInputTensor),
311 context->getInputBuffer<float>(kWeightsTensor),
312 context->getInputShape(kWeightsTensor),
313 context->getInputBuffer<float>(kBiasTensor),
314 context->getInputShape(kBiasTensor),
315 context->getInputValue<int32_t>(kActivationScalar),
316 context->getOutputBuffer<float>(kOutputTensor),
317 context->getOutputShape(kOutputTensor));
318 case OperandType::TENSOR_FLOAT16:
319 return fullyConnectedFloat16(context->getInputBuffer<_Float16>(kInputTensor),
320 context->getInputShape(kInputTensor),
321 context->getInputBuffer<_Float16>(kWeightsTensor),
322 context->getInputShape(kWeightsTensor),
323 context->getInputBuffer<_Float16>(kBiasTensor),
324 context->getInputShape(kBiasTensor),
325 context->getInputValue<int32_t>(kActivationScalar),
326 context->getOutputBuffer<_Float16>(kOutputTensor),
327 context->getOutputShape(kOutputTensor));
328 case OperandType::TENSOR_QUANT8_ASYMM:
329 return fullyConnectedQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
330 context->getInputShape(kInputTensor),
331 context->getInputBuffer<uint8_t>(kWeightsTensor),
332 context->getInputShape(kWeightsTensor),
333 context->getInputBuffer<int32_t>(kBiasTensor),
334 context->getInputShape(kBiasTensor),
335 context->getInputValue<int32_t>(kActivationScalar),
336 context->getOutputBuffer<uint8_t>(kOutputTensor),
337 context->getOutputShape(kOutputTensor));
338 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
339 return fullyConnectedQuant8(context->getInputBuffer<int8_t>(kInputTensor),
340 context->getInputShape(kInputTensor),
341 context->getInputBuffer<int8_t>(kWeightsTensor),
342 context->getInputShape(kWeightsTensor),
343 context->getInputBuffer<int32_t>(kBiasTensor),
344 context->getInputShape(kBiasTensor),
345 context->getInputValue<int32_t>(kActivationScalar),
346 context->getOutputBuffer<int8_t>(kOutputTensor),
347 context->getOutputShape(kOutputTensor));
348 default:
349 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
350 }
351 }
352 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
353
354 } // namespace fully_connected
355
356 NN_REGISTER_OPERATION(FULLY_CONNECTED, fully_connected::kOperationName, fully_connected::validate,
357 fully_connected::prepare, fully_connected::execute,
358 .allowZeroSizedInput = true);
359
360 } // namespace nn
361 } // namespace android
362