1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <cmath>
20
21 #include "OperationResolver.h"
22 #include "OperationsUtils.h"
23 #include "Tracing.h"
24
25 namespace android {
26 namespace nn {
27 namespace neg {
28
29 constexpr char kOperationName[] = "NEG";
30
31 constexpr uint32_t kNumInputs = 1;
32 constexpr uint32_t kInputTensor = 0;
33
34 constexpr uint32_t kNumOutputs = 1;
35 constexpr uint32_t kOutputTensor = 0;
36
37 namespace {
38
39 template <typename T>
compute(const T * input,const Shape & shape,T * output)40 inline bool compute(const T* input, const Shape& shape, T* output) {
41 const auto size = getNumberOfElements(shape);
42 for (uint32_t i = 0; i < size; ++i) {
43 output[i] = -input[i];
44 }
45 return true;
46 }
47
48 } // namespace
49
validate(const IOperationValidationContext * context)50 Result<Version> validate(const IOperationValidationContext* context) {
51 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
52 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
53 OperandType inputType = context->getInputType(kInputTensor);
54 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
55 inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_INT32)
56 << "Unsupported tensor type for operation " << kOperationName;
57 NN_RET_CHECK(validateInputTypes(context, {inputType}));
58 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
59 return Version::ANDROID_Q;
60 }
61
prepare(IOperationExecutionContext * context)62 bool prepare(IOperationExecutionContext* context) {
63 Shape input = context->getInputShape(kInputTensor);
64 Shape output = context->getOutputShape(kOutputTensor);
65 NN_RET_CHECK(SetShape(input, &output));
66 return context->setOutputShape(kOutputTensor, output);
67 }
68
execute(IOperationExecutionContext * context)69 bool execute(IOperationExecutionContext* context) {
70 switch (context->getInputType(kInputTensor)) {
71 case OperandType::TENSOR_FLOAT16:
72 return compute(context->getInputBuffer<_Float16>(kInputTensor),
73 context->getInputShape(kInputTensor),
74 context->getOutputBuffer<_Float16>(kOutputTensor));
75 case OperandType::TENSOR_FLOAT32:
76 return compute(context->getInputBuffer<float>(kInputTensor),
77 context->getInputShape(kInputTensor),
78 context->getOutputBuffer<float>(kOutputTensor));
79 case OperandType::TENSOR_INT32:
80 return compute(context->getInputBuffer<int32_t>(kInputTensor),
81 context->getInputShape(kInputTensor),
82 context->getOutputBuffer<int32_t>(kOutputTensor));
83 default:
84 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
85 }
86 }
87
88 } // namespace neg
89
90 NN_REGISTER_OPERATION(NEG, neg::kOperationName, neg::validate, neg::prepare, neg::execute);
91
92 } // namespace nn
93 } // namespace android
94