1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <algorithm>
20 #include <cfloat>
21 #include <cmath>
22 #include <vector>
23
24 #include "OperationResolver.h"
25 #include "OperationsUtils.h"
26 #include "Tracing.h"
27
28 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
29 #include <tensorflow/lite/kernels/internal/common.h>
30
31 #include "CpuOperationUtils.h"
32 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
33
34 namespace android {
35 namespace nn {
36 namespace roi_align {
37
38 constexpr char kOperationName[] = "ROI_ALIGN";
39
40 constexpr uint32_t kNumInputs = 10;
41 constexpr uint32_t kInputTensor = 0;
42 constexpr uint32_t kRoiTensor = 1;
43 constexpr uint32_t kBatchSplitTensor = 2;
44 constexpr uint32_t kOutputHeightScalar = 3;
45 constexpr uint32_t kOutputWidthScalar = 4;
46 constexpr uint32_t kHeightStrideSalar = 5;
47 constexpr uint32_t kWidthStrideScalar = 6;
48 constexpr uint32_t kHeightSamplingRatioScalar = 7;
49 constexpr uint32_t kWidthSamplingRatioScalar = 8;
50 constexpr uint32_t kLayoutScalar = 9;
51
52 constexpr uint32_t kNumOutputs = 1;
53 constexpr uint32_t kOutputTensor = 0;
54
55 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
56 namespace {
57
58 template <typename T_Input, typename T_Roi>
roiAlignNhwc(const T_Input * inputData,const Shape & inputShape,const T_Roi * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,int32_t heightSamplingRatio,int32_t widthSamplingRatio,T_Input * outputData,const Shape & outputShape)59 inline bool roiAlignNhwc(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
60 const Shape& roiShape, const int32_t* batchSplitData,
61 const Shape& batchSplitShape, float heightStride, float widthStride,
62 int32_t heightSamplingRatio, int32_t widthSamplingRatio,
63 T_Input* outputData, const Shape& outputShape) {
64 NNTRACE_TRANS("RoiAlign");
65
66 const uint32_t kRoiDim = 4;
67 const T_Roi heightScale = 1.0f / heightStride;
68 const T_Roi widthScale = 1.0f / widthStride;
69
70 uint32_t numBatches = getSizeOfDimension(inputShape, 0);
71 uint32_t inHeight = getSizeOfDimension(inputShape, 1);
72 uint32_t inWidth = getSizeOfDimension(inputShape, 2);
73 uint32_t inDepth = getSizeOfDimension(inputShape, 3);
74 uint32_t outHeight = getSizeOfDimension(outputShape, 1);
75 uint32_t outWidth = getSizeOfDimension(outputShape, 2);
76 uint32_t numRois = getSizeOfDimension(roiShape, 0);
77 uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
78
79 T_Input* outPtr = outputData;
80 const T_Roi* roiDataEnd = roiData + numRois * roiInfoLength;
81 uint32_t roiIndex = 0;
82 for (const T_Roi* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
83 uint32_t batchId = static_cast<uint32_t>(batchSplitData[roiIndex]);
84 // Check for malformed data
85 // 1. invalid batch id
86 // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
87 // 3. Invalid region: x2 < x1 || y2 < y1
88 NN_RET_CHECK_GE(batchId, 0);
89 NN_RET_CHECK_LT(batchId, numBatches);
90 NN_RET_CHECK(roiInfo[0] >= 0);
91 NN_RET_CHECK(roiInfo[1] >= 0);
92 NN_RET_CHECK(roiInfo[2] >= 0);
93 NN_RET_CHECK(roiInfo[3] >= 0);
94 NN_RET_CHECK(roiInfo[0] * widthScale <= inWidth);
95 NN_RET_CHECK(roiInfo[1] * heightScale <= inHeight);
96 NN_RET_CHECK(roiInfo[2] * widthScale <= inWidth);
97 NN_RET_CHECK(roiInfo[3] * heightScale <= inHeight);
98 NN_RET_CHECK(roiInfo[0] <= roiInfo[2]);
99 NN_RET_CHECK(roiInfo[1] <= roiInfo[3]);
100
101 T_Roi wRoiStart = roiInfo[0] * widthScale;
102 T_Roi hRoiStart = roiInfo[1] * heightScale;
103 T_Roi wRoiEnd = roiInfo[2] * widthScale;
104 T_Roi hRoiEnd = roiInfo[3] * heightScale;
105
106 T_Roi roiWidth = std::max(static_cast<float>(wRoiEnd - wRoiStart), 1.0f);
107 T_Roi roiHeight = std::max(static_cast<float>(hRoiEnd - hRoiStart), 1.0f);
108 T_Roi wStepSize = roiWidth / static_cast<T_Roi>(outWidth);
109 T_Roi hStepSize = roiHeight / static_cast<T_Roi>(outHeight);
110
111 // if samplingRatio = 0, use adaptive value of ceil(roiWidth/outWidth), same for height
112 uint32_t wSamplingRatio = widthSamplingRatio > 0 ? widthSamplingRatio
113 : std::ceil(static_cast<float>(wStepSize));
114 uint32_t hSamplingRatio = heightSamplingRatio > 0
115 ? heightSamplingRatio
116 : std::ceil(static_cast<float>(hStepSize));
117 int32_t numSamplingPoints = wSamplingRatio * hSamplingRatio;
118 T_Roi wBinSize = wStepSize / static_cast<T_Roi>(wSamplingRatio);
119 T_Roi hBinSize = hStepSize / static_cast<T_Roi>(hSamplingRatio);
120
121 const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
122 for (uint32_t i = 0; i < outHeight; i++) {
123 for (uint32_t j = 0; j < outWidth; j++) {
124 T_Roi wStart = wStepSize * j + wRoiStart;
125 T_Roi wEnd = wStepSize * (j + 1) + wRoiStart;
126 T_Roi hStart = hStepSize * i + hRoiStart;
127 T_Roi hEnd = hStepSize * (i + 1) + hRoiStart;
128
129 // initialize output to zero
130 for (uint32_t k = 0; k < inDepth; k++) outPtr[k] = 0;
131
132 // calculate the sum of the sampling points
133 for (uint32_t yInd = 0; yInd < hSamplingRatio; yInd++) {
134 for (uint32_t xInd = 0; xInd < wSamplingRatio; xInd++) {
135 T_Roi y = hStart + hBinSize / 2 + hBinSize * yInd;
136 T_Roi x = wStart + wBinSize / 2 + wBinSize * xInd;
137
138 // bilinear interpolation of point (x,y)
139 // w.r.t box [(x1,y1), (x1,y2), (x2,y1), (x2,y2)]
140 uint32_t x1 = std::floor(static_cast<float>(x));
141 uint32_t y1 = std::floor(static_cast<float>(y));
142 uint32_t x2 = x1 + 1, y2 = y1 + 1;
143 T_Roi dx1 = x - static_cast<T_Roi>(x1);
144 T_Roi dy1 = y - static_cast<T_Roi>(y1);
145
146 // dealing with out of bound samples
147 if (x1 >= inWidth - 1) {
148 x1 = x2 = inWidth - 1;
149 dx1 = 0;
150 }
151 if (y1 >= inHeight - 1) {
152 y1 = y2 = inHeight - 1;
153 dy1 = 0;
154 }
155
156 T_Roi dx2 = 1.0f - dx1, dy2 = 1.0f - dy1;
157 T_Roi ws[] = {dx2 * dy2, dx1 * dy2, dx2 * dy1, dx1 * dy1};
158 uint32_t offsets[] = {y1 * inWidth * inDepth + x1 * inDepth,
159 y1 * inWidth * inDepth + x2 * inDepth,
160 y2 * inWidth * inDepth + x1 * inDepth,
161 y2 * inWidth * inDepth + x2 * inDepth};
162
163 for (uint32_t k = 0; k < inDepth; k++) {
164 T_Input interpolation = 0;
165 for (uint32_t c = 0; c < 4; c++) {
166 interpolation += ws[c] * batchBase[offsets[c] + k];
167 }
168 outPtr[k] += interpolation;
169 }
170 }
171 }
172
173 // take average
174 for (uint32_t k = 0; k < inDepth; k++)
175 outPtr[k] /= static_cast<T_Input>(numSamplingPoints);
176 outPtr += inDepth;
177 }
178 }
179 }
180 return true;
181 }
182
183 template <typename T_Input>
roiAlignQuantNhwc(const T_Input * inputData,const Shape & inputShape,const uint16_t * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,int32_t heightSamplingRatio,int32_t widthSamplingRatio,T_Input * outputData,const Shape & outputShape)184 inline bool roiAlignQuantNhwc(const T_Input* inputData, const Shape& inputShape,
185 const uint16_t* roiData, const Shape& roiShape,
186 const int32_t* batchSplitData, const Shape& batchSplitShape,
187 float heightStride, float widthStride, int32_t heightSamplingRatio,
188 int32_t widthSamplingRatio, T_Input* outputData,
189 const Shape& outputShape) {
190 NNTRACE_TRANS("RoiAlignQuant8");
191
192 constexpr float wScale = 1.0f / 255.0f;
193 constexpr uint32_t kRoiDim = 4;
194 const float heightScale = 1.0f / heightStride;
195 const float widthScale = 1.0f / widthStride;
196
197 uint32_t numBatches = getSizeOfDimension(inputShape, 0);
198 uint32_t inHeight = getSizeOfDimension(inputShape, 1);
199 uint32_t inWidth = getSizeOfDimension(inputShape, 2);
200 uint32_t inDepth = getSizeOfDimension(inputShape, 3);
201 uint32_t outHeight = getSizeOfDimension(outputShape, 1);
202 uint32_t outWidth = getSizeOfDimension(outputShape, 2);
203 uint32_t numRois = getSizeOfDimension(roiShape, 0);
204 uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
205
206 T_Input* outPtr = outputData;
207 const uint16_t* roiDataEnd = roiData + numRois * roiInfoLength;
208 uint32_t roiIndex = 0;
209 for (const uint16_t* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
210 uint32_t batchId = static_cast<uint32_t>(batchSplitData[roiIndex]);
211 float wRoiStart = static_cast<float>(roiInfo[0]) * widthScale * 0.125f;
212 float hRoiStart = static_cast<float>(roiInfo[1]) * heightScale * 0.125f;
213 float wRoiEnd = static_cast<float>(roiInfo[2]) * widthScale * 0.125f;
214 float hRoiEnd = static_cast<float>(roiInfo[3]) * heightScale * 0.125f;
215
216 // Check for malformed data
217 // 1. invalid batch id
218 // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
219 // 3. Invalid region: x2 < x1 || y2 < y1
220 NN_RET_CHECK_GE(batchId, 0);
221 NN_RET_CHECK_LT(batchId, numBatches);
222 NN_RET_CHECK(wRoiStart <= inWidth);
223 NN_RET_CHECK(hRoiStart <= inHeight);
224 NN_RET_CHECK(wRoiEnd <= inWidth);
225 NN_RET_CHECK(hRoiEnd <= inHeight);
226 NN_RET_CHECK_LE(wRoiStart, wRoiEnd);
227 NN_RET_CHECK_LE(hRoiStart, hRoiEnd);
228
229 float roiWidth = std::max(wRoiEnd - wRoiStart, 1.0f);
230 float roiHeight = std::max(hRoiEnd - hRoiStart, 1.0f);
231 float wStepSize = roiWidth / static_cast<float>(outWidth);
232 float hStepSize = roiHeight / static_cast<float>(outHeight);
233
234 // if samplingRatio = 0, use adaptive value of ceil(roiWidth/outWidth), same for height
235 uint32_t wSamplingRatio =
236 widthSamplingRatio > 0 ? widthSamplingRatio : std::ceil(wStepSize);
237 uint32_t hSamplingRatio =
238 heightSamplingRatio > 0 ? heightSamplingRatio : std::ceil(hStepSize);
239 int32_t numSamplingPoints = wSamplingRatio * hSamplingRatio;
240 float wBinSize = wStepSize / static_cast<float>(wSamplingRatio);
241 float hBinSize = hStepSize / static_cast<float>(hSamplingRatio);
242
243 float realMultiplier = inputShape.scale * wScale / outputShape.scale / numSamplingPoints;
244 int32_t outputMultiplier = 0;
245 int32_t outputShift = 0;
246 if (!QuantizeMultiplierSmallerThanOne(realMultiplier, &outputMultiplier, &outputShift)) {
247 return false;
248 }
249
250 const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
251 for (uint32_t i = 0; i < outHeight; i++) {
252 for (uint32_t j = 0; j < outWidth; j++) {
253 float wStart = wStepSize * j + wRoiStart;
254 float wEnd = wStepSize * (j + 1) + wRoiStart;
255 float hStart = hStepSize * i + hRoiStart;
256 float hEnd = hStepSize * (i + 1) + hRoiStart;
257
258 std::vector<int32_t> outTemp(inDepth, 0);
259 // calculate the sum of the sampling points
260 for (uint32_t yInd = 0; yInd < hSamplingRatio; yInd++) {
261 for (uint32_t xInd = 0; xInd < wSamplingRatio; xInd++) {
262 float y = hStart + hBinSize / 2 + hBinSize * yInd;
263 float x = wStart + wBinSize / 2 + wBinSize * xInd;
264
265 // bilinear interpolation of point (x,y)
266 // w.r.t box [(x1,y1), (x1,y2), (x2,y1), (x2,y2)]
267 uint32_t x1 = std::floor(x), y1 = std::floor(y);
268 uint32_t x2 = x1 + 1, y2 = y1 + 1;
269 float dx1 = x - static_cast<float>(x1);
270 float dy1 = y - static_cast<float>(y1);
271
272 // dealing with out of bound samples
273 if (x1 >= inWidth - 1) {
274 x1 = x2 = inWidth - 1;
275 dx1 = 0;
276 }
277 if (y1 >= inHeight - 1) {
278 y1 = y2 = inHeight - 1;
279 dy1 = 0;
280 }
281
282 float dx2 = 1.0f - dx1, dy2 = 1.0f - dy1;
283 float ws[] = {dx2 * dy2, dx1 * dy2, dx2 * dy1, dx1 * dy1};
284 uint32_t offsets[] = {y1 * inWidth * inDepth + x1 * inDepth,
285 y1 * inWidth * inDepth + x2 * inDepth,
286 y2 * inWidth * inDepth + x1 * inDepth,
287 y2 * inWidth * inDepth + x2 * inDepth};
288
289 for (uint32_t k = 0; k < inDepth; k++) {
290 int32_t interpolation = 0;
291 for (uint32_t c = 0; c < 4; c++) {
292 int32_t wQuant = static_cast<int32_t>(std::round(ws[c] / wScale));
293 interpolation +=
294 wQuant * (static_cast<int32_t>(batchBase[offsets[c] + k]) -
295 inputShape.offset);
296 }
297 outTemp[k] += interpolation;
298 }
299 }
300 }
301
302 // take average and cast to output quantization
303 for (uint32_t k = 0; k < inDepth; k++) {
304 int32_t raw_out = tflite::MultiplyByQuantizedMultiplier(
305 outTemp[k], outputMultiplier, -outputShift) +
306 outputShape.offset;
307 outPtr[k] = saturateCast<T_Input>(raw_out);
308 }
309 outPtr += inDepth;
310 }
311 }
312 }
313 return true;
314 }
315
316 template <typename T_Input, typename T_Roi>
roiAlign(const T_Input * inputData,const Shape & inputShape,const T_Roi * roiData,const Shape & roiShape,const int32_t * batchSplitData,const Shape & batchSplitShape,float heightStride,float widthStride,int32_t heightSamplingRatio,int32_t widthSamplingRatio,bool useNchw,T_Input * outputData,const Shape & outputShape)317 inline bool roiAlign(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
318 const Shape& roiShape, const int32_t* batchSplitData,
319 const Shape& batchSplitShape, float heightStride, float widthStride,
320 int32_t heightSamplingRatio, int32_t widthSamplingRatio, bool useNchw,
321 T_Input* outputData, const Shape& outputShape) {
322 InputWithLayout<T_Input> input(useNchw);
323 OutputWithLayout<T_Input> output(useNchw);
324 NN_RET_CHECK(input.initialize(inputData, inputShape));
325 NN_RET_CHECK(output.initialize(outputData, outputShape));
326 if constexpr (std::is_same_v<T_Roi, uint16_t> &&
327 (std::is_same_v<T_Input, uint8_t> || std::is_same_v<T_Input, int8_t>)) {
328 NN_RET_CHECK(roiAlignQuantNhwc<T_Input>(
329 input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape, batchSplitData,
330 batchSplitShape, heightStride, widthStride, heightSamplingRatio, widthSamplingRatio,
331 output.getNhwcBuffer(), output.getNhwcShape()));
332 } else {
333 NN_RET_CHECK(roiAlignNhwc(input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape,
334 batchSplitData, batchSplitShape, heightStride, widthStride,
335 heightSamplingRatio, widthSamplingRatio, output.getNhwcBuffer(),
336 output.getNhwcShape()));
337 }
338 NN_RET_CHECK(output.commit());
339 return true;
340 }
341
342 } // namespace
343 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
344
validate(const IOperationValidationContext * context)345 Result<Version> validate(const IOperationValidationContext* context) {
346 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
347 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
348 std::vector<OperandType> inExpectedTypes;
349 auto inputType = context->getInputType(kInputTensor);
350 if (inputType == OperandType::TENSOR_FLOAT32) {
351 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
352 OperandType::TENSOR_INT32, OperandType::INT32,
353 OperandType::INT32, OperandType::FLOAT32,
354 OperandType::FLOAT32, OperandType::INT32,
355 OperandType::INT32, OperandType::BOOL};
356 } else if (inputType == OperandType::TENSOR_FLOAT16) {
357 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
358 OperandType::TENSOR_INT32, OperandType::INT32,
359 OperandType::INT32, OperandType::FLOAT16,
360 OperandType::FLOAT16, OperandType::INT32,
361 OperandType::INT32, OperandType::BOOL};
362 } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM ||
363 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
364 inExpectedTypes = {inputType,
365 OperandType::TENSOR_QUANT16_ASYMM,
366 OperandType::TENSOR_INT32,
367 OperandType::INT32,
368 OperandType::INT32,
369 OperandType::FLOAT32,
370 OperandType::FLOAT32,
371 OperandType::INT32,
372 OperandType::INT32,
373 OperandType::BOOL};
374 } else {
375 return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName;
376 }
377 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
378 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
379 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
380 return Version::ANDROID_R;
381 } else {
382 return Version::ANDROID_Q;
383 }
384 }
385
386 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
prepare(IOperationExecutionContext * context)387 bool prepare(IOperationExecutionContext* context) {
388 bool useNchw = context->getInputValue<bool>(kLayoutScalar);
389 Shape input = context->getInputShape(kInputTensor);
390 Shape roiShape = context->getInputShape(kRoiTensor);
391 Shape batchSplitShape = context->getInputShape(kBatchSplitTensor);
392 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
393 NN_RET_CHECK_EQ(getNumberOfDimensions(roiShape), 2);
394
395 uint32_t numBatches = getSizeOfDimension(input, 0);
396 uint32_t inHeight = getSizeOfDimension(input, useNchw ? 2 : 1);
397 uint32_t inWidth = getSizeOfDimension(input, useNchw ? 3 : 2);
398 uint32_t inDepth = getSizeOfDimension(input, useNchw ? 1 : 3);
399 uint32_t numRois = getSizeOfDimension(roiShape, 0);
400 // Every dimension must be positive except for numRois.
401 NN_RET_CHECK_GT(numBatches, 0);
402 NN_RET_CHECK_GT(inHeight, 0);
403 NN_RET_CHECK_GT(inWidth, 0);
404 NN_RET_CHECK_GT(inDepth, 0);
405 NN_RET_CHECK_EQ(getSizeOfDimension(roiShape, 1), 4);
406 NN_RET_CHECK_EQ(getSizeOfDimension(batchSplitShape, 0), numRois);
407
408 int32_t outputHeight = context->getInputValue<int32_t>(kOutputHeightScalar);
409 int32_t outputWidth = context->getInputValue<int32_t>(kOutputWidthScalar);
410 int32_t heightSamplingRatio = context->getInputValue<int32_t>(kHeightSamplingRatioScalar);
411 int32_t widthSamplingRatio = context->getInputValue<int32_t>(kWidthSamplingRatioScalar);
412 float heightScale, widthScale;
413 if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
414 heightScale = context->getInputValue<_Float16>(kHeightStrideSalar);
415 widthScale = context->getInputValue<_Float16>(kWidthStrideScalar);
416 } else {
417 heightScale = context->getInputValue<float>(kHeightStrideSalar);
418 widthScale = context->getInputValue<float>(kWidthStrideScalar);
419 }
420 NN_RET_CHECK_GT(outputHeight, 0);
421 NN_RET_CHECK_GT(outputWidth, 0);
422 NN_RET_CHECK_GT(heightScale, 0);
423 NN_RET_CHECK_GT(widthScale, 0);
424 // Sampling ratio can set to 0 for adaptive value.
425 NN_RET_CHECK_GE(heightSamplingRatio, 0);
426 NN_RET_CHECK_GE(widthSamplingRatio, 0);
427
428 if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) {
429 NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
430 NN_RET_CHECK_EQ(roiShape.offset, 0);
431 }
432
433 Shape output = context->getOutputShape(kOutputTensor);
434 output.type = input.type;
435 if (useNchw) {
436 output.dimensions = {numRois, inDepth, static_cast<uint32_t>(outputHeight),
437 static_cast<uint32_t>(outputWidth)};
438 } else {
439 output.dimensions = {numRois, static_cast<uint32_t>(outputHeight),
440 static_cast<uint32_t>(outputWidth), inDepth};
441 }
442 return context->setOutputShape(kOutputTensor, output);
443 }
444
execute(IOperationExecutionContext * context)445 bool execute(IOperationExecutionContext* context) {
446 // Bypass execution in the case of zero-sized input.
447 if (getNumberOfElements(context->getInputShape(kRoiTensor)) == 0) return true;
448 switch (context->getInputType(kInputTensor)) {
449 case OperandType::TENSOR_FLOAT16:
450 return roiAlign(context->getInputBuffer<_Float16>(kInputTensor),
451 context->getInputShape(kInputTensor),
452 context->getInputBuffer<_Float16>(kRoiTensor),
453 context->getInputShape(kRoiTensor),
454 context->getInputBuffer<int32_t>(kBatchSplitTensor),
455 context->getInputShape(kBatchSplitTensor),
456 context->getInputValue<_Float16>(kHeightStrideSalar),
457 context->getInputValue<_Float16>(kWidthStrideScalar),
458 context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
459 context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
460 context->getInputValue<bool>(kLayoutScalar),
461 context->getOutputBuffer<_Float16>(kOutputTensor),
462 context->getOutputShape(kOutputTensor));
463 case OperandType::TENSOR_FLOAT32:
464 return roiAlign(context->getInputBuffer<float>(kInputTensor),
465 context->getInputShape(kInputTensor),
466 context->getInputBuffer<float>(kRoiTensor),
467 context->getInputShape(kRoiTensor),
468 context->getInputBuffer<int32_t>(kBatchSplitTensor),
469 context->getInputShape(kBatchSplitTensor),
470 context->getInputValue<float>(kHeightStrideSalar),
471 context->getInputValue<float>(kWidthStrideScalar),
472 context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
473 context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
474 context->getInputValue<bool>(kLayoutScalar),
475 context->getOutputBuffer<float>(kOutputTensor),
476 context->getOutputShape(kOutputTensor));
477 case OperandType::TENSOR_QUANT8_ASYMM:
478 return roiAlign(context->getInputBuffer<uint8_t>(kInputTensor),
479 context->getInputShape(kInputTensor),
480 context->getInputBuffer<uint16_t>(kRoiTensor),
481 context->getInputShape(kRoiTensor),
482 context->getInputBuffer<int32_t>(kBatchSplitTensor),
483 context->getInputShape(kBatchSplitTensor),
484 context->getInputValue<float>(kHeightStrideSalar),
485 context->getInputValue<float>(kWidthStrideScalar),
486 context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
487 context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
488 context->getInputValue<bool>(kLayoutScalar),
489 context->getOutputBuffer<uint8_t>(kOutputTensor),
490 context->getOutputShape(kOutputTensor));
491 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
492 return roiAlign(context->getInputBuffer<int8_t>(kInputTensor),
493 context->getInputShape(kInputTensor),
494 context->getInputBuffer<uint16_t>(kRoiTensor),
495 context->getInputShape(kRoiTensor),
496 context->getInputBuffer<int32_t>(kBatchSplitTensor),
497 context->getInputShape(kBatchSplitTensor),
498 context->getInputValue<float>(kHeightStrideSalar),
499 context->getInputValue<float>(kWidthStrideScalar),
500 context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
501 context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
502 context->getInputValue<bool>(kLayoutScalar),
503 context->getOutputBuffer<int8_t>(kOutputTensor),
504 context->getOutputShape(kOutputTensor));
505 default:
506 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
507 }
508 }
509 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
510
511 } // namespace roi_align
512
513 NN_REGISTER_OPERATION(ROI_ALIGN, roi_align::kOperationName, roi_align::validate, roi_align::prepare,
514 roi_align::execute, .allowZeroSizedInput = true);
515
516 } // namespace nn
517 } // namespace android
518