1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <functional>
20 #include <vector>
21
22 #include "IndexedShapeWrapper.h"
23 #include "OperationResolver.h"
24 #include "OperationsUtils.h"
25
26 namespace android {
27 namespace nn {
28 namespace comparisons {
29
30 constexpr uint32_t kNumInputs = 2;
31 constexpr uint32_t kInputTensor1 = 0;
32 constexpr uint32_t kInputTensor2 = 1;
33
34 constexpr uint32_t kNumOutputs = 1;
35 constexpr uint32_t kOutputTensor = 0;
36
37 namespace {
38
39 template <typename DataType, typename ComparisonType>
compute(const std::function<bool (ComparisonType,ComparisonType)> & func,const DataType * aData,const Shape & aShape,const DataType * bData,const Shape & bShape,bool8 * outputData,const Shape & outputShape)40 bool compute(const std::function<bool(ComparisonType, ComparisonType)>& func, const DataType* aData,
41 const Shape& aShape, const DataType* bData, const Shape& bShape, bool8* outputData,
42 const Shape& outputShape) {
43 IndexedShapeWrapper aShapeIndexed(aShape);
44 IndexedShapeWrapper bShapeIndexed(bShape);
45 IndexedShapeWrapper outputShapeIndexed(outputShape);
46 std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0);
47 bool lastIndex = false;
48 do {
49 uint32_t outputFlatIndex;
50 NN_RET_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex));
51 uint32_t aFlatIndex;
52 NN_RET_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex));
53 uint32_t bFlatIndex;
54 NN_RET_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex));
55
56 if (aShape.type == OperandType::TENSOR_QUANT8_ASYMM ||
57 aShape.type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
58 const float realA = (aData[aFlatIndex] - aShape.offset) * aShape.scale;
59 const float realB = (bData[bFlatIndex] - bShape.offset) * bShape.scale;
60 outputData[outputFlatIndex] = func(realA, realB);
61 } else {
62 outputData[outputFlatIndex] = func(aData[aFlatIndex], bData[bFlatIndex]);
63 }
64
65 NN_RET_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex));
66 } while (!lastIndex);
67 return true;
68 }
69
70 template <typename DataType, typename ComparisonType>
executeLessTyped(IOperationExecutionContext * context)71 bool executeLessTyped(IOperationExecutionContext* context) {
72 return compute<DataType, ComparisonType>(
73 std::less<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1),
74 context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2),
75 context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor),
76 context->getOutputShape(kOutputTensor));
77 }
78
79 template <typename DataType, typename ComparisonType>
executeLessEqualTyped(IOperationExecutionContext * context)80 bool executeLessEqualTyped(IOperationExecutionContext* context) {
81 return compute<DataType, ComparisonType>(
82 std::less_equal<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1),
83 context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2),
84 context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor),
85 context->getOutputShape(kOutputTensor));
86 }
87
88 template <typename DataType, typename ComparisonType>
executeEqualTyped(IOperationExecutionContext * context)89 bool executeEqualTyped(IOperationExecutionContext* context) {
90 return compute<DataType, ComparisonType>(
91 std::equal_to<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1),
92 context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2),
93 context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor),
94 context->getOutputShape(kOutputTensor));
95 }
96
97 template <typename DataType, typename ComparisonType>
executeNotEqualTyped(IOperationExecutionContext * context)98 bool executeNotEqualTyped(IOperationExecutionContext* context) {
99 return compute<DataType, ComparisonType>(
100 std::not_equal_to<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1),
101 context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2),
102 context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor),
103 context->getOutputShape(kOutputTensor));
104 }
105
106 template <typename DataType, typename ComparisonType>
executeGreaterEqualTyped(IOperationExecutionContext * context)107 bool executeGreaterEqualTyped(IOperationExecutionContext* context) {
108 return compute<DataType, ComparisonType>(
109 std::greater_equal<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1),
110 context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2),
111 context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor),
112 context->getOutputShape(kOutputTensor));
113 }
114
115 template <typename DataType, typename ComparisonType>
executeGreaterTyped(IOperationExecutionContext * context)116 bool executeGreaterTyped(IOperationExecutionContext* context) {
117 return compute<DataType, ComparisonType>(
118 std::greater<ComparisonType>(), context->getInputBuffer<DataType>(kInputTensor1),
119 context->getInputShape(kInputTensor1), context->getInputBuffer<DataType>(kInputTensor2),
120 context->getInputShape(kInputTensor2), context->getOutputBuffer<bool8>(kOutputTensor),
121 context->getOutputShape(kOutputTensor));
122 }
123
124 } // namespace
125
validate(const IOperationValidationContext * context)126 Result<Version> validate(const IOperationValidationContext* context) {
127 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
128 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
129 OperandType inputType = context->getInputType(kInputTensor1);
130 NN_RET_CHECK(
131 inputType == OperandType::TENSOR_BOOL8 || inputType == OperandType::TENSOR_FLOAT16 ||
132 inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32 ||
133 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
134 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)
135 << "Unsupported input operand type for comparison op: " << inputType;
136 NN_RET_CHECK(validateInputTypes(context, {inputType, inputType}));
137 NN_RET_CHECK(validateOutputTypes(context, {OperandType::TENSOR_BOOL8}));
138 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
139 return Version::ANDROID_R;
140 } else {
141 return Version::ANDROID_Q;
142 }
143 }
144
prepare(IOperationExecutionContext * context)145 bool prepare(IOperationExecutionContext* context) {
146 Shape input1 = context->getInputShape(kInputTensor1);
147 Shape input2 = context->getInputShape(kInputTensor2);
148 Shape output = context->getOutputShape(kOutputTensor);
149 NN_RET_CHECK(calculateBroadcastedShape(input1, input2, &output));
150 return context->setOutputShape(kOutputTensor, output);
151 }
152
executeLess(IOperationExecutionContext * context)153 bool executeLess(IOperationExecutionContext* context) {
154 switch (context->getInputType(kInputTensor1)) {
155 case OperandType::TENSOR_FLOAT16:
156 return executeLessTyped<_Float16, _Float16>(context);
157 case OperandType::TENSOR_FLOAT32:
158 return executeLessTyped<float, float>(context);
159 case OperandType::TENSOR_INT32:
160 return executeLessTyped<int32_t, int32_t>(context);
161 case OperandType::TENSOR_QUANT8_ASYMM:
162 return executeLessTyped<uint8_t, float>(context);
163 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
164 return executeLessTyped<int8_t, float>(context);
165 case OperandType::TENSOR_BOOL8:
166 return executeLessTyped<bool8, bool8>(context);
167 default:
168 NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison";
169 }
170 }
171
executeLessEqual(IOperationExecutionContext * context)172 bool executeLessEqual(IOperationExecutionContext* context) {
173 switch (context->getInputType(kInputTensor1)) {
174 case OperandType::TENSOR_FLOAT16:
175 return executeLessEqualTyped<_Float16, _Float16>(context);
176 case OperandType::TENSOR_FLOAT32:
177 return executeLessEqualTyped<float, float>(context);
178 case OperandType::TENSOR_INT32:
179 return executeLessEqualTyped<int32_t, int32_t>(context);
180 case OperandType::TENSOR_QUANT8_ASYMM:
181 return executeLessEqualTyped<uint8_t, float>(context);
182 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
183 return executeLessEqualTyped<int8_t, float>(context);
184 case OperandType::TENSOR_BOOL8:
185 return executeLessEqualTyped<bool8, bool8>(context);
186 default:
187 NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison";
188 }
189 }
190
executeEqual(IOperationExecutionContext * context)191 bool executeEqual(IOperationExecutionContext* context) {
192 switch (context->getInputType(kInputTensor1)) {
193 case OperandType::TENSOR_FLOAT16:
194 return executeEqualTyped<_Float16, _Float16>(context);
195 case OperandType::TENSOR_FLOAT32:
196 return executeEqualTyped<float, float>(context);
197 case OperandType::TENSOR_INT32:
198 return executeEqualTyped<int32_t, int32_t>(context);
199 case OperandType::TENSOR_QUANT8_ASYMM:
200 return executeEqualTyped<uint8_t, float>(context);
201 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
202 return executeEqualTyped<int8_t, float>(context);
203 case OperandType::TENSOR_BOOL8:
204 return executeEqualTyped<bool8, bool8>(context);
205 default:
206 NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison";
207 }
208 }
209
executeNotEqual(IOperationExecutionContext * context)210 bool executeNotEqual(IOperationExecutionContext* context) {
211 switch (context->getInputType(kInputTensor1)) {
212 case OperandType::TENSOR_FLOAT16:
213 return executeNotEqualTyped<_Float16, _Float16>(context);
214 case OperandType::TENSOR_FLOAT32:
215 return executeNotEqualTyped<float, float>(context);
216 case OperandType::TENSOR_INT32:
217 return executeNotEqualTyped<int32_t, int32_t>(context);
218 case OperandType::TENSOR_QUANT8_ASYMM:
219 return executeNotEqualTyped<uint8_t, float>(context);
220 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
221 return executeNotEqualTyped<int8_t, float>(context);
222 case OperandType::TENSOR_BOOL8:
223 return executeNotEqualTyped<bool8, bool8>(context);
224 default:
225 NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison";
226 }
227 }
228
executeGreaterEqual(IOperationExecutionContext * context)229 bool executeGreaterEqual(IOperationExecutionContext* context) {
230 switch (context->getInputType(kInputTensor1)) {
231 case OperandType::TENSOR_FLOAT16:
232 return executeGreaterEqualTyped<_Float16, _Float16>(context);
233 case OperandType::TENSOR_FLOAT32:
234 return executeGreaterEqualTyped<float, float>(context);
235 case OperandType::TENSOR_INT32:
236 return executeGreaterEqualTyped<int32_t, int32_t>(context);
237 case OperandType::TENSOR_QUANT8_ASYMM:
238 return executeGreaterEqualTyped<uint8_t, float>(context);
239 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
240 return executeGreaterEqualTyped<int8_t, float>(context);
241 case OperandType::TENSOR_BOOL8:
242 return executeGreaterEqualTyped<bool8, bool8>(context);
243 default:
244 NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison";
245 }
246 }
247
executeGreater(IOperationExecutionContext * context)248 bool executeGreater(IOperationExecutionContext* context) {
249 switch (context->getInputType(kInputTensor1)) {
250 case OperandType::TENSOR_FLOAT16:
251 return executeGreaterTyped<_Float16, _Float16>(context);
252 case OperandType::TENSOR_FLOAT32:
253 return executeGreaterTyped<float, float>(context);
254 case OperandType::TENSOR_INT32:
255 return executeGreaterTyped<int32_t, int32_t>(context);
256 case OperandType::TENSOR_QUANT8_ASYMM:
257 return executeGreaterTyped<uint8_t, float>(context);
258 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
259 return executeGreaterTyped<int8_t, float>(context);
260 case OperandType::TENSOR_BOOL8:
261 return executeGreaterTyped<bool8, bool8>(context);
262 default:
263 NN_RET_CHECK_FAIL() << "Unsupported tensor type for comparison";
264 }
265 }
266
267 } // namespace comparisons
268
269 NN_REGISTER_OPERATION(LESS, "LESS", comparisons::validate, comparisons::prepare,
270 comparisons::executeLess);
271 NN_REGISTER_OPERATION(LESS_EQUAL, "LESS_EQUAL", comparisons::validate, comparisons::prepare,
272 comparisons::executeLessEqual);
273 NN_REGISTER_OPERATION(EQUAL, "EQUAL", comparisons::validate, comparisons::prepare,
274 comparisons::executeEqual);
275 NN_REGISTER_OPERATION(NOT_EQUAL, "NOT_EQUAL", comparisons::validate, comparisons::prepare,
276 comparisons::executeNotEqual);
277 NN_REGISTER_OPERATION(GREATER_EQUAL, "GREATER_EQUAL", comparisons::validate, comparisons::prepare,
278 comparisons::executeGreaterEqual);
279 NN_REGISTER_OPERATION(GREATER, "GREATER", comparisons::validate, comparisons::prepare,
280 comparisons::executeGreater);
281
282 } // namespace nn
283 } // namespace android
284