1 /*
2 * Copyright (c) 2022 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "fullconnection_builder.h"
17
18 #include "transform.h"
19 #include "validation.h"
20
21 namespace OHOS {
22 namespace NeuralNetworkRuntime {
23 namespace Ops {
24 static constexpr int OUTPUT_NUM = 1;
25 static constexpr int PARAM_MAX_NUM = 4;
26 static constexpr int SCALAR_LENGTH = 1;
27 static const std::string OP_NAME = "FullConnection";
28
FullConnectionBuilder()29 FullConnectionBuilder::FullConnectionBuilder() {}
30
~FullConnectionBuilder()31 FullConnectionBuilder::~FullConnectionBuilder() {}
32
SetFullConnectionInput(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)33 OH_NN_ReturnCode FullConnectionBuilder::SetFullConnectionInput(const std::vector<uint32_t>& inputsIndex,
34 const std::vector<uint32_t>& outputsIndex,
35 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
36 {
37 if (outputsIndex.size() != OUTPUT_NUM) {
38 LOGE("[FullConnection] SetFullConnectionInput failed, the index of outputs don't equal to %d.", OUTPUT_NUM);
39 return OH_NN_INVALID_PARAMETER;
40 }
41
42 size_t allTensorsSize = allTensors.size();
43 bool isOverTensorSize = std::any_of(inputsIndex.begin(), inputsIndex.end(), [allTensorsSize](uint32_t index) {
44 return index >= allTensorsSize;
45 });
46 if (isOverTensorSize) {
47 LOGE("[FullConnection] SetFullConnectionInput failed, the index of inputs is out of range.");
48 return OH_NN_INVALID_PARAMETER;
49 }
50
51 m_inputsIndex = inputsIndex;
52 m_outputsIndex = outputsIndex;
53
54 return OH_NN_SUCCESS;
55 }
56
SetHasBias(const std::shared_ptr<NNTensor> & tensor)57 OH_NN_ReturnCode FullConnectionBuilder::SetHasBias(const std::shared_ptr<NNTensor>& tensor)
58 {
59 if (tensor->GetDataType() != OH_NN_BOOL) {
60 LOGE("[FullConnection] The hasBias should be type OH_NN_BOOL.");
61 return OH_NN_INVALID_PARAMETER;
62 }
63
64 if (tensor->GetElementCount() != SCALAR_LENGTH) {
65 LOGE("[FullConnection] The hasBias should be scalar.");
66 return OH_NN_INVALID_PARAMETER;
67 }
68
69 void* buffer = tensor->GetBuffer();
70 if (buffer == nullptr) {
71 LOGE("[FullConnection] Tensor buffer is nullptr.");
72 return OH_NN_INVALID_PARAMETER;
73 }
74 m_hasBias = *(static_cast<bool*>(buffer));
75
76 return OH_NN_SUCCESS;
77 }
78
SetUseAxis(const std::shared_ptr<NNTensor> & tensor)79 OH_NN_ReturnCode FullConnectionBuilder::SetUseAxis(const std::shared_ptr<NNTensor>& tensor)
80 {
81 if (tensor->GetDataType() != OH_NN_BOOL) {
82 LOGE("[FullConnection] The useAxis should be type OH_NN_BOOL.");
83 return OH_NN_INVALID_PARAMETER;
84 }
85
86 if (tensor->GetElementCount() != SCALAR_LENGTH) {
87 LOGE("[FullConnection] The useAxis should be scalar.");
88 return OH_NN_INVALID_PARAMETER;
89 }
90
91 void* buffer = tensor->GetBuffer();
92 if (buffer == nullptr) {
93 LOGE("[FullConnection] Tensor buffer is nullptr.");
94 return OH_NN_INVALID_PARAMETER;
95 }
96
97 bool useAxis = *(static_cast<bool*>(buffer));
98 if (m_axisIsSet && !useAxis) {
99 LOGE("[FullConnection] m_useAxis is not allowed to be set to false when m_axis is already set.");
100 return OH_NN_INVALID_PARAMETER;
101 }
102
103 m_useAxis = useAxis;
104 m_useAxisIsSet = true;
105 return OH_NN_SUCCESS;
106 }
107
SetFullConnectionActivation(const std::shared_ptr<NNTensor> & tensor)108 OH_NN_ReturnCode FullConnectionBuilder::SetFullConnectionActivation(const std::shared_ptr<NNTensor>& tensor)
109 {
110 tensor->IdentifyOpParameter();
111 // Set Activation
112 if (tensor->GetElementCount() != SCALAR_LENGTH) {
113 LOGE("[FullConnection] SetFullConnectionActivation failed, the Activation shoule be a scalar");
114 return OH_NN_INVALID_PARAMETER;
115 }
116
117 if (tensor->GetDataType() != OH_NN_INT8) {
118 LOGE("[FullConnection] SetFullConnectionActivation failed, the Activation should have type OH_NN_INT8.");
119 return OH_NN_INVALID_PARAMETER;
120 }
121
122 void* buffer = tensor->GetBuffer();
123 if (buffer == nullptr) {
124 LOGE("[FullConnection] SetFullConnectionActivation GetBuffer return nullptr");
125 return OH_NN_INVALID_PARAMETER;
126 }
127
128 int8_t* pFuseData = static_cast<int8_t*>(tensor->GetBuffer());
129 if (!OHOS::NeuralNetworkRuntime::Validation::ValidateFuseType(static_cast<OH_NN_FuseType>(*pFuseData))) {
130 LOGE("[FullConnection] SetFullConnectionActivation failed, activation input is invalid.");
131 return OH_NN_INVALID_PARAMETER;
132 }
133 m_activationType = NNToMS::TransfromFusionType((OH_NN_FuseType)(*pFuseData));
134
135 return OH_NN_SUCCESS;
136 }
137
SetAxis(const std::shared_ptr<NNTensor> & tensor)138 OH_NN_ReturnCode FullConnectionBuilder::SetAxis(const std::shared_ptr<NNTensor>& tensor)
139 {
140 tensor->IdentifyOpParameter();
141
142 if (tensor->GetElementCount() != SCALAR_LENGTH) {
143 LOGE("[FullConnection] SetAxis failed, the axis shoule be a scalar");
144 return OH_NN_INVALID_PARAMETER;
145 }
146
147 if (tensor->GetDataType() != OH_NN_INT64) {
148 LOGE("[FullConnection] SetAxis failed, the Axis should be type OH_NN_INT64.");
149 return OH_NN_INVALID_PARAMETER;
150 }
151
152 void* buffer = tensor->GetBuffer();
153 if (buffer == nullptr) {
154 LOGE("[FullConnection] SetAxis GetBuffer return nullptr");
155 return OH_NN_INVALID_PARAMETER;
156 }
157
158 if (m_useAxisIsSet && !m_useAxis) {
159 LOGE("[FullConnection] m_useAxis has been set to false, axis is not allowed.");
160 return OH_NN_INVALID_PARAMETER;
161 }
162
163 m_axis = *static_cast<int64_t*>(buffer);
164 m_useAxis = true;
165 m_axisIsSet = true;
166 return OH_NN_SUCCESS;
167 }
168
169
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)170 OH_NN_ReturnCode FullConnectionBuilder::Build(const std::vector<uint32_t>& paramsIndex,
171 const std::vector<uint32_t>& inputsIndex,
172 const std::vector<uint32_t>& outputsIndex,
173 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
174 {
175 if (m_isBuild) {
176 LOGE("[FullConnection] Build failed, operation has been build, cannot build again.");
177 return OH_NN_OPERATION_FORBIDDEN;
178 }
179
180 OH_NN_ReturnCode returnCode = SetFullConnectionInput(inputsIndex, outputsIndex, allTensors);
181 if (returnCode != OH_NN_SUCCESS) {
182 LOGE("[FullConnection] Build failed, SetFullConnectionInput failed.");
183 return returnCode;
184 }
185
186 returnCode = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
187 if (returnCode != OH_NN_SUCCESS) {
188 LOGE("[FullConnection] Build failed, passed invalid param index.");
189 return returnCode;
190 }
191
192 for (int i : paramsIndex) {
193 std::shared_ptr<NNTensor> tensor = allTensors[i]; // 参数 tensor
194 if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
195 returnCode = (this->*(m_paramMap[tensor->GetType()]))(tensor);
196 } else {
197 LOGE("[FullConnection] Build failed, param invalid, type=%d", tensor->GetType());
198 return OH_NN_INVALID_PARAMETER;
199 }
200
201 if (returnCode != OH_NN_SUCCESS) {
202 LOGE("[FullConnection] Build failed, passed invalid param.");
203 return returnCode;
204 }
205 }
206
207 // The quantization type of the first output determinies that of the operator.
208 SetQuantType(outputsIndex, allTensors);
209
210 m_isBuild = true;
211 m_name = OP_NAME;
212 return OH_NN_SUCCESS;
213 }
214
GetPrimitive()215 LiteGraphPrimitvePtr FullConnectionBuilder::GetPrimitive()
216 {
217 if (!m_isBuild) {
218 LOGE("[FullConnection] GetPrimitive failed, cannot get primitive before call build.");
219 return {nullptr, DestroyLiteGraphPrimitive};
220 }
221
222 void* primitive = mindspore::lite::MindIR_FullConnection_CreatePrimitive(m_hasBias, m_useAxis,
223 m_axis, m_activationType);
224 LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive) ;
225 return graphPrimitivePtr;
226 }
227
228 REGISTER_OPS(FullConnectionBuilder, OH_NN_OPS_FULL_CONNECTION);
229 } // namespace Ops
230 } // namespace NeuralNetworkRuntime
231 } // namespace OHOS