1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include <cmath>
20 
21 #include "OperationResolver.h"
22 #include "OperationsUtils.h"
23 #include "Tracing.h"
24 
25 namespace android {
26 namespace nn {
27 namespace elementwise {
28 
29 constexpr uint32_t kNumInputs = 1;
30 constexpr uint32_t kInputTensor = 0;
31 
32 constexpr uint32_t kNumOutputs = 1;
33 constexpr uint32_t kOutputTensor = 0;
34 
35 namespace {
36 
37 template <typename IntermediateType, typename T>
compute(IntermediateType func (IntermediateType),const T * input,const Shape & shape,T * output)38 inline bool compute(IntermediateType func(IntermediateType), const T* input, const Shape& shape,
39                     T* output) {
40     const auto size = getNumberOfElements(shape);
41     for (uint32_t i = 0; i < size; ++i) {
42         output[i] = static_cast<T>(func(static_cast<IntermediateType>(input[i])));
43     }
44     return true;
45 }
46 
execute(IOperationExecutionContext * context,float func (float))47 bool execute(IOperationExecutionContext* context, float func(float)) {
48     switch (context->getInputType(kInputTensor)) {
49         case OperandType::TENSOR_FLOAT16:
50             return compute<float, _Float16>(func, context->getInputBuffer<_Float16>(kInputTensor),
51                                             context->getInputShape(kInputTensor),
52                                             context->getOutputBuffer<_Float16>(kOutputTensor));
53         case OperandType::TENSOR_FLOAT32:
54             return compute<float, float>(func, context->getInputBuffer<float>(kInputTensor),
55                                          context->getInputShape(kInputTensor),
56                                          context->getOutputBuffer<float>(kOutputTensor));
57         default:
58             NN_RET_CHECK_FAIL() << "Unsupported tensor type for elementwise operation";
59     }
60 }
61 
62 }  // namespace
63 
executeAbs(IOperationExecutionContext * context)64 bool executeAbs(IOperationExecutionContext* context) {
65     switch (context->getInputType(kInputTensor)) {
66         case OperandType::TENSOR_FLOAT16:
67             return compute<float, _Float16>(std::abs,
68                                             context->getInputBuffer<_Float16>(kInputTensor),
69                                             context->getInputShape(kInputTensor),
70                                             context->getOutputBuffer<_Float16>(kOutputTensor));
71         case OperandType::TENSOR_FLOAT32:
72             return compute<float, float>(std::abs, context->getInputBuffer<float>(kInputTensor),
73                                          context->getInputShape(kInputTensor),
74                                          context->getOutputBuffer<float>(kOutputTensor));
75         case OperandType::TENSOR_INT32:
76             return compute<int32_t, int32_t>(std::abs,
77                                              context->getInputBuffer<int32_t>(kInputTensor),
78                                              context->getInputShape(kInputTensor),
79                                              context->getOutputBuffer<int32_t>(kOutputTensor));
80         default:
81             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ABS";
82     }
83 }
84 
validate(const IOperationValidationContext * context)85 Result<Version> validate(const IOperationValidationContext* context) {
86     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
87     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
88     OperandType inputType = context->getInputType(kInputTensor);
89     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
90                  inputType == OperandType::TENSOR_FLOAT32)
91             << "Unsupported tensor type for elementwise operation";
92     NN_RET_CHECK(validateInputTypes(context, {inputType}));
93     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
94     return Version::ANDROID_Q;
95 }
96 
validateAbs(const IOperationValidationContext * context)97 Result<Version> validateAbs(const IOperationValidationContext* context) {
98     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
99     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
100     OperandType inputType = context->getInputType(kInputTensor);
101     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
102                  inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32)
103             << "Unsupported tensor type for operation ABS";
104     NN_RET_CHECK(validateInputTypes(context, {inputType}));
105     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
106     return inputType == OperandType::TENSOR_INT32 ? Version::ANDROID_R : Version::ANDROID_Q;
107 }
108 
validateFloor(const IOperationValidationContext * context)109 Result<Version> validateFloor(const IOperationValidationContext* context) {
110     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
111     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
112 
113     OperandType inputType = context->getInputType(kInputTensor);
114     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
115                  inputType == OperandType::TENSOR_FLOAT32)
116             << "Unsupported tensor type for operation FLOOR";
117     NN_RET_CHECK(validateInputTypes(context, {inputType}));
118     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
119 
120     const Shape& input = context->getInputShape(kInputTensor);
121     if (hasKnownRank(input)) {
122         NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
123     }
124 
125     return inputType == OperandType::TENSOR_FLOAT16 ? Version::ANDROID_Q : Version::ANDROID_OC_MR1;
126 }
127 
prepare(IOperationExecutionContext * context)128 bool prepare(IOperationExecutionContext* context) {
129     Shape input = context->getInputShape(kInputTensor);
130     Shape output = context->getOutputShape(kOutputTensor);
131     NN_RET_CHECK(SetShape(input, &output));
132     return context->setOutputShape(kOutputTensor, output);
133 }
134 
prepareFloor(IOperationExecutionContext * context)135 bool prepareFloor(IOperationExecutionContext* context) {
136     Shape input = context->getInputShape(kInputTensor);
137     Shape output = context->getOutputShape(kOutputTensor);
138     NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
139     NN_RET_CHECK(SetShape(input, &output));
140     return context->setOutputShape(kOutputTensor, output);
141 }
142 
executeExp(IOperationExecutionContext * context)143 bool executeExp(IOperationExecutionContext* context) {
144     return execute(context, std::exp);
145 }
146 
executeFloor(IOperationExecutionContext * context)147 bool executeFloor(IOperationExecutionContext* context) {
148     return execute(context, std::floor);
149 }
150 
executeLog(IOperationExecutionContext * context)151 bool executeLog(IOperationExecutionContext* context) {
152     return execute(context, std::log);
153 }
154 
executeRsqrt(IOperationExecutionContext * context)155 bool executeRsqrt(IOperationExecutionContext* context) {
156     return execute(context, [](float x) { return 1.f / std::sqrt(x); });
157 }
158 
executeSin(IOperationExecutionContext * context)159 bool executeSin(IOperationExecutionContext* context) {
160     return execute(context, std::sin);
161 }
162 
executeSqrt(IOperationExecutionContext * context)163 bool executeSqrt(IOperationExecutionContext* context) {
164     return execute(context, std::sqrt);
165 }
166 
167 }  // namespace elementwise
168 
169 NN_REGISTER_OPERATION(ABS, "ABS", elementwise::validateAbs, elementwise::prepare,
170                       elementwise::executeAbs);
171 NN_REGISTER_OPERATION(EXP, "EXP", elementwise::validate, elementwise::prepare,
172                       elementwise::executeExp);
173 NN_REGISTER_OPERATION(FLOOR, "FLOOR", elementwise::validateFloor, elementwise::prepareFloor,
174                       elementwise::executeFloor);
175 NN_REGISTER_OPERATION(LOG, "LOG", elementwise::validate, elementwise::prepare,
176                       elementwise::executeLog);
177 NN_REGISTER_OPERATION(RSQRT, "RSQRT", elementwise::validate, elementwise::prepare,
178                       elementwise::executeRsqrt);
179 NN_REGISTER_OPERATION(SIN, "SIN", elementwise::validate, elementwise::prepare,
180                       elementwise::executeSin);
181 NN_REGISTER_OPERATION(SQRT, "SQRT", elementwise::validate, elementwise::prepare,
182                       elementwise::executeSqrt);
183 
184 }  // namespace nn
185 }  // namespace android
186