1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <algorithm>
20 #include <cfloat>
21 #include <cmath>
22 #include <vector>
23
24 #include "OperationResolver.h"
25 #include "OperationsUtils.h"
26 #include "Tracing.h"
27
28 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
29 #include "CpuOperationUtils.h"
30 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
31
32 namespace android {
33 namespace nn {
34 namespace roi_pooling {
35
36 constexpr char kOperationName[] = "ROI_POOLING";
37
38 constexpr uint32_t kNumInputs = 8;
39 constexpr uint32_t kInputTensor = 0;
40 constexpr uint32_t kRoiTensor = 1;
41 constexpr uint32_t kBatchSplitTensor = 2;
42 constexpr uint32_t kOutputHeightScalar = 3;
43 constexpr uint32_t kOutputWidthScalar = 4;
44 constexpr uint32_t kHeightStrideSalar = 5;
45 constexpr uint32_t kWidthStrideScalar = 6;
46 constexpr uint32_t kLayoutScalar = 7;
47
48 constexpr uint32_t kNumOutputs = 1;
49 constexpr uint32_t kOutputTensor = 0;
50
51 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
52 namespace {
53
54 template <typename T_Input, typename T_Roi>
roiPoolingNhwc(const T_Input * inputData,const Shape & inputShape,const T_Roi * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,T_Input * outputData,const Shape & outputShape)55 inline bool roiPoolingNhwc(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
56 const Shape& roiShape, const int32_t* batchSplitData,
57 const Shape& batchSplitShape, float heightStride, float widthStride,
58 T_Input* outputData, const Shape& outputShape) {
59 NNTRACE_TRANS("RoiPooling");
60
61 const uint32_t kRoiDim = 4;
62 const T_Roi heightScale = 1.0f / heightStride;
63 const T_Roi widthScale = 1.0f / widthStride;
64
65 uint32_t numBatches = getSizeOfDimension(inputShape, 0);
66 uint32_t inHeight = getSizeOfDimension(inputShape, 1);
67 uint32_t inWidth = getSizeOfDimension(inputShape, 2);
68 uint32_t inDepth = getSizeOfDimension(inputShape, 3);
69 uint32_t outHeight = getSizeOfDimension(outputShape, 1);
70 uint32_t outWidth = getSizeOfDimension(outputShape, 2);
71 uint32_t numRois = getSizeOfDimension(roiShape, 0);
72 uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
73
74 T_Input* outPtr = outputData;
75 const T_Roi* roiDataEnd = roiData + numRois * roiInfoLength;
76 uint32_t roiIndex = 0;
77 for (const T_Roi* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
78 uint32_t batchId = batchSplitData[roiIndex];
79 // Check for malformed data
80 // 1. invalid batch id
81 // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
82 // 3. Invalid region: x2 < x1 || y2 < y1
83 NN_RET_CHECK_GE(batchId, 0);
84 NN_RET_CHECK_LT(batchId, numBatches);
85 NN_RET_CHECK(roiInfo[0] >= 0);
86 NN_RET_CHECK(roiInfo[1] >= 0);
87 NN_RET_CHECK(roiInfo[2] >= 0);
88 NN_RET_CHECK(roiInfo[3] >= 0);
89 NN_RET_CHECK(roiInfo[0] * widthScale <= inWidth);
90 NN_RET_CHECK(roiInfo[1] * heightScale <= inHeight);
91 NN_RET_CHECK(roiInfo[2] * widthScale <= inWidth);
92 NN_RET_CHECK(roiInfo[3] * heightScale <= inHeight);
93 NN_RET_CHECK(roiInfo[0] <= roiInfo[2]);
94 NN_RET_CHECK(roiInfo[1] <= roiInfo[3]);
95
96 int32_t wRoiStart = std::round(static_cast<float>(roiInfo[0] * widthScale));
97 int32_t hRoiStart = std::round(static_cast<float>(roiInfo[1] * heightScale));
98 int32_t wRoiEnd = std::round(static_cast<float>(roiInfo[2] * widthScale));
99 int32_t hRoiEnd = std::round(static_cast<float>(roiInfo[3] * heightScale));
100
101 // Rois with width/height < 1 are considered malformed and are forced to be 1
102 T_Roi roiWidth = static_cast<T_Roi>(std::max(wRoiEnd - wRoiStart + 1, 1));
103 T_Roi roiHeight = static_cast<T_Roi>(std::max(hRoiEnd - hRoiStart + 1, 1));
104 T_Roi wStepSize = roiWidth / static_cast<T_Roi>(outWidth);
105 T_Roi hStepSize = roiHeight / static_cast<T_Roi>(outHeight);
106
107 const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
108 for (uint32_t i = 0; i < outHeight; i++) {
109 for (uint32_t j = 0; j < outWidth; j++) {
110 // Take floor on start, ceil on end, start included, end excluded, i.e. [start, end)
111 // end is guaranteed to larger than start by at least 1
112 uint32_t wStart = std::floor(static_cast<float>(wStepSize * j + wRoiStart));
113 uint32_t wEnd = std::ceil(static_cast<float>(wStepSize * (j + 1) + wRoiStart));
114 uint32_t hStart = std::floor(static_cast<float>(hStepSize * i + hRoiStart));
115 uint32_t hEnd = std::ceil(static_cast<float>(hStepSize * (i + 1) + hRoiStart));
116
117 wStart = std::min(wStart, inWidth);
118 wEnd = std::min(wEnd, inWidth);
119 hStart = std::min(hStart, inHeight);
120 hEnd = std::min(hEnd, inHeight);
121
122 for (uint32_t k = 0; k < inDepth; k++) {
123 T_Input maxValue = static_cast<T_Input>(inputShape.offset);
124 bool first = true;
125 for (uint32_t h = hStart; h < hEnd; h++) {
126 for (uint32_t w = wStart; w < wEnd; w++) {
127 T_Input inputValue = batchBase[h * inWidth * inDepth + w * inDepth + k];
128 if (first || inputValue > maxValue) {
129 maxValue = inputValue;
130 first = false;
131 }
132 }
133 }
134 outPtr[k] = maxValue;
135 }
136 outPtr += inDepth;
137 }
138 }
139 }
140 return true;
141 }
142
143 template <typename T_Input, typename T_Roi>
roiPooling(const T_Input * inputData,const Shape & inputShape,const T_Roi * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,bool useNchw,T_Input * outputData,const Shape & outputShape)144 inline bool roiPooling(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
145 const Shape& roiShape, const int32_t* batchSplitData,
146 const Shape& batchSplitShape, float heightStride, float widthStride,
147 bool useNchw, T_Input* outputData, const Shape& outputShape) {
148 InputWithLayout<T_Input> input(useNchw);
149 OutputWithLayout<T_Input> output(useNchw);
150 NN_RET_CHECK(input.initialize(inputData, inputShape));
151 NN_RET_CHECK(output.initialize(outputData, outputShape));
152 NN_RET_CHECK(roiPoolingNhwc(input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape,
153 batchSplitData, batchSplitShape, heightStride, widthStride,
154 output.getNhwcBuffer(), output.getNhwcShape()));
155 NN_RET_CHECK(output.commit());
156 return true;
157 }
158
159 template <>
160 inline bool roiPooling<uint8_t, uint16_t>(const uint8_t* inputData, const Shape& inputShape,
161 const uint16_t* roiData, const Shape& roiShape,
162 const int32_t* batchSplitData,
163 const Shape& batchSplitShape, float heightStride,
164 float widthStride, bool useNchw, uint8_t* outputData,
165 const Shape& outputShape) {
166 std::vector<float> roi_float32(getNumberOfElements(roiShape));
167 convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
168 NN_RET_CHECK(roiPooling(inputData, inputShape, roi_float32.data(), roiShape, batchSplitData,
169 batchSplitShape, heightStride, widthStride, useNchw, outputData,
170 outputShape));
171 return true;
172 }
173
174 template <>
175 inline bool roiPooling<int8_t, uint16_t>(const int8_t* inputData, const Shape& inputShape,
176 const uint16_t* roiData, const Shape& roiShape,
177 const int32_t* batchSplitData,
178 const Shape& batchSplitShape, float heightStride,
179 float widthStride, bool useNchw, int8_t* outputData,
180 const Shape& outputShape) {
181 std::vector<float> roi_float32(getNumberOfElements(roiShape));
182 convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
183 NN_RET_CHECK(roiPooling(inputData, inputShape, roi_float32.data(), roiShape, batchSplitData,
184 batchSplitShape, heightStride, widthStride, useNchw, outputData,
185 outputShape));
186 return true;
187 }
188
189 } // namespace
190 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
191
validate(const IOperationValidationContext * context)192 Result<Version> validate(const IOperationValidationContext* context) {
193 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
194 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
195 std::vector<OperandType> inExpectedTypes;
196 auto inputType = context->getInputType(kInputTensor);
197 if (inputType == OperandType::TENSOR_FLOAT32) {
198 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
199 OperandType::TENSOR_INT32, OperandType::INT32,
200 OperandType::INT32, OperandType::FLOAT32,
201 OperandType::FLOAT32, OperandType::BOOL};
202 } else if (inputType == OperandType::TENSOR_FLOAT16) {
203 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
204 OperandType::TENSOR_INT32, OperandType::INT32,
205 OperandType::INT32, OperandType::FLOAT16,
206 OperandType::FLOAT16, OperandType::BOOL};
207 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
208 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
209 inExpectedTypes = {inputType,
210 OperandType::TENSOR_QUANT16_ASYMM,
211 OperandType::TENSOR_INT32,
212 OperandType::INT32,
213 OperandType::INT32,
214 OperandType::FLOAT32,
215 OperandType::FLOAT32,
216 OperandType::BOOL};
217 } else {
218 return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName;
219 }
220 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
221 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
222 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
223 return Version::ANDROID_R;
224 } else {
225 return Version::ANDROID_Q;
226 }
227 }
228
229 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
prepare(IOperationExecutionContext * context)230 bool prepare(IOperationExecutionContext* context) {
231 bool useNchw = context->getInputValue<bool>(kLayoutScalar);
232 Shape input = context->getInputShape(kInputTensor);
233 Shape roiShape = context->getInputShape(kRoiTensor);
234 Shape batchSplitShape = context->getInputShape(kBatchSplitTensor);
235 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
236 NN_RET_CHECK_EQ(getNumberOfDimensions(roiShape), 2);
237
238 uint32_t numBatches = getSizeOfDimension(input, 0);
239 uint32_t inHeight = getSizeOfDimension(input, useNchw ? 2 : 1);
240 uint32_t inWidth = getSizeOfDimension(input, useNchw ? 3 : 2);
241 uint32_t inDepth = getSizeOfDimension(input, useNchw ? 1 : 3);
242 uint32_t numRois = getSizeOfDimension(roiShape, 0);
243 NN_RET_CHECK_EQ(getSizeOfDimension(roiShape, 1), 4);
244 NN_RET_CHECK_EQ(getSizeOfDimension(batchSplitShape, 0), numRois);
245
246 auto outputHeight = context->getInputValue<int32_t>(kOutputHeightScalar);
247 auto outputWidth = context->getInputValue<int32_t>(kOutputWidthScalar);
248 float heightStride, widthStride;
249 if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
250 heightStride = context->getInputValue<_Float16>(kHeightStrideSalar);
251 widthStride = context->getInputValue<_Float16>(kWidthStrideScalar);
252 } else {
253 heightStride = context->getInputValue<float>(kHeightStrideSalar);
254 widthStride = context->getInputValue<float>(kWidthStrideScalar);
255 }
256 NN_RET_CHECK_GT(outputHeight, 0);
257 NN_RET_CHECK_GT(outputWidth, 0);
258 NN_RET_CHECK_GT(heightStride, 0);
259 NN_RET_CHECK_GT(widthStride, 0);
260
261 if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) {
262 NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
263 NN_RET_CHECK_EQ(roiShape.offset, 0);
264 }
265
266 Shape output = input;
267 if (useNchw) {
268 output.dimensions = {numRois, inDepth, static_cast<uint32_t>(outputHeight),
269 static_cast<uint32_t>(outputWidth)};
270 } else {
271 output.dimensions = {numRois, static_cast<uint32_t>(outputHeight),
272 static_cast<uint32_t>(outputWidth), inDepth};
273 }
274 return context->setOutputShape(kOutputTensor, output);
275 }
276
execute(IOperationExecutionContext * context)277 bool execute(IOperationExecutionContext* context) {
278 switch (context->getInputType(kInputTensor)) {
279 case OperandType::TENSOR_FLOAT16:
280 return roiPooling(context->getInputBuffer<_Float16>(kInputTensor),
281 context->getInputShape(kInputTensor),
282 context->getInputBuffer<_Float16>(kRoiTensor),
283 context->getInputShape(kRoiTensor),
284 context->getInputBuffer<int32_t>(kBatchSplitTensor),
285 context->getInputShape(kBatchSplitTensor),
286 context->getInputValue<_Float16>(kHeightStrideSalar),
287 context->getInputValue<_Float16>(kWidthStrideScalar),
288 context->getInputValue<bool>(kLayoutScalar),
289 context->getOutputBuffer<_Float16>(kOutputTensor),
290 context->getOutputShape(kOutputTensor));
291 case OperandType::TENSOR_FLOAT32:
292 return roiPooling(context->getInputBuffer<float>(kInputTensor),
293 context->getInputShape(kInputTensor),
294 context->getInputBuffer<float>(kRoiTensor),
295 context->getInputShape(kRoiTensor),
296 context->getInputBuffer<int32_t>(kBatchSplitTensor),
297 context->getInputShape(kBatchSplitTensor),
298 context->getInputValue<float>(kHeightStrideSalar),
299 context->getInputValue<float>(kWidthStrideScalar),
300 context->getInputValue<bool>(kLayoutScalar),
301 context->getOutputBuffer<float>(kOutputTensor),
302 context->getOutputShape(kOutputTensor));
303 case OperandType::TENSOR_QUANT8_ASYMM:
304 return roiPooling(context->getInputBuffer<uint8_t>(kInputTensor),
305 context->getInputShape(kInputTensor),
306 context->getInputBuffer<uint16_t>(kRoiTensor),
307 context->getInputShape(kRoiTensor),
308 context->getInputBuffer<int32_t>(kBatchSplitTensor),
309 context->getInputShape(kBatchSplitTensor),
310 context->getInputValue<float>(kHeightStrideSalar),
311 context->getInputValue<float>(kWidthStrideScalar),
312 context->getInputValue<bool>(kLayoutScalar),
313 context->getOutputBuffer<uint8_t>(kOutputTensor),
314 context->getOutputShape(kOutputTensor));
315 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
316 return roiPooling(context->getInputBuffer<int8_t>(kInputTensor),
317 context->getInputShape(kInputTensor),
318 context->getInputBuffer<uint16_t>(kRoiTensor),
319 context->getInputShape(kRoiTensor),
320 context->getInputBuffer<int32_t>(kBatchSplitTensor),
321 context->getInputShape(kBatchSplitTensor),
322 context->getInputValue<float>(kHeightStrideSalar),
323 context->getInputValue<float>(kWidthStrideScalar),
324 context->getInputValue<bool>(kLayoutScalar),
325 context->getOutputBuffer<int8_t>(kOutputTensor),
326 context->getOutputShape(kOutputTensor));
327 default:
328 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
329 }
330 }
331 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
332
333 } // namespace roi_pooling
334
335 NN_REGISTER_OPERATION(ROI_POOLING, roi_pooling::kOperationName, roi_pooling::validate,
336 roi_pooling::prepare, roi_pooling::execute);
337
338 } // namespace nn
339 } // namespace android
340