1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "argmax_builder.h"
17 
18 namespace OHOS {
19 namespace NeuralNetworkRuntime {
20 namespace Ops {
21 static const int INPUT_NUM = 1;
22 static const int OUTPUT_NUM = 1;
23 static const int PARAM_MAX_NUM = 4;
24 static const std::string OP_NAME = "ArgMax";
25 
ArgMaxBuilder()26 ArgMaxBuilder::ArgMaxBuilder() {}
27 
~ArgMaxBuilder()28 ArgMaxBuilder::~ArgMaxBuilder() {}
29 
SetAxis(const std::shared_ptr<NNTensor> & tensor)30 OH_NN_ReturnCode ArgMaxBuilder::SetAxis(const std::shared_ptr<NNTensor>& tensor)
31 {
32     tensor->IdentifyOpParameter();
33 
34     if (tensor->GetDataType() != OH_NN_INT64) {
35         LOGE("[ArgMax] SetAxis failed, the axis should be type OH_NN_INT64.");
36         return OH_NN_INVALID_PARAMETER;
37     }
38 
39     void* buffer = tensor->GetBuffer();
40     if (buffer == nullptr) {
41         LOGE("[ArgMax] SetAxis GetBuffer return nullptr.");
42         return OH_NN_INVALID_PARAMETER;
43     }
44 
45     m_axis = *(static_cast<int64_t*>(buffer));
46     return OH_NN_SUCCESS;
47 }
48 
SetTopK(const std::shared_ptr<NNTensor> & tensor)49 OH_NN_ReturnCode ArgMaxBuilder::SetTopK(const std::shared_ptr<NNTensor>& tensor)
50 {
51     tensor->IdentifyOpParameter();
52 
53     if (tensor->GetDataType() != OH_NN_INT64) {
54         LOGE("[ArgMax] SetTopK failed, the topK should be type OH_NN_INT64.");
55         return OH_NN_INVALID_PARAMETER;
56     }
57 
58     void* buffer = tensor->GetBuffer();
59     if (buffer == nullptr) {
60         LOGE("[ArgMax] SetTopK GetBuffer return nullptr.");
61         return OH_NN_INVALID_PARAMETER;
62     }
63 
64     m_topK = *(static_cast<int64_t*>(buffer));
65     return OH_NN_SUCCESS;
66 }
67 
SetKeepdims(const std::shared_ptr<NNTensor> & tensor)68 OH_NN_ReturnCode ArgMaxBuilder::SetKeepdims(const std::shared_ptr<NNTensor>& tensor)
69 {
70     tensor->IdentifyOpParameter();
71 
72     if (tensor->GetDataType() != OH_NN_BOOL) {
73         LOGE("[ArgMax] SetKeepdims failed, the keep_dims should be type OH_NN_BOOL.");
74         return OH_NN_INVALID_PARAMETER;
75     }
76 
77     void* buffer = tensor->GetBuffer();
78     if (buffer == nullptr) {
79         LOGE("[ArgMax] SetKeepdims GetBuffer return nullptr.");
80         return OH_NN_INVALID_PARAMETER;
81     }
82     m_keepDims = *(static_cast<bool*>(buffer));
83 
84     return OH_NN_SUCCESS;
85 }
86 
SetOutMaxValue(const std::shared_ptr<NNTensor> & tensor)87 OH_NN_ReturnCode ArgMaxBuilder::SetOutMaxValue(const std::shared_ptr<NNTensor>& tensor)
88 {
89     tensor->IdentifyOpParameter();
90 
91     if (tensor->GetDataType() != OH_NN_BOOL) {
92         LOGE("[ArgMax] SetOutMaxValue failed, the outMaxValue should be type OH_NN_BOOL.");
93         return OH_NN_INVALID_PARAMETER;
94     }
95 
96     void* buffer = tensor->GetBuffer();
97     if (buffer == nullptr) {
98         LOGE("[ArgMax] SetOutMaxValue GetBuffer return nullptr.");
99         return OH_NN_INVALID_PARAMETER;
100     }
101     m_outMaxValue = *(static_cast<bool*>(buffer));
102 
103     return OH_NN_SUCCESS;
104 }
105 
106 /**
107  * Build method.
108  * 1.build primitive of ops.
109  * 2.build inputIndex of ops.
110  * 3.build outputIndex of ops.
111  */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)112 OH_NN_ReturnCode ArgMaxBuilder::Build(const std::vector<uint32_t>& paramsIndex,
113     const std::vector<uint32_t>& inputsIndex, const std::vector<uint32_t>& outputsIndex,
114     const std::vector<std::shared_ptr<NNTensor>>& allTensors)
115 {
116     if (m_isBuild) {
117         LOGE("[ArgMax] Build failed, build operation has been completed, cannot build again.");
118         return OH_NN_OPERATION_FORBIDDEN;
119     }
120 
121     OH_NN_ReturnCode returnCode = CheckIOIndex(inputsIndex, outputsIndex, allTensors, INPUT_NUM, OUTPUT_NUM);
122     if (returnCode != OH_NN_SUCCESS) {
123         LOGE("[ArgMax] Build failed, passed invalid input or output index.");
124         return returnCode;
125     }
126     m_inputsIndex = inputsIndex;
127     m_outputsIndex = outputsIndex;
128 
129     returnCode = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
130     if (returnCode != OH_NN_SUCCESS) {
131         LOGE("[ArgMax] Build failed, passed invalid param index.");
132         return returnCode;
133     }
134 
135     for (int i : paramsIndex) {
136         const std::shared_ptr<NNTensor> tensor = allTensors[i];
137         if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
138             returnCode = (this->*(m_paramMap[tensor->GetType()]))(tensor);
139         } else {
140             LOGE("[ArgMax] Build failed, param invalid, type=%d", tensor->GetType());
141             return OH_NN_INVALID_PARAMETER;
142         }
143 
144         if (returnCode != OH_NN_SUCCESS) {
145             LOGE("[ArgMax] Build failed, passed invalid param.");
146             return returnCode;
147         }
148     }
149 
150     // The quantization type of the first output determinies that of the operator.
151     SetQuantType(outputsIndex, allTensors);
152 
153     m_name = OP_NAME;
154     m_isBuild = true;
155     return OH_NN_SUCCESS;
156 }
157 
GetPrimitive()158 LiteGraphPrimitvePtr ArgMaxBuilder::GetPrimitive()
159 {
160     if (!m_isBuild) {
161         LOGE("[ArgMax] GetPrimitive failed, cannot get primitive before call build.");
162         return {nullptr, DestroyLiteGraphPrimitive};
163     }
164 
165     void* primitive = mindspore::lite::MindIR_ArgMaxFusion_CreatePrimitive(m_axis, m_topK, m_keepDims, m_outMaxValue);
166     LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
167     return graphPrimitivePtr;
168 }
169 REGISTER_OPS(ArgMaxBuilder, OH_NN_OPS_ARG_MAX);
170 } // namespace Ops
171 } // namespace NeuralNetworkRuntime
172 } // namespace OHOS
173