1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <cmath>
20
21 #include "OperationResolver.h"
22 #include "OperationsUtils.h"
23 #include "Tracing.h"
24
25 namespace android {
26 namespace nn {
27 namespace elementwise {
28
29 constexpr uint32_t kNumInputs = 1;
30 constexpr uint32_t kInputTensor = 0;
31
32 constexpr uint32_t kNumOutputs = 1;
33 constexpr uint32_t kOutputTensor = 0;
34
35 namespace {
36
37 template <typename IntermediateType, typename T>
compute(IntermediateType func (IntermediateType),const T * input,const Shape & shape,T * output)38 inline bool compute(IntermediateType func(IntermediateType), const T* input, const Shape& shape,
39 T* output) {
40 const auto size = getNumberOfElements(shape);
41 for (uint32_t i = 0; i < size; ++i) {
42 output[i] = static_cast<T>(func(static_cast<IntermediateType>(input[i])));
43 }
44 return true;
45 }
46
execute(IOperationExecutionContext * context,float func (float))47 bool execute(IOperationExecutionContext* context, float func(float)) {
48 switch (context->getInputType(kInputTensor)) {
49 case OperandType::TENSOR_FLOAT16:
50 return compute<float, _Float16>(func, context->getInputBuffer<_Float16>(kInputTensor),
51 context->getInputShape(kInputTensor),
52 context->getOutputBuffer<_Float16>(kOutputTensor));
53 case OperandType::TENSOR_FLOAT32:
54 return compute<float, float>(func, context->getInputBuffer<float>(kInputTensor),
55 context->getInputShape(kInputTensor),
56 context->getOutputBuffer<float>(kOutputTensor));
57 default:
58 NN_RET_CHECK_FAIL() << "Unsupported tensor type for elementwise operation";
59 }
60 }
61
62 } // namespace
63
executeAbs(IOperationExecutionContext * context)64 bool executeAbs(IOperationExecutionContext* context) {
65 switch (context->getInputType(kInputTensor)) {
66 case OperandType::TENSOR_FLOAT16:
67 return compute<float, _Float16>(std::abs,
68 context->getInputBuffer<_Float16>(kInputTensor),
69 context->getInputShape(kInputTensor),
70 context->getOutputBuffer<_Float16>(kOutputTensor));
71 case OperandType::TENSOR_FLOAT32:
72 return compute<float, float>(std::abs, context->getInputBuffer<float>(kInputTensor),
73 context->getInputShape(kInputTensor),
74 context->getOutputBuffer<float>(kOutputTensor));
75 case OperandType::TENSOR_INT32:
76 return compute<int32_t, int32_t>(std::abs,
77 context->getInputBuffer<int32_t>(kInputTensor),
78 context->getInputShape(kInputTensor),
79 context->getOutputBuffer<int32_t>(kOutputTensor));
80 default:
81 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ABS";
82 }
83 }
84
validate(const IOperationValidationContext * context)85 Result<Version> validate(const IOperationValidationContext* context) {
86 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
87 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
88 OperandType inputType = context->getInputType(kInputTensor);
89 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
90 inputType == OperandType::TENSOR_FLOAT32)
91 << "Unsupported tensor type for elementwise operation";
92 NN_RET_CHECK(validateInputTypes(context, {inputType}));
93 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
94 return Version::ANDROID_Q;
95 }
96
validateAbs(const IOperationValidationContext * context)97 Result<Version> validateAbs(const IOperationValidationContext* context) {
98 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
99 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
100 OperandType inputType = context->getInputType(kInputTensor);
101 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
102 inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32)
103 << "Unsupported tensor type for operation ABS";
104 NN_RET_CHECK(validateInputTypes(context, {inputType}));
105 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
106 return inputType == OperandType::TENSOR_INT32 ? Version::ANDROID_R : Version::ANDROID_Q;
107 }
108
validateFloor(const IOperationValidationContext * context)109 Result<Version> validateFloor(const IOperationValidationContext* context) {
110 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
111 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
112
113 OperandType inputType = context->getInputType(kInputTensor);
114 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
115 inputType == OperandType::TENSOR_FLOAT32)
116 << "Unsupported tensor type for operation FLOOR";
117 NN_RET_CHECK(validateInputTypes(context, {inputType}));
118 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
119
120 const Shape& input = context->getInputShape(kInputTensor);
121 if (hasKnownRank(input)) {
122 NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
123 }
124
125 return inputType == OperandType::TENSOR_FLOAT16 ? Version::ANDROID_Q : Version::ANDROID_OC_MR1;
126 }
127
prepare(IOperationExecutionContext * context)128 bool prepare(IOperationExecutionContext* context) {
129 Shape input = context->getInputShape(kInputTensor);
130 Shape output = context->getOutputShape(kOutputTensor);
131 NN_RET_CHECK(SetShape(input, &output));
132 return context->setOutputShape(kOutputTensor, output);
133 }
134
prepareFloor(IOperationExecutionContext * context)135 bool prepareFloor(IOperationExecutionContext* context) {
136 Shape input = context->getInputShape(kInputTensor);
137 Shape output = context->getOutputShape(kOutputTensor);
138 NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
139 NN_RET_CHECK(SetShape(input, &output));
140 return context->setOutputShape(kOutputTensor, output);
141 }
142
executeExp(IOperationExecutionContext * context)143 bool executeExp(IOperationExecutionContext* context) {
144 return execute(context, std::exp);
145 }
146
executeFloor(IOperationExecutionContext * context)147 bool executeFloor(IOperationExecutionContext* context) {
148 return execute(context, std::floor);
149 }
150
executeLog(IOperationExecutionContext * context)151 bool executeLog(IOperationExecutionContext* context) {
152 return execute(context, std::log);
153 }
154
executeRsqrt(IOperationExecutionContext * context)155 bool executeRsqrt(IOperationExecutionContext* context) {
156 return execute(context, [](float x) { return 1.f / std::sqrt(x); });
157 }
158
executeSin(IOperationExecutionContext * context)159 bool executeSin(IOperationExecutionContext* context) {
160 return execute(context, std::sin);
161 }
162
executeSqrt(IOperationExecutionContext * context)163 bool executeSqrt(IOperationExecutionContext* context) {
164 return execute(context, std::sqrt);
165 }
166
167 } // namespace elementwise
168
169 NN_REGISTER_OPERATION(ABS, "ABS", elementwise::validateAbs, elementwise::prepare,
170 elementwise::executeAbs);
171 NN_REGISTER_OPERATION(EXP, "EXP", elementwise::validate, elementwise::prepare,
172 elementwise::executeExp);
173 NN_REGISTER_OPERATION(FLOOR, "FLOOR", elementwise::validateFloor, elementwise::prepareFloor,
174 elementwise::executeFloor);
175 NN_REGISTER_OPERATION(LOG, "LOG", elementwise::validate, elementwise::prepare,
176 elementwise::executeLog);
177 NN_REGISTER_OPERATION(RSQRT, "RSQRT", elementwise::validate, elementwise::prepare,
178 elementwise::executeRsqrt);
179 NN_REGISTER_OPERATION(SIN, "SIN", elementwise::validate, elementwise::prepare,
180 elementwise::executeSin);
181 NN_REGISTER_OPERATION(SQRT, "SQRT", elementwise::validate, elementwise::prepare,
182 elementwise::executeSqrt);
183
184 } // namespace nn
185 } // namespace android
186