1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "crop_builder.h"
17
18 #include "mindir.h"
19 #include "ops_registry.h"
20
21 namespace OHOS {
22 namespace NeuralNetworkRuntime {
23 namespace Ops {
24 static const int INPUT_NUM = 2;
25 static const int OUTPUT_NUM = 1;
26 static const int PARAM_MAX_NUM = 2;
27 static const int SCALAR_LENGTH = 1;
28 static const std::string OP_NAME = "Crop";
29
CropBuilder()30 CropBuilder::CropBuilder() {}
31
~CropBuilder()32 CropBuilder::~CropBuilder() {}
33
SetAxis(const std::shared_ptr<NNTensor> & tensor)34 OH_NN_ReturnCode CropBuilder::SetAxis(const std::shared_ptr<NNTensor>& tensor)
35 {
36 if (tensor->GetDataType() != OH_NN_INT64) {
37 LOGE("[Crop] The axis should be type OH_NN_INT64.");
38 return OH_NN_INVALID_PARAMETER;
39 }
40
41 if (tensor->GetElementCount() != SCALAR_LENGTH) {
42 LOGE("[Crop] The axis should be scalar.");
43 return OH_NN_INVALID_PARAMETER;
44 }
45
46 void* buffer = tensor->GetBuffer();
47 if (buffer == nullptr) {
48 LOGE("[Crop] Tensor buffer is nullptr.");
49 return OH_NN_INVALID_PARAMETER;
50 }
51 m_axis = *static_cast<const int64_t*>(buffer);
52
53 return OH_NN_SUCCESS;
54 }
55
SetOffset(const std::shared_ptr<NNTensor> & tensor)56 OH_NN_ReturnCode CropBuilder::SetOffset(const std::shared_ptr<NNTensor>& tensor)
57 {
58 if (tensor->GetDataType() != OH_NN_INT64) {
59 LOGE("[Crop] The offset should be type OH_NN_INT64.");
60 return OH_NN_INVALID_PARAMETER;
61 }
62
63 m_offset.clear();
64
65 void* buffer = tensor->GetBuffer();
66 if (buffer == nullptr) {
67 LOGE("[Crop] Tensor buffer is nullptr.");
68 return OH_NN_INVALID_PARAMETER;
69 }
70
71 int64_t* pOffset = static_cast<int64_t*>(buffer);
72
73 uint32_t elementCount = tensor->GetElementCount();
74 for (uint32_t i = 0; i < elementCount; ++i) {
75 m_offset.emplace_back(*pOffset);
76 ++pOffset;
77 }
78 return OH_NN_SUCCESS;
79 }
80
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)81 OH_NN_ReturnCode CropBuilder::Build(const std::vector<uint32_t>& paramsIndex,
82 const std::vector<uint32_t>& inputsIndex,
83 const std::vector<uint32_t>& outputsIndex,
84 const std::vector<std::shared_ptr<NNTensor>>& allTensors)
85 {
86 if (m_isBuild) {
87 LOGE("[Crop] Build failed, the Crop operation has been build. cannot build again.");
88 return OH_NN_OPERATION_FORBIDDEN;
89 }
90
91 auto ret = CheckIOIndex(inputsIndex, outputsIndex, allTensors, INPUT_NUM, OUTPUT_NUM);
92 if (ret != OH_NN_SUCCESS) {
93 LOGE("[Crop] Build failed, passed invalid input or output index.");
94 return ret;
95 }
96
97 ret = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
98 if (ret != OH_NN_SUCCESS) {
99 LOGE("[Crop] Build failed, passed invalid param index.");
100 return ret;
101 }
102
103 m_inputsIndex = inputsIndex;
104 m_outputsIndex = outputsIndex;
105
106 for (int i : paramsIndex) {
107 std::shared_ptr<NNTensor> tensor = allTensors[i];
108 tensor->IdentifyOpParameter();
109 if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
110 ret = (this->*(m_paramMap[tensor->GetType()]))(tensor);
111 } else {
112 LOGE("[Crop] Build failed, param invalid, type=%d", tensor->GetType());
113 return OH_NN_INVALID_PARAMETER;
114 }
115
116 if (ret != OH_NN_SUCCESS) {
117 LOGE("[Crop] Build failed, passed invalid param.");
118 return ret;
119 }
120 }
121
122 m_name = OP_NAME;
123 m_isBuild = true;
124 return OH_NN_SUCCESS;
125 }
126
GetPrimitive()127 LiteGraphPrimitvePtr CropBuilder::GetPrimitive()
128 {
129 if (!m_isBuild) {
130 LOGE("[Crop] GetPrimitive failed, cannot get primitive before call build.");
131 return {nullptr, DestroyLiteGraphPrimitive};
132 }
133
134 void* primitive = mindspore::lite::MindIR_Crop_CreatePrimitive(m_axis, m_offset);
135 LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive) ;
136 return graphPrimitivePtr;
137 }
138
139 REGISTER_OPS(CropBuilder, OH_NN_OPS_CROP);
140 } // namespace Ops
141 } // namespace NeuralNetworkRuntime
142 } // namespace OHOS
143