1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 // Contains the implementation of the operations.
18 
19 #define LOG_TAG "Operations"
20 
21 #include <algorithm>
22 #include <vector>
23 
24 #include "IndexedShapeWrapper.h"
25 #include "OperationResolver.h"
26 #include "Tracing.h"
27 #include "nnapi/Types.h"
28 #include "nnapi/Validation.h"
29 
30 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
31 #include <tensorflow/lite/kernels/internal/optimized/integer_ops/add.h>
32 #include <tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h>
33 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
34 #include <tensorflow/lite/kernels/internal/reference/integer_ops/add.h>
35 #include <tensorflow/lite/kernels/internal/reference/integer_ops/mul.h>
36 #include <tensorflow/lite/kernels/internal/types.h>
37 
38 #include "CpuOperationUtils.h"
39 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
40 
41 namespace android {
42 namespace nn {
43 
44 namespace broadcast {
45 
46 constexpr uint32_t kNumInputs = 3;
47 constexpr uint32_t kInputTensor1 = 0;
48 constexpr uint32_t kInputTensor2 = 1;
49 constexpr uint32_t kActivationScalar = 2;
50 
51 constexpr uint32_t kNumOutputs = 1;
52 constexpr uint32_t kOutputTensor = 0;
53 
54 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
55 namespace {
56 
57 #define ANDROID_NN_MACRO_DISPATCH(macro)                                \
58     switch (activation) {                                               \
59         case static_cast<int32_t>(FusedActivationFunc::NONE):           \
60             macro(kNone);                                               \
61             break;                                                      \
62         case static_cast<int32_t>(FusedActivationFunc::RELU):           \
63             macro(kRelu);                                               \
64             break;                                                      \
65         case static_cast<int32_t>(FusedActivationFunc::RELU1):          \
66             macro(kRelu1);                                              \
67             break;                                                      \
68         case static_cast<int32_t>(FusedActivationFunc::RELU6):          \
69             macro(kRelu6);                                              \
70             break;                                                      \
71         default:                                                        \
72             LOG(ERROR) << "Unsupported fused activation function type"; \
73             return false;                                               \
74     }
75 
76 using binaryFunctionFloat32 = std::function<bool(
77         const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
78         int32_t activation, float* out, const Shape& shapeOut)>;
79 
binaryOperationFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut,binaryFunctionFloat32 operationFloat32)80 bool binaryOperationFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2,
81                             const Shape& shape2, int32_t activation, _Float16* out,
82                             const Shape& shapeOut, binaryFunctionFloat32 operationFloat32) {
83     std::vector<float> in1_float32(getNumberOfElements(shape1));
84     convertFloat16ToFloat32(in1, &in1_float32);
85     std::vector<float> in2_float32(getNumberOfElements(shape2));
86     convertFloat16ToFloat32(in2, &in2_float32);
87     std::vector<float> out_float32(getNumberOfElements(shapeOut));
88 
89     operationFloat32(in1_float32.data(), shape1, in2_float32.data(), shape2, activation,
90                      out_float32.data(), shapeOut);
91     convertFloat32ToFloat16(out_float32, out);
92 
93     return true;
94 }
95 
addFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)96 bool addFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
97                 int32_t activation, float* out, const Shape& shapeOut) {
98     NNTRACE_TRANS("addFloat32");
99     bool needBroadcast = !SameShape(shape1, shape2);
100     if (needBroadcast) {
101         NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
102 #define ANDROID_NN_BROADCAST_ADD(activation)                                              \
103     tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \
104             in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out,        \
105             convertShapeToDims(shapeOut))
106 
107         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
108 #undef ANDROID_NN_BROADCAST_ADD
109     } else {
110         NNTRACE_COMP_SWITCH("optimized_ops::Add");
111 #define ANDROID_NN_ADD(activation)                                                 \
112     tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>(   \
113             in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
114             convertShapeToDims(shapeOut))
115 
116         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_ADD)
117 #undef ANDROID_NN_ADD
118     }
119 
120     return true;
121 }
122 
addFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)123 bool addFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
124                 int32_t activation, _Float16* out, const Shape& shapeOut) {
125     NNTRACE_TRANS("addFloat16");
126     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &addFloat32);
127 }
128 
129 template <typename T>
addQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)130 bool addQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
131                int32_t activation, T* out, const Shape& shapeOut) {
132     NNTRACE_TRANS("addQuant8");
133     const bool needBroadcast = !SameShape(shape1, shape2);
134 
135     const int32_t input1_offset = -shape1.offset;
136     const int32_t input2_offset = -shape2.offset;
137     const int32_t output_offset = shapeOut.offset;
138     const int left_shift = 20;
139     const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
140     const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
141     const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
142     const double real_output_multiplier =
143             twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
144 
145     int32_t input1_multiplier;
146     int32_t input1_shift;
147     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier,
148                                                      &input1_shift));
149     int32_t input2_multiplier;
150     int32_t input2_shift;
151     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier,
152                                                      &input2_shift));
153     int32_t output_multiplier;
154     int32_t output_shift;
155     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier,
156                                                      &output_shift));
157 
158     int32_t output_activation_min;
159     int32_t output_activation_max;
160     constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
161     if constexpr (isSignedOp) {
162         CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
163                                      &output_activation_max);
164     } else {
165         CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
166                                       &output_activation_max);
167     }
168 
169     tflite::ArithmeticParams op_params;
170     op_params.left_shift = left_shift;
171     op_params.input1_offset = input1_offset;
172     op_params.input1_multiplier = input1_multiplier;
173     op_params.input1_shift = input1_shift;
174     op_params.input2_offset = input2_offset;
175     op_params.input2_multiplier = input2_multiplier;
176     op_params.input2_shift = input2_shift;
177     op_params.output_offset = output_offset;
178     op_params.output_multiplier = output_multiplier;
179     op_params.output_shift = output_shift;
180     tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
181 
182     if (needBroadcast) {
183         if constexpr (isSignedOp) {
184             NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastAdd4DSlow");
185             tflite::reference_integer_ops::BroadcastAdd4DSlow(
186                     op_params, convertShapeToTflshape(shape1), in1, convertShapeToTflshape(shape2),
187                     in2, convertShapeToTflshape(shapeOut), out);
188         } else {
189             NNTRACE_COMP_SWITCH("reference_ops::BroadcastAdd4DSlow");
190             tflite::reference_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1),
191                                                       in1, convertShapeToTflshape(shape2), in2,
192                                                       convertShapeToTflshape(shapeOut), out);
193         }
194     } else {
195         if constexpr (isSignedOp) {
196             NNTRACE_COMP_SWITCH("optimized_integer_ops::Add");
197             tflite::optimized_integer_ops::Add(op_params, convertShapeToTflshape(shape1), in1,
198                                                convertShapeToTflshape(shape2), in2,
199                                                convertShapeToTflshape(shapeOut), out);
200         } else {
201             NNTRACE_COMP_SWITCH("optimized_ops::Add");
202             tflite::optimized_ops::Add(op_params, convertShapeToTflshape(shape1), in1,
203                                        convertShapeToTflshape(shape2), in2,
204                                        convertShapeToTflshape(shapeOut), out);
205         }
206     }
207 
208     return true;
209 }
210 
executeInt32(const int32_t * aData,const Shape & aShape,const int32_t * bData,const Shape & bShape,int32_t activation,int32_t * outputData,const Shape & outputShape,int32_t func (int32_t,int32_t))211 bool executeInt32(const int32_t* aData, const Shape& aShape, const int32_t* bData,
212                   const Shape& bShape, int32_t activation, int32_t* outputData,
213                   const Shape& outputShape, int32_t func(int32_t, int32_t)) {
214     NN_RET_CHECK_EQ(static_cast<FusedActivationFunc>(activation), FusedActivationFunc::NONE);
215     IndexedShapeWrapper aShapeIndexed(aShape);
216     IndexedShapeWrapper bShapeIndexed(bShape);
217     IndexedShapeWrapper outputShapeIndexed(outputShape);
218     std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0);
219     bool lastIndex = false;
220     do {
221         uint32_t outputFlatIndex;
222         NN_RET_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex));
223         uint32_t aFlatIndex;
224         NN_RET_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex));
225         uint32_t bFlatIndex;
226         NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex));
227 
228         outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]);
229 
230         NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex));
231     } while (!lastIndex);
232     return true;
233 }
234 
mulFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)235 bool mulFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
236                 int32_t activation, float* out, const Shape& shapeOut) {
237     NNTRACE_TRANS("mulFloat32");
238     bool needBroadcast = !SameShape(shape1, shape2);
239 
240     if (needBroadcast) {
241         NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul");
242 #define ANDROID_NN_BROADCAST_MUL(activation)                                              \
243     tflite::optimized_ops::BroadcastMul<tflite::FusedActivationFunctionType::activation>( \
244             in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out,        \
245             convertShapeToDims(shapeOut))
246 
247         ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_MUL)
248 #undef ANDROID_NN_BROADCAST_MUL
249     } else {
250         float output_activation_min, output_activation_max;
251         CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
252 
253         NNTRACE_COMP_SWITCH("optimized_ops::Mul");
254         tflite::optimized_ops::Mul(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
255                                    output_activation_min, output_activation_max, out,
256                                    convertShapeToDims(shapeOut));
257     }
258 
259     return true;
260 }
261 
mulFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)262 bool mulFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
263                 int32_t activation, _Float16* out, const Shape& shapeOut) {
264     NNTRACE_TRANS("mulFloat16");
265     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &mulFloat32);
266 }
267 
268 template <typename T>
mulQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)269 bool mulQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
270                int32_t activation, T* out, const Shape& shapeOut) {
271     NNTRACE_TRANS("mulQuant8");
272     const int32_t input1_offset = -shape1.offset;
273     const int32_t input2_offset = -shape2.offset;
274     const int32_t output_offset = shapeOut.offset;
275     const double input_product_scale = shape1.scale * shape2.scale;
276     const double real_multiplier = input_product_scale / shapeOut.scale;
277     int32 output_multiplier;
278     int output_shift;
279     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier,
280                                                      &output_shift));
281 
282     constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
283     int32_t output_activation_min;
284     int32_t output_activation_max;
285     if constexpr (isSignedOp) {
286         CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
287                                      &output_activation_max);
288     } else {
289         CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
290                                       &output_activation_max);
291     }
292 
293     tflite::ArithmeticParams op_params;
294     op_params.input1_offset = input1_offset;
295     op_params.input2_offset = input2_offset;
296     op_params.output_offset = output_offset;
297     op_params.output_multiplier = output_multiplier;
298     op_params.output_shift = output_shift;
299     tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
300 
301     if constexpr (isSignedOp) {
302         NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastMul4DSlow");
303         tflite::reference_integer_ops::BroadcastMul4DSlow(op_params, convertShapeToTflshape(shape1),
304                                                           in1, convertShapeToTflshape(shape2), in2,
305                                                           convertShapeToTflshape(shapeOut), out);
306     } else {
307         NNTRACE_COMP_SWITCH("reference_ops::BroadcastMul4DSlow");
308         tflite::reference_ops::BroadcastMul4DSlow(op_params, convertShapeToTflshape(shape1), in1,
309                                                   convertShapeToTflshape(shape2), in2,
310                                                   convertShapeToTflshape(shapeOut), out);
311     }
312 
313     return true;
314 }
315 
subFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)316 bool subFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
317                 int32_t activation, float* out, const Shape& shapeOut) {
318     NNTRACE_TRANS("subFloat32");
319     NNTRACE_COMP_SWITCH("optimized_ops::Sub");
320     tflite::optimized_ops::Sub(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
321                                out, convertShapeToDims(shapeOut));
322 
323     // TFLite does not apply activation to broadcast sub.
324     float output_activation_min, output_activation_max;
325     CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
326     uint32_t numOutputElements = getNumberOfElements(shapeOut);
327     for (uint32_t i = 0; i < numOutputElements; i++) {
328         out[i] = std::min(std::max(out[i], output_activation_min), output_activation_max);
329     }
330     return true;
331 }
332 
subFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)333 bool subFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
334                 int32_t activation, _Float16* out, const Shape& shapeOut) {
335     NNTRACE_TRANS("subFloat16");
336     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &subFloat32);
337 }
338 
339 template <typename T>
subQuant8(const T * in1,const Shape & shape1,const T * in2,const Shape & shape2,int32_t activation,T * out,const Shape & shapeOut)340 bool subQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& shape2,
341                int32_t activation, T* out, const Shape& shapeOut) {
342     NNTRACE_TRANS("subQuant8");
343 
344     const int32_t input1_offset = -shape1.offset;
345     const int32_t input2_offset = -shape2.offset;
346     const int32_t output_offset = shapeOut.offset;
347     const int left_shift = 20;
348     const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
349     const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
350     const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
351     const double real_output_multiplier =
352             twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
353 
354     int32_t input1_multiplier;
355     int32_t input1_shift;
356     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, &input1_multiplier,
357                                                      &input1_shift));
358     int32_t input2_multiplier;
359     int32_t input2_shift;
360     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, &input2_multiplier,
361                                                      &input2_shift));
362     // Negate multiplier of the second input, so that we can use Add kernels.
363     input2_multiplier *= -1;
364 
365     int32_t output_multiplier;
366     int32_t output_shift;
367     NN_RET_CHECK(QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, &output_multiplier,
368                                                      &output_shift));
369 
370     constexpr bool isSignedOp = std::is_same<T, int8_t>::value;
371     int32_t output_activation_min;
372     int32_t output_activation_max;
373     if constexpr (isSignedOp) {
374         CalculateActivationRangeInt8(activation, shapeOut, &output_activation_min,
375                                      &output_activation_max);
376     } else {
377         CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
378                                       &output_activation_max);
379     }
380 
381     tflite::ArithmeticParams op_params;
382     op_params.left_shift = left_shift;
383     op_params.input1_offset = input1_offset;
384     op_params.input1_multiplier = input1_multiplier;
385     op_params.input1_shift = input1_shift;
386     op_params.input2_offset = input2_offset;
387     op_params.input2_multiplier = input2_multiplier;
388     op_params.input2_shift = input2_shift;
389     op_params.output_offset = output_offset;
390     op_params.output_multiplier = output_multiplier;
391     op_params.output_shift = output_shift;
392     tflite::SetActivationParams(output_activation_min, output_activation_max, &op_params);
393 
394     // We are using tflite::optimized_ops::BroadcastAdd unconditionally here
395     // because tflite::optimized_ops::Add fails to pass some of the
396     // sub_quantized_different_scales tests.
397     if constexpr (isSignedOp) {
398         NNTRACE_COMP_SWITCH("reference_integer_ops::BroadcastAdd4DSlow");
399         tflite::reference_integer_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1),
400                                                           in1, convertShapeToTflshape(shape2), in2,
401                                                           convertShapeToTflshape(shapeOut), out);
402     } else {
403         NNTRACE_COMP_SWITCH("reference_ops::BroadcastAdd4DSlow");
404         tflite::reference_ops::BroadcastAdd4DSlow(op_params, convertShapeToTflshape(shape1), in1,
405                                                   convertShapeToTflshape(shape2), in2,
406                                                   convertShapeToTflshape(shapeOut), out);
407     }
408 
409     return true;
410 }
411 
divFloat32(const float * in1,const Shape & shape1,const float * in2,const Shape & shape2,int32_t activation,float * out,const Shape & shapeOut)412 bool divFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
413                 int32_t activation, float* out, const Shape& shapeOut) {
414     NNTRACE_TRANS("divFloat32");
415     float output_activation_min, output_activation_max;
416     CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
417 
418     bool needBroadcast = !SameShape(shape1, shape2);
419     if (needBroadcast) {
420         NNTRACE_COMP_SWITCH("optimized_ops::BroadcastDiv");
421         tflite::optimized_ops::BroadcastDiv(
422                 in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
423                 output_activation_min, output_activation_max, out, convertShapeToDims(shapeOut));
424     } else {
425         NNTRACE_COMP_SWITCH("optimized_ops::Div");
426         tflite::optimized_ops::Div(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
427                                    output_activation_min, output_activation_max, out,
428                                    convertShapeToDims(shapeOut));
429     }
430     return true;
431 }
432 
divFloat16(const _Float16 * in1,const Shape & shape1,const _Float16 * in2,const Shape & shape2,int32_t activation,_Float16 * out,const Shape & shapeOut)433 bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
434                 int32_t activation, _Float16* out, const Shape& shapeOut) {
435     NNTRACE_TRANS("divFloat16");
436     return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &divFloat32);
437 }
438 
439 }  // namespace
440 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
441 
validate(OperationType opType,const IOperationValidationContext * context)442 Result<Version> validate(OperationType opType, const IOperationValidationContext* context) {
443     auto minSupportedVersion = (opType == OperationType::DIV || opType == OperationType::SUB)
444                                        ? Version::ANDROID_P
445                                        : Version::ANDROID_OC_MR1;
446     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
447     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
448     auto inputType = context->getInputType(kInputTensor1);
449     if (inputType == OperandType::TENSOR_FLOAT32) {
450         minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1);
451     } else if (inputType == OperandType::TENSOR_FLOAT16) {
452         minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q);
453     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
454         if (opType == OperationType::SUB) {
455             minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_Q);
456         } else if (opType == OperationType::DIV) {
457             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
458         } else if (opType == OperationType::MUL) {
459             Shape output = context->getOutputShape(kOutputTensor);
460             Shape input1 = context->getInputShape(kInputTensor1);
461             Shape input2 = context->getInputShape(kInputTensor2);
462             NN_RET_CHECK_GT(output.scale, input1.scale * input2.scale);
463             minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1);
464         } else {
465             minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_OC_MR1);
466         }
467     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED ||
468                inputType == OperandType::TENSOR_INT32) {
469         minSupportedVersion = combineVersions(minSupportedVersion, Version::ANDROID_R);
470     } else {
471         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << opType;
472     }
473     const Shape& input1 = context->getInputShape(kInputTensor1);
474     const Shape& input2 = context->getInputShape(kInputTensor2);
475     if (hasKnownRank(input1) && hasKnownRank(input2)) {
476         NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4);
477         NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4);
478     }
479     NN_RET_CHECK(validateInputTypes(context, {inputType, inputType, OperandType::INT32}));
480     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
481     return minSupportedVersion;
482 }
483 
484 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
prepare(IOperationExecutionContext * context)485 bool prepare(IOperationExecutionContext* context) {
486     Shape input1 = context->getInputShape(kInputTensor1);
487     Shape input2 = context->getInputShape(kInputTensor2);
488     Shape output = context->getOutputShape(kOutputTensor);
489     NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4);
490     NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4);
491     NN_RET_CHECK(calculateBroadcastedShape(input1, input2, &output));
492     return context->setOutputShape(kOutputTensor, output);
493 }
494 
executeAdd(IOperationExecutionContext * context)495 bool executeAdd(IOperationExecutionContext* context) {
496     // Bypass execution in the case of zero-sized input.
497     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
498     switch (context->getInputType(kInputTensor1)) {
499         case OperandType::TENSOR_FLOAT16:
500             return addFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
501                               context->getInputShape(kInputTensor1),
502                               context->getInputBuffer<_Float16>(kInputTensor2),
503                               context->getInputShape(kInputTensor2),
504                               context->getInputValue<int32_t>(kActivationScalar),
505                               context->getOutputBuffer<_Float16>(kOutputTensor),
506                               context->getOutputShape(kOutputTensor));
507         case OperandType::TENSOR_FLOAT32:
508             return addFloat32(context->getInputBuffer<float>(kInputTensor1),
509                               context->getInputShape(kInputTensor1),
510                               context->getInputBuffer<float>(kInputTensor2),
511                               context->getInputShape(kInputTensor2),
512                               context->getInputValue<int32_t>(kActivationScalar),
513                               context->getOutputBuffer<float>(kOutputTensor),
514                               context->getOutputShape(kOutputTensor));
515         case OperandType::TENSOR_QUANT8_ASYMM:
516             return addQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
517                              context->getInputShape(kInputTensor1),
518                              context->getInputBuffer<uint8_t>(kInputTensor2),
519                              context->getInputShape(kInputTensor2),
520                              context->getInputValue<int32_t>(kActivationScalar),
521                              context->getOutputBuffer<uint8_t>(kOutputTensor),
522                              context->getOutputShape(kOutputTensor));
523         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
524             return addQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
525                              context->getInputShape(kInputTensor1),
526                              context->getInputBuffer<int8_t>(kInputTensor2),
527                              context->getInputShape(kInputTensor2),
528                              context->getInputValue<int32_t>(kActivationScalar),
529                              context->getOutputBuffer<int8_t>(kOutputTensor),
530                              context->getOutputShape(kOutputTensor));
531         case OperandType::TENSOR_INT32:
532             return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
533                                 context->getInputShape(kInputTensor1),
534                                 context->getInputBuffer<int32_t>(kInputTensor2),
535                                 context->getInputShape(kInputTensor2),
536                                 context->getInputValue<int32_t>(kActivationScalar),
537                                 context->getOutputBuffer<int32_t>(kOutputTensor),
538                                 context->getOutputShape(kOutputTensor),
539                                 [](int32_t a, int32_t b) { return a + b; });
540         default:
541             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ADD";
542     }
543 }
544 
executeMul(IOperationExecutionContext * context)545 bool executeMul(IOperationExecutionContext* context) {
546     // Bypass execution in the case of zero-sized input.
547     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
548     switch (context->getInputType(kInputTensor1)) {
549         case OperandType::TENSOR_FLOAT16:
550             return mulFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
551                               context->getInputShape(kInputTensor1),
552                               context->getInputBuffer<_Float16>(kInputTensor2),
553                               context->getInputShape(kInputTensor2),
554                               context->getInputValue<int32_t>(kActivationScalar),
555                               context->getOutputBuffer<_Float16>(kOutputTensor),
556                               context->getOutputShape(kOutputTensor));
557         case OperandType::TENSOR_FLOAT32:
558             return mulFloat32(context->getInputBuffer<float>(kInputTensor1),
559                               context->getInputShape(kInputTensor1),
560                               context->getInputBuffer<float>(kInputTensor2),
561                               context->getInputShape(kInputTensor2),
562                               context->getInputValue<int32_t>(kActivationScalar),
563                               context->getOutputBuffer<float>(kOutputTensor),
564                               context->getOutputShape(kOutputTensor));
565         case OperandType::TENSOR_QUANT8_ASYMM:
566             return mulQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
567                              context->getInputShape(kInputTensor1),
568                              context->getInputBuffer<uint8_t>(kInputTensor2),
569                              context->getInputShape(kInputTensor2),
570                              context->getInputValue<int32_t>(kActivationScalar),
571                              context->getOutputBuffer<uint8_t>(kOutputTensor),
572                              context->getOutputShape(kOutputTensor));
573         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
574             return mulQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
575                              context->getInputShape(kInputTensor1),
576                              context->getInputBuffer<int8_t>(kInputTensor2),
577                              context->getInputShape(kInputTensor2),
578                              context->getInputValue<int32_t>(kActivationScalar),
579                              context->getOutputBuffer<int8_t>(kOutputTensor),
580                              context->getOutputShape(kOutputTensor));
581         case OperandType::TENSOR_INT32:
582             return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
583                                 context->getInputShape(kInputTensor1),
584                                 context->getInputBuffer<int32_t>(kInputTensor2),
585                                 context->getInputShape(kInputTensor2),
586                                 context->getInputValue<int32_t>(kActivationScalar),
587                                 context->getOutputBuffer<int32_t>(kOutputTensor),
588                                 context->getOutputShape(kOutputTensor),
589                                 [](int32_t a, int32_t b) { return a * b; });
590         default:
591             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation MUL";
592     }
593 }
594 
executeSub(IOperationExecutionContext * context)595 bool executeSub(IOperationExecutionContext* context) {
596     // Bypass execution in the case of zero-sized input.
597     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
598     switch (context->getInputType(kInputTensor1)) {
599         case OperandType::TENSOR_FLOAT16:
600             return subFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
601                               context->getInputShape(kInputTensor1),
602                               context->getInputBuffer<_Float16>(kInputTensor2),
603                               context->getInputShape(kInputTensor2),
604                               context->getInputValue<int32_t>(kActivationScalar),
605                               context->getOutputBuffer<_Float16>(kOutputTensor),
606                               context->getOutputShape(kOutputTensor));
607         case OperandType::TENSOR_FLOAT32:
608             return subFloat32(context->getInputBuffer<float>(kInputTensor1),
609                               context->getInputShape(kInputTensor1),
610                               context->getInputBuffer<float>(kInputTensor2),
611                               context->getInputShape(kInputTensor2),
612                               context->getInputValue<int32_t>(kActivationScalar),
613                               context->getOutputBuffer<float>(kOutputTensor),
614                               context->getOutputShape(kOutputTensor));
615         case OperandType::TENSOR_QUANT8_ASYMM:
616             return subQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
617                              context->getInputShape(kInputTensor1),
618                              context->getInputBuffer<uint8_t>(kInputTensor2),
619                              context->getInputShape(kInputTensor2),
620                              context->getInputValue<int32_t>(kActivationScalar),
621                              context->getOutputBuffer<uint8_t>(kOutputTensor),
622                              context->getOutputShape(kOutputTensor));
623         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
624             return subQuant8(context->getInputBuffer<int8_t>(kInputTensor1),
625                              context->getInputShape(kInputTensor1),
626                              context->getInputBuffer<int8_t>(kInputTensor2),
627                              context->getInputShape(kInputTensor2),
628                              context->getInputValue<int32_t>(kActivationScalar),
629                              context->getOutputBuffer<int8_t>(kOutputTensor),
630                              context->getOutputShape(kOutputTensor));
631         case OperandType::TENSOR_INT32:
632             return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
633                                 context->getInputShape(kInputTensor1),
634                                 context->getInputBuffer<int32_t>(kInputTensor2),
635                                 context->getInputShape(kInputTensor2),
636                                 context->getInputValue<int32_t>(kActivationScalar),
637                                 context->getOutputBuffer<int32_t>(kOutputTensor),
638                                 context->getOutputShape(kOutputTensor),
639                                 [](int32_t a, int32_t b) { return a - b; });
640         default:
641             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation SUB";
642     }
643 }
644 
executeDiv(IOperationExecutionContext * context)645 bool executeDiv(IOperationExecutionContext* context) {
646     // Bypass execution in the case of zero-sized input.
647     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
648     switch (context->getInputType(kInputTensor1)) {
649         case OperandType::TENSOR_FLOAT16:
650             return divFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
651                               context->getInputShape(kInputTensor1),
652                               context->getInputBuffer<_Float16>(kInputTensor2),
653                               context->getInputShape(kInputTensor2),
654                               context->getInputValue<int32_t>(kActivationScalar),
655                               context->getOutputBuffer<_Float16>(kOutputTensor),
656                               context->getOutputShape(kOutputTensor));
657         case OperandType::TENSOR_FLOAT32:
658             return divFloat32(context->getInputBuffer<float>(kInputTensor1),
659                               context->getInputShape(kInputTensor1),
660                               context->getInputBuffer<float>(kInputTensor2),
661                               context->getInputShape(kInputTensor2),
662                               context->getInputValue<int32_t>(kActivationScalar),
663                               context->getOutputBuffer<float>(kOutputTensor),
664                               context->getOutputShape(kOutputTensor));
665         case OperandType::TENSOR_INT32:
666             return executeInt32(context->getInputBuffer<int32_t>(kInputTensor1),
667                                 context->getInputShape(kInputTensor1),
668                                 context->getInputBuffer<int32_t>(kInputTensor2),
669                                 context->getInputShape(kInputTensor2),
670                                 context->getInputValue<int32_t>(kActivationScalar),
671                                 context->getOutputBuffer<int32_t>(kOutputTensor),
672                                 context->getOutputShape(kOutputTensor), [](int32_t a, int32_t b) {
673                                     // In NNAPI, DIV by zero is undefined, but should not crash.
674                                     if (b == 0) return 0;
675                                     int32_t result = a / b;
676                                     if (a % b != 0 && ((a < 0) != (b < 0))) {
677                                         // Implement "floor division".
678                                         --result;
679                                     }
680                                     return result;
681                                 });
682         default:
683             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
684     }
685 }
686 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
687 
688 }  // namespace broadcast
689 
690 using std::placeholders::_1;
691 NN_REGISTER_OPERATION(ADD, "ADD", std::bind(broadcast::validate, OperationType::ADD, _1),
692                       broadcast::prepare, broadcast::executeAdd, .allowZeroSizedInput = true);
693 NN_REGISTER_OPERATION(MUL, "MUL", std::bind(broadcast::validate, OperationType::MUL, _1),
694                       broadcast::prepare, broadcast::executeMul, .allowZeroSizedInput = true);
695 NN_REGISTER_OPERATION(SUB, "SUB", std::bind(broadcast::validate, OperationType::SUB, _1),
696                       broadcast::prepare, broadcast::executeSub, .allowZeroSizedInput = true);
697 NN_REGISTER_OPERATION(DIV, "DIV", std::bind(broadcast::validate, OperationType::DIV, _1),
698                       broadcast::prepare, broadcast::executeDiv, .allowZeroSizedInput = true);
699 
700 }  // namespace nn
701 }  // namespace android
702