1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include <algorithm>
20 #include <limits>
21 #include <vector>
22 
23 #include "ActivationFunctor.h"
24 #include "OperationResolver.h"
25 #include "OperationsUtils.h"
26 #include "Tracing.h"
27 
28 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
29 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
30 #include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
31 #include <tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h>
32 #include <tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h>
33 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
34 
35 #include "CpuOperationUtils.h"
36 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
37 
38 namespace android {
39 namespace nn {
40 
41 namespace activation {
42 
43 constexpr uint32_t kNumInputs = 1;
44 constexpr uint32_t kInputTensor = 0;
45 
46 constexpr uint32_t kNumOutputs = 1;
47 constexpr uint32_t kOutputTensor = 0;
48 
49 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
50 namespace {
51 
52 template <typename T>
reluFloat(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape,float reluMin=0.f,float reluMax=std::numeric_limits<float>::max ())53 bool reluFloat(const T* inputData, const Shape& inputShape, T* outputData, const Shape& outputShape,
54                float reluMin = 0.f, float reluMax = std::numeric_limits<float>::max()) {
55     NNTRACE_COMP("reluX");
56     int numElements = getNumberOfElements(inputShape);
57     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
58         *outputData = static_cast<T>(
59                 std::min(std::max(reluMin, static_cast<float>(*inputData)), reluMax));
60     }
61     return true;
62 }
63 template bool reluFloat<float>(const float* inputData, const Shape& inputShape, float* outputData,
64                                const Shape& outputShape, float reluMin, float reluMax);
65 template bool reluFloat<_Float16>(const _Float16* inputData, const Shape& inputShape,
66                                   _Float16* outputData, const Shape& outputShape, float reluMin,
67                                   float reluMax);
68 
69 template <typename T>
relu1Float(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)70 bool relu1Float(const T* inputData, const Shape& inputShape, T* outputData,
71                 const Shape& outputShape) {
72     return reluFloat(inputData, inputShape, outputData, outputShape, -1.f, 1.f);
73 }
74 template bool relu1Float<float>(const float* inputData, const Shape& inputShape, float* outputData,
75                                 const Shape& outputShape);
76 template bool relu1Float<_Float16>(const _Float16* inputData, const Shape& inputShape,
77                                    _Float16* outputData, const Shape& outputShape);
78 
79 template <typename T>
relu6Float(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)80 bool relu6Float(const T* inputData, const Shape& inputShape, T* outputData,
81                 const Shape& outputShape) {
82     return reluFloat(inputData, inputShape, outputData, outputShape, 0.f, 6.f);
83 }
84 template bool relu6Float<float>(const float* inputData, const Shape& inputShape, float* outputData,
85                                 const Shape& outputShape);
86 template bool relu6Float<_Float16>(const _Float16* inputData, const Shape& inputShape,
87                                    _Float16* outputData, const Shape& outputShape);
88 
tanhFloat16(const _Float16 * inputData,const Shape & inputShape,_Float16 * outputData,const Shape & outputShape)89 bool tanhFloat16(const _Float16* inputData, const Shape& inputShape, _Float16* outputData,
90                  const Shape& outputShape) {
91     NNTRACE_COMP("tanhFloat16");
92     int numElements = getNumberOfElements(inputShape);
93     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
94         *outputData = static_cast<_Float16>(std::tanh(static_cast<float>(*inputData)));
95     }
96     return true;
97 }
98 
tanhFloat32(const float * inputData,const Shape & inputShape,float * outputData,const Shape & outputShape)99 bool tanhFloat32(const float* inputData, const Shape& inputShape, float* outputData,
100                  const Shape& outputShape) {
101     NNTRACE_COMP("tanhFloat32");
102     int numElements = getNumberOfElements(inputShape);
103     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
104         *outputData = std::tanh(*inputData);
105     }
106     return true;
107 }
108 
109 template <typename T>
logisticFloat(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)110 bool logisticFloat(const T* inputData, const Shape& inputShape, T* outputData,
111                    const Shape& outputShape) {
112     NNTRACE_COMP("logisticFloat");
113     int numElements = getNumberOfElements(inputShape);
114     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
115         *outputData = static_cast<T>(1.f / (1.f + std::exp(static_cast<float>(-*inputData))));
116     }
117     return true;
118 }
119 template bool logisticFloat<float>(const float* inputData, const Shape& inputShape,
120                                    float* outputData, const Shape& outputShape);
121 template bool logisticFloat<_Float16>(const _Float16* inputData, const Shape& inputShape,
122                                       _Float16* outputData, const Shape& outputShape);
123 
124 template <ActivationFn activation>
reluXQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)125 inline bool reluXQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
126                         const Shape& outputShape) {
127     int numElements = getNumberOfElements(inputShape);
128     int32_t output_activation_min = 0;
129     int32_t output_activation_max = 0;
130 
131     CalculateActivationRangeUint8(activation, inputShape, &output_activation_min,
132                                   &output_activation_max);
133 
134     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
135         *outputData = std::min((uint8_t)output_activation_max,
136                                std::max((uint8_t)output_activation_min, *inputData));
137     }
138     return true;
139 }
140 
reluQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)141 bool reluQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
142                 const Shape& outputShape) {
143     NNTRACE_COMP("reluQuant8");
144     return reluXQuant8<kActivationRelu>(inputData, inputShape, outputData, outputShape);
145 }
146 
relu1Quant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)147 bool relu1Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
148                  const Shape& outputShape) {
149     NNTRACE_COMP("relu1Quant8");
150     return reluXQuant8<kActivationRelu1>(inputData, inputShape, outputData, outputShape);
151 }
152 
relu6Quant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)153 bool relu6Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
154                  const Shape& outputShape) {
155     NNTRACE_COMP("relu6Quant8");
156     return reluXQuant8<kActivationRelu6>(inputData, inputShape, outputData, outputShape);
157 }
158 
tanhQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)159 bool tanhQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
160                 const Shape& outputShape) {
161     NNTRACE_TRANS("tanhQuant8");
162     if (outputShape.offset != 128 || outputShape.scale != 1.f / 128) {
163         LOG(ERROR) << "incorrect scale or offset for TANH output";
164         return false;
165     }
166 
167     int numElements = getNumberOfElements(inputShape);
168     static constexpr int kInputIntegerBits = 4;
169 
170     const double input_real_multiplier =
171             inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
172 
173     int32_t input_multiplier = 0;
174     int32_t input_left_shift = 0;
175     if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
176                                           &input_left_shift)) {
177         return false;
178     }
179     int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
180 
181     NNTRACE_COMP_SWITCH("optimized_ops::Tanh");
182     tflite::optimized_ops::Tanh(inputData, convertShapeToTflshape(inputShape), inputShape.offset,
183                                 input_range_radius, input_multiplier, input_left_shift, outputData,
184                                 convertShapeToTflshape(outputShape));
185 
186     return true;
187 }
188 
logisticQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)189 bool logisticQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
190                     const Shape& outputShape) {
191     NNTRACE_TRANS("logisticQuant8");
192     if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) {
193         LOG(ERROR) << "incorrect scale / offset for output";
194         return false;
195     }
196 
197     int numElements = getNumberOfElements(inputShape);
198     static constexpr int kInputIntegerBits = 4;
199 
200     const double input_real_multiplier =
201             inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
202 
203     int32_t input_multiplier = 0;
204     int32_t input_left_shift = 0;
205     if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
206                                           &input_left_shift)) {
207         return false;
208     }
209     int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
210 
211     NNTRACE_COMP_SWITCH("optimized_ops::Logistic");
212     tflite::optimized_ops::Logistic(
213             inputData, convertShapeToTflshape(inputShape), inputShape.offset, input_range_radius,
214             input_multiplier, input_left_shift, outputData, convertShapeToTflshape(outputShape));
215 
216     return true;
217 }
218 
219 template <ActivationFn activation>
reluXQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)220 inline bool reluXQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
221                               const Shape& outputShape) {
222     int numElements = getNumberOfElements(inputShape);
223     int32_t output_activation_min = 0;
224     int32_t output_activation_max = 0;
225 
226     CalculateActivationRangeInt8(activation, inputShape, &output_activation_min,
227                                  &output_activation_max);
228 
229     for (int i = 0; i < numElements; i++, inputData++, outputData++) {
230         *outputData = std::min((int8_t)output_activation_max,
231                                std::max((int8_t)output_activation_min, *inputData));
232     }
233     return true;
234 }
235 
reluQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)236 bool reluQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
237                       const Shape& outputShape) {
238     NNTRACE_COMP("reluQuant8");
239     return reluXQuant8Signed<kActivationRelu>(inputData, inputShape, outputData, outputShape);
240 }
241 
relu1Quant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)242 bool relu1Quant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
243                        const Shape& outputShape) {
244     NNTRACE_COMP("relu1Quant8");
245     return reluXQuant8Signed<kActivationRelu1>(inputData, inputShape, outputData, outputShape);
246 }
247 
relu6Quant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)248 bool relu6Quant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
249                        const Shape& outputShape) {
250     NNTRACE_COMP("relu6Quant8");
251     return reluXQuant8Signed<kActivationRelu6>(inputData, inputShape, outputData, outputShape);
252 }
253 
tanhQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)254 bool tanhQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
255                       const Shape& outputShape) {
256     NNTRACE_TRANS("tanhQuant8Signed");
257     if (outputShape.offset != 0 || outputShape.scale != 1.f / 128) {
258         LOG(ERROR) << "incorrect scale or offset for TANH output";
259         return false;
260     }
261 
262     int numElements = getNumberOfElements(inputShape);
263     static constexpr int kInputIntegerBits = 4;
264 
265     const double input_real_multiplier =
266             inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
267 
268     int32_t input_multiplier = 0;
269     int32_t input_left_shift = 0;
270     if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
271                                           &input_left_shift)) {
272         return false;
273     }
274     int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
275 
276     NNTRACE_COMP_SWITCH("reference_integer_ops::Tanh");
277     tflite::reference_integer_ops::Tanh(inputShape.offset, input_range_radius, input_multiplier,
278                                         input_left_shift, convertShapeToTflshape(inputShape),
279                                         inputData, convertShapeToTflshape(outputShape), outputData);
280 
281     return true;
282 }
283 
logisticQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)284 bool logisticQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
285                           const Shape& outputShape) {
286     NNTRACE_TRANS("logisticQuant8Signed");
287     if (outputShape.offset != -128 || outputShape.scale != 1.f / 256) {
288         LOG(ERROR) << "incorrect scale / offset for output";
289         return false;
290     }
291 
292     int numElements = getNumberOfElements(inputShape);
293     static constexpr int kInputIntegerBits = 4;
294 
295     const double input_real_multiplier =
296             inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
297 
298     int32_t input_multiplier = 0;
299     int32_t input_left_shift = 0;
300     if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
301                                           &input_left_shift)) {
302         return false;
303     }
304     int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
305 
306     NNTRACE_COMP_SWITCH("reference_integer_ops::Logistic");
307     tflite::reference_integer_ops::Logistic(inputShape.offset, input_range_radius, input_multiplier,
308                                             input_left_shift, numElements, inputData, outputData);
309 
310     return true;
311 }
312 
DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32,int16_t * multiplier_int16)313 void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32, int16_t* multiplier_int16) {
314     TFLITE_DCHECK_GE(multiplier_int32, 0);
315     static constexpr int32_t kRoundingOffset = 1 << 15;
316     if (multiplier_int32 >= std::numeric_limits<int32_t>::max() - kRoundingOffset) {
317         *multiplier_int16 = std::numeric_limits<int16_t>::max();
318         return;
319     }
320     const int32_t result = (multiplier_int32 + kRoundingOffset) >> 16;
321     TFLITE_DCHECK_LE(result << 16, multiplier_int32 + kRoundingOffset);
322     TFLITE_DCHECK_GT(result << 16, multiplier_int32 - kRoundingOffset);
323     *multiplier_int16 = result;
324     TFLITE_DCHECK_EQ(*multiplier_int16, result);
325 }
326 
327 template <typename T>
hardSwishQuant(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)328 bool hardSwishQuant(const T* inputData, const Shape& inputShape, T* outputData,
329                     const Shape& outputShape) {
330     tflite::HardSwishParams params;
331     params.input_zero_point = inputShape.offset;
332     params.output_zero_point = outputShape.offset;
333     const float input_scale = inputShape.scale;
334     const float hires_input_scale = (1.0f / 128.0f) * input_scale;
335     const float reluish_scale = 3.0f / 32768.0f;
336     const float output_scale = outputShape.scale;
337 
338     const float output_multiplier = hires_input_scale / output_scale;
339 
340     int32_t output_multiplier_fixedpoint_int32;
341     NN_RET_CHECK(QuantizeMultiplier(output_multiplier, &output_multiplier_fixedpoint_int32,
342                                     &params.output_multiplier_exponent));
343     DownScaleInt32ToInt16Multiplier(output_multiplier_fixedpoint_int32,
344                                     &params.output_multiplier_fixedpoint_int16);
345     NN_RET_CHECK(params.output_multiplier_exponent <= 0);
346 
347     const float reluish_multiplier = hires_input_scale / reluish_scale;
348     int32_t reluish_multiplier_fixedpoint_int32;
349     NN_RET_CHECK(QuantizeMultiplier(reluish_multiplier, &reluish_multiplier_fixedpoint_int32,
350                                     &params.reluish_multiplier_exponent));
351     DownScaleInt32ToInt16Multiplier(reluish_multiplier_fixedpoint_int32,
352                                     &params.reluish_multiplier_fixedpoint_int16);
353 
354     tflite::reference_ops::HardSwish(params, convertShapeToTflshape(inputShape), inputData,
355                                      convertShapeToTflshape(outputShape), outputData);
356     return true;
357 }
358 
359 }  // namespace
360 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
361 
validate(OperationType opType,const IOperationValidationContext * context)362 Result<Version> validate(OperationType opType, const IOperationValidationContext* context) {
363     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
364     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
365     auto inputType = context->getInputType(kInputTensor);
366     auto minSupportedVersion = Version::ANDROID_OC_MR1;
367     if (inputType == OperandType::TENSOR_FLOAT32) {
368         minSupportedVersion = Version::ANDROID_OC_MR1;
369     } else if (inputType == OperandType::TENSOR_FLOAT16) {
370         minSupportedVersion = Version::ANDROID_Q;
371     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
372         if (opType == OperationType::TANH) {
373             minSupportedVersion = Version::ANDROID_Q;
374         } else {
375             minSupportedVersion = Version::ANDROID_OC_MR1;
376         }
377     } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
378         minSupportedVersion = Version::ANDROID_R;
379     } else {
380         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << opType;
381     }
382     const Shape& input = context->getInputShape(kInputTensor);
383     if (hasKnownRank(input)) {
384         NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
385     }
386     NN_RET_CHECK(validateInputTypes(context, {inputType}));
387     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
388     return minSupportedVersion;
389 }
390 
validateHardSwish(const IOperationValidationContext * context)391 Result<Version> validateHardSwish(const IOperationValidationContext* context) {
392     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
393     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
394     auto inputType = context->getInputType(kInputTensor);
395     auto minSupportedVersion = Version::ANDROID_OC_MR1;
396     if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 ||
397         inputType == OperandType::TENSOR_QUANT8_ASYMM ||
398         inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
399         minSupportedVersion = Version::ANDROID_R;
400     } else {
401         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ELU";
402     }
403     NN_RET_CHECK(validateInputTypes(context, {inputType}));
404     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
405     return minSupportedVersion;
406 }
407 
408 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
prepare(OperationType opType,IOperationExecutionContext * context)409 bool prepare(OperationType opType, IOperationExecutionContext* context) {
410     Shape input = context->getInputShape(kInputTensor);
411     if (opType != OperationType::HARD_SWISH) {
412         NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
413     }
414     Shape output = input;
415     if (input.type == OperandType::TENSOR_QUANT8_ASYMM ||
416         input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
417         bool isSigned = input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED;
418         switch (opType) {
419             case OperationType::HARD_SWISH: {
420                 auto outputShape = context->getOutputShape(kOutputTensor);
421                 output.scale = outputShape.scale;
422                 output.offset = outputShape.offset;
423             } break;
424             case OperationType::RELU:
425             case OperationType::RELU1:
426             case OperationType::RELU6:
427                 break;
428             case OperationType::LOGISTIC:
429                 output.scale = 1.f / 256;
430                 output.offset = isSigned ? -128 : 0;
431                 break;
432             case OperationType::TANH:
433                 output.scale = 1.f / 128;
434                 output.offset = isSigned ? 0 : 128;
435                 break;
436             default:
437                 NN_RET_CHECK_FAIL() << "Unsupported operation type";
438         }
439     }
440     return context->setOutputShape(kOutputTensor, output);
441 }
442 
executeRelu(IOperationExecutionContext * context)443 bool executeRelu(IOperationExecutionContext* context) {
444     // Bypass execution in the case of zero-sized input.
445     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
446     switch (context->getInputType(kInputTensor)) {
447         case OperandType::TENSOR_FLOAT16:
448             return reluFloat(context->getInputBuffer<_Float16>(kInputTensor),
449                              context->getInputShape(kInputTensor),
450                              context->getOutputBuffer<_Float16>(kOutputTensor),
451                              context->getOutputShape(kOutputTensor));
452         case OperandType::TENSOR_FLOAT32:
453             return reluFloat(context->getInputBuffer<float>(kInputTensor),
454                              context->getInputShape(kInputTensor),
455                              context->getOutputBuffer<float>(kOutputTensor),
456                              context->getOutputShape(kOutputTensor));
457         case OperandType::TENSOR_QUANT8_ASYMM:
458             return reluQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
459                               context->getInputShape(kInputTensor),
460                               context->getOutputBuffer<uint8_t>(kOutputTensor),
461                               context->getOutputShape(kOutputTensor));
462         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
463             return reluQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
464                                     context->getInputShape(kInputTensor),
465                                     context->getOutputBuffer<int8_t>(kOutputTensor),
466                                     context->getOutputShape(kOutputTensor));
467         default:
468             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU";
469     }
470 }
471 
executeRelu1(IOperationExecutionContext * context)472 bool executeRelu1(IOperationExecutionContext* context) {
473     // Bypass execution in the case of zero-sized input.
474     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
475     switch (context->getInputType(kInputTensor)) {
476         case OperandType::TENSOR_FLOAT16:
477             return relu1Float(context->getInputBuffer<_Float16>(kInputTensor),
478                               context->getInputShape(kInputTensor),
479                               context->getOutputBuffer<_Float16>(kOutputTensor),
480                               context->getOutputShape(kOutputTensor));
481         case OperandType::TENSOR_FLOAT32:
482             return relu1Float(context->getInputBuffer<float>(kInputTensor),
483                               context->getInputShape(kInputTensor),
484                               context->getOutputBuffer<float>(kOutputTensor),
485                               context->getOutputShape(kOutputTensor));
486         case OperandType::TENSOR_QUANT8_ASYMM:
487             return relu1Quant8(context->getInputBuffer<uint8_t>(kInputTensor),
488                                context->getInputShape(kInputTensor),
489                                context->getOutputBuffer<uint8_t>(kOutputTensor),
490                                context->getOutputShape(kOutputTensor));
491         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
492             return relu1Quant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
493                                      context->getInputShape(kInputTensor),
494                                      context->getOutputBuffer<int8_t>(kOutputTensor),
495                                      context->getOutputShape(kOutputTensor));
496         default:
497             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU1";
498     }
499 }
500 
executeRelu6(IOperationExecutionContext * context)501 bool executeRelu6(IOperationExecutionContext* context) {
502     // Bypass execution in the case of zero-sized input.
503     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
504     switch (context->getInputType(kInputTensor)) {
505         case OperandType::TENSOR_FLOAT16:
506             return relu6Float(context->getInputBuffer<_Float16>(kInputTensor),
507                               context->getInputShape(kInputTensor),
508                               context->getOutputBuffer<_Float16>(kOutputTensor),
509                               context->getOutputShape(kOutputTensor));
510         case OperandType::TENSOR_FLOAT32:
511             return relu6Float(context->getInputBuffer<float>(kInputTensor),
512                               context->getInputShape(kInputTensor),
513                               context->getOutputBuffer<float>(kOutputTensor),
514                               context->getOutputShape(kOutputTensor));
515         case OperandType::TENSOR_QUANT8_ASYMM:
516             return relu6Quant8(context->getInputBuffer<uint8_t>(kInputTensor),
517                                context->getInputShape(kInputTensor),
518                                context->getOutputBuffer<uint8_t>(kOutputTensor),
519                                context->getOutputShape(kOutputTensor));
520         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
521             return relu6Quant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
522                                      context->getInputShape(kInputTensor),
523                                      context->getOutputBuffer<int8_t>(kOutputTensor),
524                                      context->getOutputShape(kOutputTensor));
525         default:
526             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU6";
527     }
528 }
529 
executeLogistic(IOperationExecutionContext * context)530 bool executeLogistic(IOperationExecutionContext* context) {
531     // Bypass execution in the case of zero-sized input.
532     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
533     switch (context->getInputType(kInputTensor)) {
534         case OperandType::TENSOR_FLOAT16:
535             return logisticFloat(context->getInputBuffer<_Float16>(kInputTensor),
536                                  context->getInputShape(kInputTensor),
537                                  context->getOutputBuffer<_Float16>(kOutputTensor),
538                                  context->getOutputShape(kOutputTensor));
539         case OperandType::TENSOR_FLOAT32:
540             return logisticFloat(context->getInputBuffer<float>(kInputTensor),
541                                  context->getInputShape(kInputTensor),
542                                  context->getOutputBuffer<float>(kOutputTensor),
543                                  context->getOutputShape(kOutputTensor));
544         case OperandType::TENSOR_QUANT8_ASYMM:
545             return logisticQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
546                                   context->getInputShape(kInputTensor),
547                                   context->getOutputBuffer<uint8_t>(kOutputTensor),
548                                   context->getOutputShape(kOutputTensor));
549         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
550             return logisticQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
551                                         context->getInputShape(kInputTensor),
552                                         context->getOutputBuffer<int8_t>(kOutputTensor),
553                                         context->getOutputShape(kOutputTensor));
554         default:
555             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC";
556     }
557 }
558 
executeTanh(IOperationExecutionContext * context)559 bool executeTanh(IOperationExecutionContext* context) {
560     // Bypass execution in the case of zero-sized input.
561     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
562     switch (context->getInputType(kInputTensor)) {
563         case OperandType::TENSOR_FLOAT16:
564             return tanhFloat16(context->getInputBuffer<_Float16>(kInputTensor),
565                                context->getInputShape(kInputTensor),
566                                context->getOutputBuffer<_Float16>(kOutputTensor),
567                                context->getOutputShape(kOutputTensor));
568         case OperandType::TENSOR_FLOAT32:
569             return tanhFloat32(context->getInputBuffer<float>(kInputTensor),
570                                context->getInputShape(kInputTensor),
571                                context->getOutputBuffer<float>(kOutputTensor),
572                                context->getOutputShape(kOutputTensor));
573         case OperandType::TENSOR_QUANT8_ASYMM:
574             return tanhQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
575                               context->getInputShape(kInputTensor),
576                               context->getOutputBuffer<uint8_t>(kOutputTensor),
577                               context->getOutputShape(kOutputTensor));
578         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
579             return tanhQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
580                                     context->getInputShape(kInputTensor),
581                                     context->getOutputBuffer<int8_t>(kOutputTensor),
582                                     context->getOutputShape(kOutputTensor));
583         default:
584             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
585     }
586 }
587 
executeHardSwish(IOperationExecutionContext * context)588 bool executeHardSwish(IOperationExecutionContext* context) {
589     // Bypass execution in the case of zero-sized input.
590     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
591     switch (context->getInputType(kInputTensor)) {
592         case OperandType::TENSOR_FLOAT16: {
593             const Shape& inputShape = context->getInputShape(kInputTensor);
594             const Shape& outputShape = context->getOutputShape(kOutputTensor);
595             std::vector<float> inputFloat(getNumberOfElements(inputShape));
596             std::vector<float> outputFloat(getNumberOfElements(outputShape));
597             convertFloat16ToFloat32(context->getInputBuffer<_Float16>(kInputTensor), &inputFloat);
598             tflite::reference_ops::HardSwish(convertShapeToTflshape(inputShape), inputFloat.data(),
599                                              convertShapeToTflshape(outputShape),
600                                              outputFloat.data());
601             convertFloat32ToFloat16(outputFloat, context->getOutputBuffer<_Float16>(kOutputTensor));
602             return true;
603         }
604         case OperandType::TENSOR_FLOAT32: {
605             tflite::reference_ops::HardSwish(
606                     convertShapeToTflshape(context->getInputShape(kInputTensor)),
607                     context->getInputBuffer<float>(kInputTensor),
608                     convertShapeToTflshape(context->getOutputShape(kOutputTensor)),
609                     context->getOutputBuffer<float>(kOutputTensor));
610             return true;
611         }
612         case OperandType::TENSOR_QUANT8_ASYMM:
613             return hardSwishQuant(context->getInputBuffer<uint8_t>(kInputTensor),
614                                   context->getInputShape(kInputTensor),
615                                   context->getOutputBuffer<uint8_t>(kOutputTensor),
616                                   context->getOutputShape(kOutputTensor));
617         case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
618             return hardSwishQuant(context->getInputBuffer<int8_t>(kInputTensor),
619                                   context->getInputShape(kInputTensor),
620                                   context->getOutputBuffer<int8_t>(kOutputTensor),
621                                   context->getOutputShape(kOutputTensor));
622         default:
623             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
624     }
625 }
626 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
627 
628 }  // namespace activation
629 
630 using std::placeholders::_1;
631 NN_REGISTER_OPERATION(RELU, "RELU", std::bind(activation::validate, OperationType::RELU, _1),
632                       std::bind(activation::prepare, OperationType::RELU, _1),
633                       activation::executeRelu, .allowZeroSizedInput = true);
634 NN_REGISTER_OPERATION(RELU1, "RELU1", std::bind(activation::validate, OperationType::RELU1, _1),
635                       std::bind(activation::prepare, OperationType::RELU1, _1),
636                       activation::executeRelu1, .allowZeroSizedInput = true);
637 NN_REGISTER_OPERATION(RELU6, "RELU6", std::bind(activation::validate, OperationType::RELU6, _1),
638                       std::bind(activation::prepare, OperationType::RELU6, _1),
639                       activation::executeRelu6, .allowZeroSizedInput = true);
640 NN_REGISTER_OPERATION(LOGISTIC, "LOGISTIC",
641                       std::bind(activation::validate, OperationType::LOGISTIC, _1),
642                       std::bind(activation::prepare, OperationType::LOGISTIC, _1),
643                       activation::executeLogistic, .allowZeroSizedInput = true);
644 NN_REGISTER_OPERATION(TANH, "TANH", std::bind(activation::validate, OperationType::TANH, _1),
645                       std::bind(activation::prepare, OperationType::TANH, _1),
646                       activation::executeTanh, .allowZeroSizedInput = true);
647 NN_REGISTER_OPERATION(HARD_SWISH, "HARD_SWISH", activation::validateHardSwish,
648                       std::bind(activation::prepare, OperationType::HARD_SWISH, _1),
649                       activation::executeHardSwish, .allowZeroSizedInput = true);
650 
651 }  // namespace nn
652 }  // namespace android
653