1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <algorithm>
20 #include <limits>
21 #include <vector>
22
23 #include "ActivationFunctor.h"
24 #include "OperationResolver.h"
25 #include "OperationsUtils.h"
26 #include "Tracing.h"
27
28 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
29 #include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
30 #include <tensorflow/lite/kernels/internal/optimized/optimized_ops.h>
31 #include <tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h>
32 #include <tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h>
33 #include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
34
35 #include "CpuOperationUtils.h"
36 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
37
38 namespace android {
39 namespace nn {
40
41 namespace activation {
42
43 constexpr uint32_t kNumInputs = 1;
44 constexpr uint32_t kInputTensor = 0;
45
46 constexpr uint32_t kNumOutputs = 1;
47 constexpr uint32_t kOutputTensor = 0;
48
49 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
50 namespace {
51
52 template <typename T>
reluFloat(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape,float reluMin=0.f,float reluMax=std::numeric_limits<float>::max ())53 bool reluFloat(const T* inputData, const Shape& inputShape, T* outputData, const Shape& outputShape,
54 float reluMin = 0.f, float reluMax = std::numeric_limits<float>::max()) {
55 NNTRACE_COMP("reluX");
56 int numElements = getNumberOfElements(inputShape);
57 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
58 *outputData = static_cast<T>(
59 std::min(std::max(reluMin, static_cast<float>(*inputData)), reluMax));
60 }
61 return true;
62 }
63 template bool reluFloat<float>(const float* inputData, const Shape& inputShape, float* outputData,
64 const Shape& outputShape, float reluMin, float reluMax);
65 template bool reluFloat<_Float16>(const _Float16* inputData, const Shape& inputShape,
66 _Float16* outputData, const Shape& outputShape, float reluMin,
67 float reluMax);
68
69 template <typename T>
relu1Float(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)70 bool relu1Float(const T* inputData, const Shape& inputShape, T* outputData,
71 const Shape& outputShape) {
72 return reluFloat(inputData, inputShape, outputData, outputShape, -1.f, 1.f);
73 }
74 template bool relu1Float<float>(const float* inputData, const Shape& inputShape, float* outputData,
75 const Shape& outputShape);
76 template bool relu1Float<_Float16>(const _Float16* inputData, const Shape& inputShape,
77 _Float16* outputData, const Shape& outputShape);
78
79 template <typename T>
relu6Float(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)80 bool relu6Float(const T* inputData, const Shape& inputShape, T* outputData,
81 const Shape& outputShape) {
82 return reluFloat(inputData, inputShape, outputData, outputShape, 0.f, 6.f);
83 }
84 template bool relu6Float<float>(const float* inputData, const Shape& inputShape, float* outputData,
85 const Shape& outputShape);
86 template bool relu6Float<_Float16>(const _Float16* inputData, const Shape& inputShape,
87 _Float16* outputData, const Shape& outputShape);
88
tanhFloat16(const _Float16 * inputData,const Shape & inputShape,_Float16 * outputData,const Shape & outputShape)89 bool tanhFloat16(const _Float16* inputData, const Shape& inputShape, _Float16* outputData,
90 const Shape& outputShape) {
91 NNTRACE_COMP("tanhFloat16");
92 int numElements = getNumberOfElements(inputShape);
93 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
94 *outputData = static_cast<_Float16>(std::tanh(static_cast<float>(*inputData)));
95 }
96 return true;
97 }
98
tanhFloat32(const float * inputData,const Shape & inputShape,float * outputData,const Shape & outputShape)99 bool tanhFloat32(const float* inputData, const Shape& inputShape, float* outputData,
100 const Shape& outputShape) {
101 NNTRACE_COMP("tanhFloat32");
102 int numElements = getNumberOfElements(inputShape);
103 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
104 *outputData = std::tanh(*inputData);
105 }
106 return true;
107 }
108
109 template <typename T>
logisticFloat(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)110 bool logisticFloat(const T* inputData, const Shape& inputShape, T* outputData,
111 const Shape& outputShape) {
112 NNTRACE_COMP("logisticFloat");
113 int numElements = getNumberOfElements(inputShape);
114 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
115 *outputData = static_cast<T>(1.f / (1.f + std::exp(static_cast<float>(-*inputData))));
116 }
117 return true;
118 }
119 template bool logisticFloat<float>(const float* inputData, const Shape& inputShape,
120 float* outputData, const Shape& outputShape);
121 template bool logisticFloat<_Float16>(const _Float16* inputData, const Shape& inputShape,
122 _Float16* outputData, const Shape& outputShape);
123
124 template <ActivationFn activation>
reluXQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)125 inline bool reluXQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
126 const Shape& outputShape) {
127 int numElements = getNumberOfElements(inputShape);
128 int32_t output_activation_min = 0;
129 int32_t output_activation_max = 0;
130
131 CalculateActivationRangeUint8(activation, inputShape, &output_activation_min,
132 &output_activation_max);
133
134 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
135 *outputData = std::min((uint8_t)output_activation_max,
136 std::max((uint8_t)output_activation_min, *inputData));
137 }
138 return true;
139 }
140
reluQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)141 bool reluQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
142 const Shape& outputShape) {
143 NNTRACE_COMP("reluQuant8");
144 return reluXQuant8<kActivationRelu>(inputData, inputShape, outputData, outputShape);
145 }
146
relu1Quant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)147 bool relu1Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
148 const Shape& outputShape) {
149 NNTRACE_COMP("relu1Quant8");
150 return reluXQuant8<kActivationRelu1>(inputData, inputShape, outputData, outputShape);
151 }
152
relu6Quant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)153 bool relu6Quant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
154 const Shape& outputShape) {
155 NNTRACE_COMP("relu6Quant8");
156 return reluXQuant8<kActivationRelu6>(inputData, inputShape, outputData, outputShape);
157 }
158
tanhQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)159 bool tanhQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
160 const Shape& outputShape) {
161 NNTRACE_TRANS("tanhQuant8");
162 if (outputShape.offset != 128 || outputShape.scale != 1.f / 128) {
163 LOG(ERROR) << "incorrect scale or offset for TANH output";
164 return false;
165 }
166
167 int numElements = getNumberOfElements(inputShape);
168 static constexpr int kInputIntegerBits = 4;
169
170 const double input_real_multiplier =
171 inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
172
173 int32_t input_multiplier = 0;
174 int32_t input_left_shift = 0;
175 if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
176 &input_left_shift)) {
177 return false;
178 }
179 int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
180
181 NNTRACE_COMP_SWITCH("optimized_ops::Tanh");
182 tflite::optimized_ops::Tanh(inputData, convertShapeToTflshape(inputShape), inputShape.offset,
183 input_range_radius, input_multiplier, input_left_shift, outputData,
184 convertShapeToTflshape(outputShape));
185
186 return true;
187 }
188
logisticQuant8(const uint8_t * inputData,const Shape & inputShape,uint8_t * outputData,const Shape & outputShape)189 bool logisticQuant8(const uint8_t* inputData, const Shape& inputShape, uint8_t* outputData,
190 const Shape& outputShape) {
191 NNTRACE_TRANS("logisticQuant8");
192 if (outputShape.offset != 0 || outputShape.scale != 1.f / 256) {
193 LOG(ERROR) << "incorrect scale / offset for output";
194 return false;
195 }
196
197 int numElements = getNumberOfElements(inputShape);
198 static constexpr int kInputIntegerBits = 4;
199
200 const double input_real_multiplier =
201 inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
202
203 int32_t input_multiplier = 0;
204 int32_t input_left_shift = 0;
205 if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
206 &input_left_shift)) {
207 return false;
208 }
209 int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
210
211 NNTRACE_COMP_SWITCH("optimized_ops::Logistic");
212 tflite::optimized_ops::Logistic(
213 inputData, convertShapeToTflshape(inputShape), inputShape.offset, input_range_radius,
214 input_multiplier, input_left_shift, outputData, convertShapeToTflshape(outputShape));
215
216 return true;
217 }
218
219 template <ActivationFn activation>
reluXQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)220 inline bool reluXQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
221 const Shape& outputShape) {
222 int numElements = getNumberOfElements(inputShape);
223 int32_t output_activation_min = 0;
224 int32_t output_activation_max = 0;
225
226 CalculateActivationRangeInt8(activation, inputShape, &output_activation_min,
227 &output_activation_max);
228
229 for (int i = 0; i < numElements; i++, inputData++, outputData++) {
230 *outputData = std::min((int8_t)output_activation_max,
231 std::max((int8_t)output_activation_min, *inputData));
232 }
233 return true;
234 }
235
reluQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)236 bool reluQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
237 const Shape& outputShape) {
238 NNTRACE_COMP("reluQuant8");
239 return reluXQuant8Signed<kActivationRelu>(inputData, inputShape, outputData, outputShape);
240 }
241
relu1Quant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)242 bool relu1Quant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
243 const Shape& outputShape) {
244 NNTRACE_COMP("relu1Quant8");
245 return reluXQuant8Signed<kActivationRelu1>(inputData, inputShape, outputData, outputShape);
246 }
247
relu6Quant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)248 bool relu6Quant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
249 const Shape& outputShape) {
250 NNTRACE_COMP("relu6Quant8");
251 return reluXQuant8Signed<kActivationRelu6>(inputData, inputShape, outputData, outputShape);
252 }
253
tanhQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)254 bool tanhQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
255 const Shape& outputShape) {
256 NNTRACE_TRANS("tanhQuant8Signed");
257 if (outputShape.offset != 0 || outputShape.scale != 1.f / 128) {
258 LOG(ERROR) << "incorrect scale or offset for TANH output";
259 return false;
260 }
261
262 int numElements = getNumberOfElements(inputShape);
263 static constexpr int kInputIntegerBits = 4;
264
265 const double input_real_multiplier =
266 inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
267
268 int32_t input_multiplier = 0;
269 int32_t input_left_shift = 0;
270 if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
271 &input_left_shift)) {
272 return false;
273 }
274 int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
275
276 NNTRACE_COMP_SWITCH("reference_integer_ops::Tanh");
277 tflite::reference_integer_ops::Tanh(inputShape.offset, input_range_radius, input_multiplier,
278 input_left_shift, convertShapeToTflshape(inputShape),
279 inputData, convertShapeToTflshape(outputShape), outputData);
280
281 return true;
282 }
283
logisticQuant8Signed(const int8_t * inputData,const Shape & inputShape,int8_t * outputData,const Shape & outputShape)284 bool logisticQuant8Signed(const int8_t* inputData, const Shape& inputShape, int8_t* outputData,
285 const Shape& outputShape) {
286 NNTRACE_TRANS("logisticQuant8Signed");
287 if (outputShape.offset != -128 || outputShape.scale != 1.f / 256) {
288 LOG(ERROR) << "incorrect scale / offset for output";
289 return false;
290 }
291
292 int numElements = getNumberOfElements(inputShape);
293 static constexpr int kInputIntegerBits = 4;
294
295 const double input_real_multiplier =
296 inputShape.scale * static_cast<double>(1 << (31 - kInputIntegerBits));
297
298 int32_t input_multiplier = 0;
299 int32_t input_left_shift = 0;
300 if (!QuantizeMultiplierGreaterThanOne(input_real_multiplier, &input_multiplier,
301 &input_left_shift)) {
302 return false;
303 }
304 int32_t input_range_radius = CalculateInputRadius(kInputIntegerBits, input_left_shift);
305
306 NNTRACE_COMP_SWITCH("reference_integer_ops::Logistic");
307 tflite::reference_integer_ops::Logistic(inputShape.offset, input_range_radius, input_multiplier,
308 input_left_shift, numElements, inputData, outputData);
309
310 return true;
311 }
312
DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32,int16_t * multiplier_int16)313 void DownScaleInt32ToInt16Multiplier(int32_t multiplier_int32, int16_t* multiplier_int16) {
314 TFLITE_DCHECK_GE(multiplier_int32, 0);
315 static constexpr int32_t kRoundingOffset = 1 << 15;
316 if (multiplier_int32 >= std::numeric_limits<int32_t>::max() - kRoundingOffset) {
317 *multiplier_int16 = std::numeric_limits<int16_t>::max();
318 return;
319 }
320 const int32_t result = (multiplier_int32 + kRoundingOffset) >> 16;
321 TFLITE_DCHECK_LE(result << 16, multiplier_int32 + kRoundingOffset);
322 TFLITE_DCHECK_GT(result << 16, multiplier_int32 - kRoundingOffset);
323 *multiplier_int16 = result;
324 TFLITE_DCHECK_EQ(*multiplier_int16, result);
325 }
326
327 template <typename T>
hardSwishQuant(const T * inputData,const Shape & inputShape,T * outputData,const Shape & outputShape)328 bool hardSwishQuant(const T* inputData, const Shape& inputShape, T* outputData,
329 const Shape& outputShape) {
330 tflite::HardSwishParams params;
331 params.input_zero_point = inputShape.offset;
332 params.output_zero_point = outputShape.offset;
333 const float input_scale = inputShape.scale;
334 const float hires_input_scale = (1.0f / 128.0f) * input_scale;
335 const float reluish_scale = 3.0f / 32768.0f;
336 const float output_scale = outputShape.scale;
337
338 const float output_multiplier = hires_input_scale / output_scale;
339
340 int32_t output_multiplier_fixedpoint_int32;
341 NN_RET_CHECK(QuantizeMultiplier(output_multiplier, &output_multiplier_fixedpoint_int32,
342 ¶ms.output_multiplier_exponent));
343 DownScaleInt32ToInt16Multiplier(output_multiplier_fixedpoint_int32,
344 ¶ms.output_multiplier_fixedpoint_int16);
345 NN_RET_CHECK(params.output_multiplier_exponent <= 0);
346
347 const float reluish_multiplier = hires_input_scale / reluish_scale;
348 int32_t reluish_multiplier_fixedpoint_int32;
349 NN_RET_CHECK(QuantizeMultiplier(reluish_multiplier, &reluish_multiplier_fixedpoint_int32,
350 ¶ms.reluish_multiplier_exponent));
351 DownScaleInt32ToInt16Multiplier(reluish_multiplier_fixedpoint_int32,
352 ¶ms.reluish_multiplier_fixedpoint_int16);
353
354 tflite::reference_ops::HardSwish(params, convertShapeToTflshape(inputShape), inputData,
355 convertShapeToTflshape(outputShape), outputData);
356 return true;
357 }
358
359 } // namespace
360 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
361
validate(OperationType opType,const IOperationValidationContext * context)362 Result<Version> validate(OperationType opType, const IOperationValidationContext* context) {
363 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
364 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
365 auto inputType = context->getInputType(kInputTensor);
366 auto minSupportedVersion = Version::ANDROID_OC_MR1;
367 if (inputType == OperandType::TENSOR_FLOAT32) {
368 minSupportedVersion = Version::ANDROID_OC_MR1;
369 } else if (inputType == OperandType::TENSOR_FLOAT16) {
370 minSupportedVersion = Version::ANDROID_Q;
371 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
372 if (opType == OperationType::TANH) {
373 minSupportedVersion = Version::ANDROID_Q;
374 } else {
375 minSupportedVersion = Version::ANDROID_OC_MR1;
376 }
377 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
378 minSupportedVersion = Version::ANDROID_R;
379 } else {
380 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << opType;
381 }
382 const Shape& input = context->getInputShape(kInputTensor);
383 if (hasKnownRank(input)) {
384 NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
385 }
386 NN_RET_CHECK(validateInputTypes(context, {inputType}));
387 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
388 return minSupportedVersion;
389 }
390
validateHardSwish(const IOperationValidationContext * context)391 Result<Version> validateHardSwish(const IOperationValidationContext* context) {
392 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
393 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
394 auto inputType = context->getInputType(kInputTensor);
395 auto minSupportedVersion = Version::ANDROID_OC_MR1;
396 if (inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 ||
397 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
398 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
399 minSupportedVersion = Version::ANDROID_R;
400 } else {
401 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ELU";
402 }
403 NN_RET_CHECK(validateInputTypes(context, {inputType}));
404 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
405 return minSupportedVersion;
406 }
407
408 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
prepare(OperationType opType,IOperationExecutionContext * context)409 bool prepare(OperationType opType, IOperationExecutionContext* context) {
410 Shape input = context->getInputShape(kInputTensor);
411 if (opType != OperationType::HARD_SWISH) {
412 NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
413 }
414 Shape output = input;
415 if (input.type == OperandType::TENSOR_QUANT8_ASYMM ||
416 input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
417 bool isSigned = input.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED;
418 switch (opType) {
419 case OperationType::HARD_SWISH: {
420 auto outputShape = context->getOutputShape(kOutputTensor);
421 output.scale = outputShape.scale;
422 output.offset = outputShape.offset;
423 } break;
424 case OperationType::RELU:
425 case OperationType::RELU1:
426 case OperationType::RELU6:
427 break;
428 case OperationType::LOGISTIC:
429 output.scale = 1.f / 256;
430 output.offset = isSigned ? -128 : 0;
431 break;
432 case OperationType::TANH:
433 output.scale = 1.f / 128;
434 output.offset = isSigned ? 0 : 128;
435 break;
436 default:
437 NN_RET_CHECK_FAIL() << "Unsupported operation type";
438 }
439 }
440 return context->setOutputShape(kOutputTensor, output);
441 }
442
executeRelu(IOperationExecutionContext * context)443 bool executeRelu(IOperationExecutionContext* context) {
444 // Bypass execution in the case of zero-sized input.
445 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
446 switch (context->getInputType(kInputTensor)) {
447 case OperandType::TENSOR_FLOAT16:
448 return reluFloat(context->getInputBuffer<_Float16>(kInputTensor),
449 context->getInputShape(kInputTensor),
450 context->getOutputBuffer<_Float16>(kOutputTensor),
451 context->getOutputShape(kOutputTensor));
452 case OperandType::TENSOR_FLOAT32:
453 return reluFloat(context->getInputBuffer<float>(kInputTensor),
454 context->getInputShape(kInputTensor),
455 context->getOutputBuffer<float>(kOutputTensor),
456 context->getOutputShape(kOutputTensor));
457 case OperandType::TENSOR_QUANT8_ASYMM:
458 return reluQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
459 context->getInputShape(kInputTensor),
460 context->getOutputBuffer<uint8_t>(kOutputTensor),
461 context->getOutputShape(kOutputTensor));
462 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
463 return reluQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
464 context->getInputShape(kInputTensor),
465 context->getOutputBuffer<int8_t>(kOutputTensor),
466 context->getOutputShape(kOutputTensor));
467 default:
468 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU";
469 }
470 }
471
executeRelu1(IOperationExecutionContext * context)472 bool executeRelu1(IOperationExecutionContext* context) {
473 // Bypass execution in the case of zero-sized input.
474 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
475 switch (context->getInputType(kInputTensor)) {
476 case OperandType::TENSOR_FLOAT16:
477 return relu1Float(context->getInputBuffer<_Float16>(kInputTensor),
478 context->getInputShape(kInputTensor),
479 context->getOutputBuffer<_Float16>(kOutputTensor),
480 context->getOutputShape(kOutputTensor));
481 case OperandType::TENSOR_FLOAT32:
482 return relu1Float(context->getInputBuffer<float>(kInputTensor),
483 context->getInputShape(kInputTensor),
484 context->getOutputBuffer<float>(kOutputTensor),
485 context->getOutputShape(kOutputTensor));
486 case OperandType::TENSOR_QUANT8_ASYMM:
487 return relu1Quant8(context->getInputBuffer<uint8_t>(kInputTensor),
488 context->getInputShape(kInputTensor),
489 context->getOutputBuffer<uint8_t>(kOutputTensor),
490 context->getOutputShape(kOutputTensor));
491 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
492 return relu1Quant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
493 context->getInputShape(kInputTensor),
494 context->getOutputBuffer<int8_t>(kOutputTensor),
495 context->getOutputShape(kOutputTensor));
496 default:
497 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU1";
498 }
499 }
500
executeRelu6(IOperationExecutionContext * context)501 bool executeRelu6(IOperationExecutionContext* context) {
502 // Bypass execution in the case of zero-sized input.
503 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
504 switch (context->getInputType(kInputTensor)) {
505 case OperandType::TENSOR_FLOAT16:
506 return relu6Float(context->getInputBuffer<_Float16>(kInputTensor),
507 context->getInputShape(kInputTensor),
508 context->getOutputBuffer<_Float16>(kOutputTensor),
509 context->getOutputShape(kOutputTensor));
510 case OperandType::TENSOR_FLOAT32:
511 return relu6Float(context->getInputBuffer<float>(kInputTensor),
512 context->getInputShape(kInputTensor),
513 context->getOutputBuffer<float>(kOutputTensor),
514 context->getOutputShape(kOutputTensor));
515 case OperandType::TENSOR_QUANT8_ASYMM:
516 return relu6Quant8(context->getInputBuffer<uint8_t>(kInputTensor),
517 context->getInputShape(kInputTensor),
518 context->getOutputBuffer<uint8_t>(kOutputTensor),
519 context->getOutputShape(kOutputTensor));
520 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
521 return relu6Quant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
522 context->getInputShape(kInputTensor),
523 context->getOutputBuffer<int8_t>(kOutputTensor),
524 context->getOutputShape(kOutputTensor));
525 default:
526 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation RELU6";
527 }
528 }
529
executeLogistic(IOperationExecutionContext * context)530 bool executeLogistic(IOperationExecutionContext* context) {
531 // Bypass execution in the case of zero-sized input.
532 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
533 switch (context->getInputType(kInputTensor)) {
534 case OperandType::TENSOR_FLOAT16:
535 return logisticFloat(context->getInputBuffer<_Float16>(kInputTensor),
536 context->getInputShape(kInputTensor),
537 context->getOutputBuffer<_Float16>(kOutputTensor),
538 context->getOutputShape(kOutputTensor));
539 case OperandType::TENSOR_FLOAT32:
540 return logisticFloat(context->getInputBuffer<float>(kInputTensor),
541 context->getInputShape(kInputTensor),
542 context->getOutputBuffer<float>(kOutputTensor),
543 context->getOutputShape(kOutputTensor));
544 case OperandType::TENSOR_QUANT8_ASYMM:
545 return logisticQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
546 context->getInputShape(kInputTensor),
547 context->getOutputBuffer<uint8_t>(kOutputTensor),
548 context->getOutputShape(kOutputTensor));
549 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
550 return logisticQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
551 context->getInputShape(kInputTensor),
552 context->getOutputBuffer<int8_t>(kOutputTensor),
553 context->getOutputShape(kOutputTensor));
554 default:
555 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC";
556 }
557 }
558
executeTanh(IOperationExecutionContext * context)559 bool executeTanh(IOperationExecutionContext* context) {
560 // Bypass execution in the case of zero-sized input.
561 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
562 switch (context->getInputType(kInputTensor)) {
563 case OperandType::TENSOR_FLOAT16:
564 return tanhFloat16(context->getInputBuffer<_Float16>(kInputTensor),
565 context->getInputShape(kInputTensor),
566 context->getOutputBuffer<_Float16>(kOutputTensor),
567 context->getOutputShape(kOutputTensor));
568 case OperandType::TENSOR_FLOAT32:
569 return tanhFloat32(context->getInputBuffer<float>(kInputTensor),
570 context->getInputShape(kInputTensor),
571 context->getOutputBuffer<float>(kOutputTensor),
572 context->getOutputShape(kOutputTensor));
573 case OperandType::TENSOR_QUANT8_ASYMM:
574 return tanhQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
575 context->getInputShape(kInputTensor),
576 context->getOutputBuffer<uint8_t>(kOutputTensor),
577 context->getOutputShape(kOutputTensor));
578 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
579 return tanhQuant8Signed(context->getInputBuffer<int8_t>(kInputTensor),
580 context->getInputShape(kInputTensor),
581 context->getOutputBuffer<int8_t>(kOutputTensor),
582 context->getOutputShape(kOutputTensor));
583 default:
584 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
585 }
586 }
587
executeHardSwish(IOperationExecutionContext * context)588 bool executeHardSwish(IOperationExecutionContext* context) {
589 // Bypass execution in the case of zero-sized input.
590 if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
591 switch (context->getInputType(kInputTensor)) {
592 case OperandType::TENSOR_FLOAT16: {
593 const Shape& inputShape = context->getInputShape(kInputTensor);
594 const Shape& outputShape = context->getOutputShape(kOutputTensor);
595 std::vector<float> inputFloat(getNumberOfElements(inputShape));
596 std::vector<float> outputFloat(getNumberOfElements(outputShape));
597 convertFloat16ToFloat32(context->getInputBuffer<_Float16>(kInputTensor), &inputFloat);
598 tflite::reference_ops::HardSwish(convertShapeToTflshape(inputShape), inputFloat.data(),
599 convertShapeToTflshape(outputShape),
600 outputFloat.data());
601 convertFloat32ToFloat16(outputFloat, context->getOutputBuffer<_Float16>(kOutputTensor));
602 return true;
603 }
604 case OperandType::TENSOR_FLOAT32: {
605 tflite::reference_ops::HardSwish(
606 convertShapeToTflshape(context->getInputShape(kInputTensor)),
607 context->getInputBuffer<float>(kInputTensor),
608 convertShapeToTflshape(context->getOutputShape(kOutputTensor)),
609 context->getOutputBuffer<float>(kOutputTensor));
610 return true;
611 }
612 case OperandType::TENSOR_QUANT8_ASYMM:
613 return hardSwishQuant(context->getInputBuffer<uint8_t>(kInputTensor),
614 context->getInputShape(kInputTensor),
615 context->getOutputBuffer<uint8_t>(kOutputTensor),
616 context->getOutputShape(kOutputTensor));
617 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
618 return hardSwishQuant(context->getInputBuffer<int8_t>(kInputTensor),
619 context->getInputShape(kInputTensor),
620 context->getOutputBuffer<int8_t>(kOutputTensor),
621 context->getOutputShape(kOutputTensor));
622 default:
623 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
624 }
625 }
626 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
627
628 } // namespace activation
629
630 using std::placeholders::_1;
631 NN_REGISTER_OPERATION(RELU, "RELU", std::bind(activation::validate, OperationType::RELU, _1),
632 std::bind(activation::prepare, OperationType::RELU, _1),
633 activation::executeRelu, .allowZeroSizedInput = true);
634 NN_REGISTER_OPERATION(RELU1, "RELU1", std::bind(activation::validate, OperationType::RELU1, _1),
635 std::bind(activation::prepare, OperationType::RELU1, _1),
636 activation::executeRelu1, .allowZeroSizedInput = true);
637 NN_REGISTER_OPERATION(RELU6, "RELU6", std::bind(activation::validate, OperationType::RELU6, _1),
638 std::bind(activation::prepare, OperationType::RELU6, _1),
639 activation::executeRelu6, .allowZeroSizedInput = true);
640 NN_REGISTER_OPERATION(LOGISTIC, "LOGISTIC",
641 std::bind(activation::validate, OperationType::LOGISTIC, _1),
642 std::bind(activation::prepare, OperationType::LOGISTIC, _1),
643 activation::executeLogistic, .allowZeroSizedInput = true);
644 NN_REGISTER_OPERATION(TANH, "TANH", std::bind(activation::validate, OperationType::TANH, _1),
645 std::bind(activation::prepare, OperationType::TANH, _1),
646 activation::executeTanh, .allowZeroSizedInput = true);
647 NN_REGISTER_OPERATION(HARD_SWISH, "HARD_SWISH", activation::validateHardSwish,
648 std::bind(activation::prepare, OperationType::HARD_SWISH, _1),
649 activation::executeHardSwish, .allowZeroSizedInput = true);
650
651 } // namespace nn
652 } // namespace android
653