1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "clip_builder.h"
17
18 namespace OHOS {
19 namespace NeuralNetworkRuntime {
20 namespace Ops {
21 static const int INPUT_NUM = 1;
22 static const int OUTPUT_NUM = 1;
23 static const int PARAM_MAX_NUM = 2;
24 static constexpr int SCALAR_LENGTH = 1;
25 static const std::string OP_NAME = "Clip";
26
ClipBuilder()27 ClipBuilder::ClipBuilder() {}
28
~ClipBuilder()29 ClipBuilder::~ClipBuilder() {}
30
SetMax(const std::shared_ptr<NNTensor> & tensor)31 OH_NN_ReturnCode ClipBuilder::SetMax(const std::shared_ptr<NNTensor>& tensor)
32 {
33 if (tensor->GetDataType() != OH_NN_FLOAT32) {
34 LOGE("[Clip] The max should be type OH_NN_FLOAT32.");
35 return OH_NN_INVALID_PARAMETER;
36 }
37
38 if (tensor->GetElementCount() != SCALAR_LENGTH) {
39 LOGE("[Clip] The max should be scalar.");
40 return OH_NN_INVALID_PARAMETER;
41 }
42
43 void* buffer = tensor->GetBuffer();
44 if (buffer == nullptr) {
45 LOGE("[Clip] Tensor buffer is nullptr.");
46 return OH_NN_INVALID_PARAMETER;
47 }
48 m_max = *(static_cast<const float*>(buffer));
49
50 return OH_NN_SUCCESS;
51 }
52
SetMin(const std::shared_ptr<NNTensor> & tensor)53 OH_NN_ReturnCode ClipBuilder::SetMin(const std::shared_ptr<NNTensor>& tensor)
54 {
55 if (tensor->GetDataType() != OH_NN_FLOAT32) {
56 LOGE("[Clip] The min should be type OH_NN_FLOAT32.");
57 return OH_NN_INVALID_PARAMETER;
58 }
59
60 if (tensor->GetElementCount() != SCALAR_LENGTH) {
61 LOGE("[Clip] The min should be scalar.");
62 return OH_NN_INVALID_PARAMETER;
63 }
64
65 void* buffer = tensor->GetBuffer();
66 if (buffer == nullptr) {
67 LOGE("[Clip] Tensor buffer is nullptr.");
68 return OH_NN_INVALID_PARAMETER;
69 }
70 m_min = *(static_cast<const float*>(buffer));
71
72 return OH_NN_SUCCESS;
73 }
74
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)75 OH_NN_ReturnCode ClipBuilder::Build(const std::vector<uint32_t>& paramsIndex,
76 const std::vector<uint32_t>& inputsIndex,
77 const std::vector<uint32_t>& outputsIndex,
78 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
79 {
80 if (m_isBuild) {
81 LOGE("[Clip] Build failed, the clip operation has been build. cannot build again.");
82 return OH_NN_OPERATION_FORBIDDEN;
83 }
84
85 auto ret = CheckIOIndex(inputsIndex, outputsIndex, allTensors, INPUT_NUM, OUTPUT_NUM);
86 if (ret != OH_NN_SUCCESS) {
87 LOGE("[Clip] Build failed, passed invalid input or output index.");
88 return ret;
89 }
90
91 m_inputsIndex = inputsIndex;
92 m_outputsIndex = outputsIndex;
93
94 ret = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
95 if (ret != OH_NN_SUCCESS) {
96 LOGE("[Clip] Build failed, passed invalid param index.");
97 return ret;
98 }
99
100 for (int i : paramsIndex) {
101 std::shared_ptr<NNTensor> tensor = allTensors[i];
102 tensor->IdentifyOpParameter();
103 if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
104 ret = (this->*(m_paramMap[tensor->GetType()]))(tensor);
105 } else {
106 LOGE("[Clip] Build failed, param invalid, type=%d", tensor->GetType());
107 return OH_NN_INVALID_PARAMETER;
108 }
109
110 if (ret != OH_NN_SUCCESS) {
111 LOGE("[Clip] Build failed, passed invalid param.");
112 return ret;
113 }
114 }
115
116 m_name = OP_NAME;
117 m_isBuild = true;
118 return OH_NN_SUCCESS;
119 }
120
GetPrimitive()121 LiteGraphPrimitvePtr ClipBuilder::GetPrimitive()
122 {
123 if (!m_isBuild) {
124 LOGE("[Clip] GetPrimitive failed, cannot get primitive before call build.");
125 return {nullptr, DestroyLiteGraphPrimitive};
126 }
127
128 void* primitive = mindspore::lite::MindIR_Clip_CreatePrimitive(m_max, m_min);
129 LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive) ;
130 return graphPrimitivePtr;
131 }
132
133 REGISTER_OPS(ClipBuilder, OH_NN_OPS_CLIP);
134 } // namespace Ops
135 } // namespace NeuralNetworkRuntime
136 } // namespace OHOS