1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "fullconnection_builder.h"
17 
18 #include "transform.h"
19 #include "validation.h"
20 
21 namespace OHOS {
22 namespace NeuralNetworkRuntime {
23 namespace Ops {
24 static constexpr int OUTPUT_NUM = 1;
25 static constexpr int PARAM_MAX_NUM = 4;
26 static constexpr int SCALAR_LENGTH = 1;
27 static const std::string OP_NAME = "FullConnection";
28 
FullConnectionBuilder()29 FullConnectionBuilder::FullConnectionBuilder() {}
30 
~FullConnectionBuilder()31 FullConnectionBuilder::~FullConnectionBuilder() {}
32 
SetFullConnectionInput(const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)33 OH_NN_ReturnCode FullConnectionBuilder::SetFullConnectionInput(const std::vector<uint32_t>& inputsIndex,
34                                                                const std::vector<uint32_t>& outputsIndex,
35                                                                const std::vector<std::shared_ptr<NNTensor>>& allTensors)
36 {
37     if (outputsIndex.size() != OUTPUT_NUM) {
38         LOGE("[FullConnection] SetFullConnectionInput failed, the index of outputs don't equal to %d.", OUTPUT_NUM);
39         return OH_NN_INVALID_PARAMETER;
40     }
41 
42     size_t allTensorsSize = allTensors.size();
43     bool isOverTensorSize = std::any_of(inputsIndex.begin(), inputsIndex.end(), [allTensorsSize](uint32_t index) {
44         return index >= allTensorsSize;
45     });
46     if (isOverTensorSize) {
47         LOGE("[FullConnection] SetFullConnectionInput failed, the index of inputs is out of range.");
48         return OH_NN_INVALID_PARAMETER;
49     }
50 
51     m_inputsIndex = inputsIndex;
52     m_outputsIndex = outputsIndex;
53 
54     return OH_NN_SUCCESS;
55 }
56 
SetHasBias(const std::shared_ptr<NNTensor> & tensor)57 OH_NN_ReturnCode FullConnectionBuilder::SetHasBias(const std::shared_ptr<NNTensor>& tensor)
58 {
59     if (tensor->GetDataType() != OH_NN_BOOL) {
60         LOGE("[FullConnection] The hasBias should be type OH_NN_BOOL.");
61         return OH_NN_INVALID_PARAMETER;
62     }
63 
64     if (tensor->GetElementCount() != SCALAR_LENGTH) {
65         LOGE("[FullConnection] The hasBias should be scalar.");
66         return OH_NN_INVALID_PARAMETER;
67     }
68 
69     void* buffer = tensor->GetBuffer();
70     if (buffer == nullptr) {
71         LOGE("[FullConnection] Tensor buffer is nullptr.");
72         return OH_NN_INVALID_PARAMETER;
73     }
74     m_hasBias = *(static_cast<bool*>(buffer));
75 
76     return OH_NN_SUCCESS;
77 }
78 
SetUseAxis(const std::shared_ptr<NNTensor> & tensor)79 OH_NN_ReturnCode FullConnectionBuilder::SetUseAxis(const std::shared_ptr<NNTensor>& tensor)
80 {
81     if (tensor->GetDataType() != OH_NN_BOOL) {
82         LOGE("[FullConnection] The useAxis should be type OH_NN_BOOL.");
83         return OH_NN_INVALID_PARAMETER;
84     }
85 
86     if (tensor->GetElementCount() != SCALAR_LENGTH) {
87         LOGE("[FullConnection] The useAxis should be scalar.");
88         return OH_NN_INVALID_PARAMETER;
89     }
90 
91     void* buffer = tensor->GetBuffer();
92     if (buffer == nullptr) {
93         LOGE("[FullConnection] Tensor buffer is nullptr.");
94         return OH_NN_INVALID_PARAMETER;
95     }
96 
97     bool useAxis = *(static_cast<bool*>(buffer));
98     if (m_axisIsSet && !useAxis) {
99         LOGE("[FullConnection] m_useAxis is not allowed to be set to false when m_axis is already set.");
100         return OH_NN_INVALID_PARAMETER;
101     }
102 
103     m_useAxis = useAxis;
104     m_useAxisIsSet = true;
105     return OH_NN_SUCCESS;
106 }
107 
SetFullConnectionActivation(const std::shared_ptr<NNTensor> & tensor)108 OH_NN_ReturnCode FullConnectionBuilder::SetFullConnectionActivation(const std::shared_ptr<NNTensor>& tensor)
109 {
110     tensor->IdentifyOpParameter();
111     // Set Activation
112     if (tensor->GetElementCount() != SCALAR_LENGTH) {
113         LOGE("[FullConnection] SetFullConnectionActivation failed, the Activation shoule be a scalar");
114         return OH_NN_INVALID_PARAMETER;
115     }
116 
117     if (tensor->GetDataType() != OH_NN_INT8) {
118         LOGE("[FullConnection] SetFullConnectionActivation failed, the Activation should have type OH_NN_INT8.");
119         return OH_NN_INVALID_PARAMETER;
120     }
121 
122     void* buffer = tensor->GetBuffer();
123     if (buffer == nullptr) {
124         LOGE("[FullConnection] SetFullConnectionActivation GetBuffer return nullptr");
125         return OH_NN_INVALID_PARAMETER;
126     }
127 
128     int8_t* pFuseData = static_cast<int8_t*>(tensor->GetBuffer());
129     if (!OHOS::NeuralNetworkRuntime::Validation::ValidateFuseType(static_cast<OH_NN_FuseType>(*pFuseData))) {
130         LOGE("[FullConnection] SetFullConnectionActivation failed, activation input is invalid.");
131         return OH_NN_INVALID_PARAMETER;
132     }
133     m_activationType = NNToMS::TransfromFusionType((OH_NN_FuseType)(*pFuseData));
134 
135     return OH_NN_SUCCESS;
136 }
137 
SetAxis(const std::shared_ptr<NNTensor> & tensor)138 OH_NN_ReturnCode FullConnectionBuilder::SetAxis(const std::shared_ptr<NNTensor>& tensor)
139 {
140     tensor->IdentifyOpParameter();
141 
142     if (tensor->GetElementCount() != SCALAR_LENGTH) {
143         LOGE("[FullConnection] SetAxis failed, the axis shoule be a scalar");
144         return OH_NN_INVALID_PARAMETER;
145     }
146 
147     if (tensor->GetDataType() != OH_NN_INT64) {
148         LOGE("[FullConnection] SetAxis failed, the Axis should be type OH_NN_INT64.");
149         return OH_NN_INVALID_PARAMETER;
150     }
151 
152     void* buffer = tensor->GetBuffer();
153     if (buffer == nullptr) {
154         LOGE("[FullConnection] SetAxis GetBuffer return nullptr");
155         return OH_NN_INVALID_PARAMETER;
156     }
157 
158     if (m_useAxisIsSet && !m_useAxis) {
159         LOGE("[FullConnection] m_useAxis has been set to false, axis is not allowed.");
160         return OH_NN_INVALID_PARAMETER;
161     }
162 
163     m_axis = *static_cast<int64_t*>(buffer);
164     m_useAxis = true;
165     m_axisIsSet = true;
166     return OH_NN_SUCCESS;
167 }
168 
169 
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)170 OH_NN_ReturnCode FullConnectionBuilder::Build(const std::vector<uint32_t>& paramsIndex,
171                                               const std::vector<uint32_t>& inputsIndex,
172                                               const std::vector<uint32_t>& outputsIndex,
173                                               const std::vector<std::shared_ptr<NNTensor>>& allTensors)
174 {
175     if (m_isBuild) {
176         LOGE("[FullConnection] Build failed, operation has been build, cannot build again.");
177         return OH_NN_OPERATION_FORBIDDEN;
178     }
179 
180     OH_NN_ReturnCode returnCode = SetFullConnectionInput(inputsIndex, outputsIndex, allTensors);
181     if (returnCode != OH_NN_SUCCESS) {
182         LOGE("[FullConnection] Build failed, SetFullConnectionInput failed.");
183         return returnCode;
184     }
185 
186     returnCode = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
187     if (returnCode != OH_NN_SUCCESS) {
188         LOGE("[FullConnection] Build failed, passed invalid param index.");
189         return returnCode;
190     }
191 
192     for (int i : paramsIndex) {
193         std::shared_ptr<NNTensor> tensor = allTensors[i]; // 参数 tensor
194         if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
195             returnCode = (this->*(m_paramMap[tensor->GetType()]))(tensor);
196         } else {
197             LOGE("[FullConnection] Build failed, param invalid, type=%d", tensor->GetType());
198             return OH_NN_INVALID_PARAMETER;
199         }
200 
201         if (returnCode != OH_NN_SUCCESS) {
202             LOGE("[FullConnection] Build failed, passed invalid param.");
203             return returnCode;
204         }
205     }
206 
207     // The quantization type of the first output determinies that of the operator.
208     SetQuantType(outputsIndex, allTensors);
209 
210     m_isBuild = true;
211     m_name = OP_NAME;
212     return OH_NN_SUCCESS;
213 }
214 
GetPrimitive()215 LiteGraphPrimitvePtr FullConnectionBuilder::GetPrimitive()
216 {
217     if (!m_isBuild) {
218         LOGE("[FullConnection] GetPrimitive failed, cannot get primitive before call build.");
219         return {nullptr, DestroyLiteGraphPrimitive};
220     }
221 
222     void* primitive = mindspore::lite::MindIR_FullConnection_CreatePrimitive(m_hasBias, m_useAxis,
223         m_axis, m_activationType);
224     LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive) ;
225     return graphPrimitivePtr;
226 }
227 
228 REGISTER_OPS(FullConnectionBuilder, OH_NN_OPS_FULL_CONNECTION);
229 } // namespace Ops
230 } // namespace NeuralNetworkRuntime
231 } // namespace OHOS