1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "OperationResolver.h"
20 #include "OperationsUtils.h"
21 
22 namespace android {
23 namespace nn {
24 namespace fill_op {
25 
26 constexpr uint32_t kNumInputs = 2;
27 constexpr uint32_t kDimsTensor = 0;
28 constexpr uint32_t kValueScalar = 1;
29 
30 constexpr uint32_t kNumOutputs = 1;
31 constexpr uint32_t kOutputTensor = 0;
32 
33 namespace {
34 
35 template <typename T>
executeTyped(IOperationExecutionContext * context)36 bool executeTyped(IOperationExecutionContext* context) {
37     T* output = context->getOutputBuffer<T>(kOutputTensor);
38     const int numElements = getNumberOfElements(context->getOutputShape(kOutputTensor));
39     const T value = context->getInputValue<T>(kValueScalar);
40     for (int i = 0; i < numElements; ++i) {
41         output[i] = value;
42     }
43     return true;
44 }
45 
getValueType(OperandType outputType,OperandType * valueType)46 bool getValueType(OperandType outputType, OperandType* valueType) {
47     switch (outputType) {
48         case OperandType::TENSOR_FLOAT16:
49             *valueType = OperandType::FLOAT16;
50             return true;
51         case OperandType::TENSOR_FLOAT32:
52             *valueType = OperandType::FLOAT32;
53             return true;
54         case OperandType::TENSOR_INT32:
55             *valueType = OperandType::INT32;
56             return true;
57         default:
58             NN_RET_CHECK_FAIL() << "Unsupported value type for fill op: " << outputType;
59     }
60 }
61 
62 }  // namespace
63 
validate(const IOperationValidationContext * context)64 Result<Version> validate(const IOperationValidationContext* context) {
65     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
66     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
67     // Check output type first because input value type is dependent on the
68     // output type.
69     OperandType outputType = context->getOutputType(kOutputTensor);
70     NN_RET_CHECK(outputType == OperandType::TENSOR_FLOAT16 ||
71                  outputType == OperandType::TENSOR_FLOAT32 ||
72                  outputType == OperandType::TENSOR_INT32)
73             << "Unsupported output type for fill op: " << outputType;
74     NN_RET_CHECK(validateOutputTypes(context, {outputType}));
75 
76     OperandType valueType;
77     NN_RET_CHECK(getValueType(outputType, &valueType));
78     NN_RET_CHECK(validateInputTypes(context, {OperandType::TENSOR_INT32, valueType}));
79 
80     return Version::ANDROID_R;
81 }
82 
prepare(IOperationExecutionContext * context)83 bool prepare(IOperationExecutionContext* context) {
84     Shape dimsShape = context->getInputShape(kDimsTensor);
85     NN_RET_CHECK_EQ(getNumberOfDimensions(dimsShape), 1);
86 
87     Shape outputShape = context->getOutputShape(kOutputTensor);
88     outputShape.dimensions.resize(dimsShape.dimensions[0]);
89     const int32_t* dims = context->getInputBuffer<int32_t>(kDimsTensor);
90     for (int i = 0; i < dimsShape.dimensions[0]; ++i) {
91         outputShape.dimensions[i] = dims[i];
92     }
93     return context->setOutputShape(kOutputTensor, outputShape);
94 }
95 
execute(IOperationExecutionContext * context)96 bool execute(IOperationExecutionContext* context) {
97     switch (context->getInputType(kValueScalar)) {
98         case OperandType::FLOAT16:
99             return executeTyped<_Float16>(context);
100         case OperandType::FLOAT32:
101             return executeTyped<float>(context);
102         case OperandType::INT32:
103             return executeTyped<int32_t>(context);
104         default:
105             NN_RET_CHECK_FAIL() << "Unsupported value type for fill op.";
106     }
107 }
108 
109 }  // namespace fill_op
110 
111 NN_REGISTER_OPERATION(FILL, "FILL", fill_op::validate, fill_op::prepare, fill_op::execute);
112 
113 }  // namespace nn
114 }  // namespace android
115