1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "argmax_builder.h"
17
18 namespace OHOS {
19 namespace NeuralNetworkRuntime {
20 namespace Ops {
21 static const int INPUT_NUM = 1;
22 static const int OUTPUT_NUM = 1;
23 static const int PARAM_MAX_NUM = 4;
24 static const std::string OP_NAME = "ArgMax";
25
ArgMaxBuilder()26 ArgMaxBuilder::ArgMaxBuilder() {}
27
~ArgMaxBuilder()28 ArgMaxBuilder::~ArgMaxBuilder() {}
29
SetAxis(const std::shared_ptr<NNTensor> & tensor)30 OH_NN_ReturnCode ArgMaxBuilder::SetAxis(const std::shared_ptr<NNTensor>& tensor)
31 {
32 tensor->IdentifyOpParameter();
33
34 if (tensor->GetDataType() != OH_NN_INT64) {
35 LOGE("[ArgMax] SetAxis failed, the axis should be type OH_NN_INT64.");
36 return OH_NN_INVALID_PARAMETER;
37 }
38
39 void* buffer = tensor->GetBuffer();
40 if (buffer == nullptr) {
41 LOGE("[ArgMax] SetAxis GetBuffer return nullptr.");
42 return OH_NN_INVALID_PARAMETER;
43 }
44
45 m_axis = *(static_cast<int64_t*>(buffer));
46 return OH_NN_SUCCESS;
47 }
48
SetTopK(const std::shared_ptr<NNTensor> & tensor)49 OH_NN_ReturnCode ArgMaxBuilder::SetTopK(const std::shared_ptr<NNTensor>& tensor)
50 {
51 tensor->IdentifyOpParameter();
52
53 if (tensor->GetDataType() != OH_NN_INT64) {
54 LOGE("[ArgMax] SetTopK failed, the topK should be type OH_NN_INT64.");
55 return OH_NN_INVALID_PARAMETER;
56 }
57
58 void* buffer = tensor->GetBuffer();
59 if (buffer == nullptr) {
60 LOGE("[ArgMax] SetTopK GetBuffer return nullptr.");
61 return OH_NN_INVALID_PARAMETER;
62 }
63
64 m_topK = *(static_cast<int64_t*>(buffer));
65 return OH_NN_SUCCESS;
66 }
67
SetKeepdims(const std::shared_ptr<NNTensor> & tensor)68 OH_NN_ReturnCode ArgMaxBuilder::SetKeepdims(const std::shared_ptr<NNTensor>& tensor)
69 {
70 tensor->IdentifyOpParameter();
71
72 if (tensor->GetDataType() != OH_NN_BOOL) {
73 LOGE("[ArgMax] SetKeepdims failed, the keep_dims should be type OH_NN_BOOL.");
74 return OH_NN_INVALID_PARAMETER;
75 }
76
77 void* buffer = tensor->GetBuffer();
78 if (buffer == nullptr) {
79 LOGE("[ArgMax] SetKeepdims GetBuffer return nullptr.");
80 return OH_NN_INVALID_PARAMETER;
81 }
82 m_keepDims = *(static_cast<bool*>(buffer));
83
84 return OH_NN_SUCCESS;
85 }
86
SetOutMaxValue(const std::shared_ptr<NNTensor> & tensor)87 OH_NN_ReturnCode ArgMaxBuilder::SetOutMaxValue(const std::shared_ptr<NNTensor>& tensor)
88 {
89 tensor->IdentifyOpParameter();
90
91 if (tensor->GetDataType() != OH_NN_BOOL) {
92 LOGE("[ArgMax] SetOutMaxValue failed, the outMaxValue should be type OH_NN_BOOL.");
93 return OH_NN_INVALID_PARAMETER;
94 }
95
96 void* buffer = tensor->GetBuffer();
97 if (buffer == nullptr) {
98 LOGE("[ArgMax] SetOutMaxValue GetBuffer return nullptr.");
99 return OH_NN_INVALID_PARAMETER;
100 }
101 m_outMaxValue = *(static_cast<bool*>(buffer));
102
103 return OH_NN_SUCCESS;
104 }
105
106 /**
107 * Build method.
108 * 1.build primitive of ops.
109 * 2.build inputIndex of ops.
110 * 3.build outputIndex of ops.
111 */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)112 OH_NN_ReturnCode ArgMaxBuilder::Build(const std::vector<uint32_t>& paramsIndex,
113 const std::vector<uint32_t>& inputsIndex, const std::vector<uint32_t>& outputsIndex,
114 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
115 {
116 if (m_isBuild) {
117 LOGE("[ArgMax] Build failed, build operation has been completed, cannot build again.");
118 return OH_NN_OPERATION_FORBIDDEN;
119 }
120
121 OH_NN_ReturnCode returnCode = CheckIOIndex(inputsIndex, outputsIndex, allTensors, INPUT_NUM, OUTPUT_NUM);
122 if (returnCode != OH_NN_SUCCESS) {
123 LOGE("[ArgMax] Build failed, passed invalid input or output index.");
124 return returnCode;
125 }
126 m_inputsIndex = inputsIndex;
127 m_outputsIndex = outputsIndex;
128
129 returnCode = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
130 if (returnCode != OH_NN_SUCCESS) {
131 LOGE("[ArgMax] Build failed, passed invalid param index.");
132 return returnCode;
133 }
134
135 for (int i : paramsIndex) {
136 const std::shared_ptr<NNTensor> tensor = allTensors[i];
137 if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
138 returnCode = (this->*(m_paramMap[tensor->GetType()]))(tensor);
139 } else {
140 LOGE("[ArgMax] Build failed, param invalid, type=%d", tensor->GetType());
141 return OH_NN_INVALID_PARAMETER;
142 }
143
144 if (returnCode != OH_NN_SUCCESS) {
145 LOGE("[ArgMax] Build failed, passed invalid param.");
146 return returnCode;
147 }
148 }
149
150 // The quantization type of the first output determinies that of the operator.
151 SetQuantType(outputsIndex, allTensors);
152
153 m_name = OP_NAME;
154 m_isBuild = true;
155 return OH_NN_SUCCESS;
156 }
157
GetPrimitive()158 LiteGraphPrimitvePtr ArgMaxBuilder::GetPrimitive()
159 {
160 if (!m_isBuild) {
161 LOGE("[ArgMax] GetPrimitive failed, cannot get primitive before call build.");
162 return {nullptr, DestroyLiteGraphPrimitive};
163 }
164
165 void* primitive = mindspore::lite::MindIR_ArgMaxFusion_CreatePrimitive(m_axis, m_topK, m_keepDims, m_outMaxValue);
166 LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
167 return graphPrimitivePtr;
168 }
169 REGISTER_OPS(ArgMaxBuilder, OH_NN_OPS_ARG_MAX);
170 } // namespace Ops
171 } // namespace NeuralNetworkRuntime
172 } // namespace OHOS
173