1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include <cmath>
20 
21 #include "OperationResolver.h"
22 #include "OperationsUtils.h"
23 #include "Tracing.h"
24 
25 namespace android {
26 namespace nn {
27 namespace neg {
28 
29 constexpr char kOperationName[] = "NEG";
30 
31 constexpr uint32_t kNumInputs = 1;
32 constexpr uint32_t kInputTensor = 0;
33 
34 constexpr uint32_t kNumOutputs = 1;
35 constexpr uint32_t kOutputTensor = 0;
36 
37 namespace {
38 
39 template <typename T>
compute(const T * input,const Shape & shape,T * output)40 inline bool compute(const T* input, const Shape& shape, T* output) {
41     const auto size = getNumberOfElements(shape);
42     for (uint32_t i = 0; i < size; ++i) {
43         output[i] = -input[i];
44     }
45     return true;
46 }
47 
48 }  // namespace
49 
validate(const IOperationValidationContext * context)50 Result<Version> validate(const IOperationValidationContext* context) {
51     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
52     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
53     OperandType inputType = context->getInputType(kInputTensor);
54     NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
55                  inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32)
56             << "Unsupported tensor type for operation " << kOperationName;
57     NN_RET_CHECK(validateInputTypes(context, {inputType}));
58     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
59     return Version::ANDROID_Q;
60 }
61 
prepare(IOperationExecutionContext * context)62 bool prepare(IOperationExecutionContext* context) {
63     Shape input = context->getInputShape(kInputTensor);
64     Shape output = context->getOutputShape(kOutputTensor);
65     NN_RET_CHECK(SetShape(input, &output));
66     return context->setOutputShape(kOutputTensor, output);
67 }
68 
execute(IOperationExecutionContext * context)69 bool execute(IOperationExecutionContext* context) {
70     switch (context->getInputType(kInputTensor)) {
71         case OperandType::TENSOR_FLOAT16:
72             return compute(context->getInputBuffer<_Float16>(kInputTensor),
73                            context->getInputShape(kInputTensor),
74                            context->getOutputBuffer<_Float16>(kOutputTensor));
75         case OperandType::TENSOR_FLOAT32:
76             return compute(context->getInputBuffer<float>(kInputTensor),
77                            context->getInputShape(kInputTensor),
78                            context->getOutputBuffer<float>(kOutputTensor));
79         case OperandType::TENSOR_INT32:
80             return compute(context->getInputBuffer<int32_t>(kInputTensor),
81                            context->getInputShape(kInputTensor),
82                            context->getOutputBuffer<int32_t>(kOutputTensor));
83         default:
84             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
85     }
86 }
87 
88 }  // namespace neg
89 
90 NN_REGISTER_OPERATION(NEG, neg::kOperationName, neg::validate, neg::prepare, neg::execute);
91 
92 }  // namespace nn
93 }  // namespace android
94