1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "crop_builder.h"
17 
18 #include "mindir.h"
19 #include "ops_registry.h"
20 
21 namespace OHOS {
22 namespace NeuralNetworkRuntime {
23 namespace Ops {
24 static const int INPUT_NUM = 2;
25 static const int OUTPUT_NUM = 1;
26 static const int PARAM_MAX_NUM = 2;
27 static const int SCALAR_LENGTH = 1;
28 static const std::string OP_NAME = "Crop";
29 
CropBuilder()30 CropBuilder::CropBuilder() {}
31 
~CropBuilder()32 CropBuilder::~CropBuilder() {}
33 
SetAxis(const std::shared_ptr<NNTensor> & tensor)34 OH_NN_ReturnCode CropBuilder::SetAxis(const std::shared_ptr<NNTensor>& tensor)
35 {
36     if (tensor->GetDataType() != OH_NN_INT64) {
37         LOGE("[Crop] The axis should be type OH_NN_INT64.");
38         return OH_NN_INVALID_PARAMETER;
39     }
40 
41     if (tensor->GetElementCount() != SCALAR_LENGTH) {
42         LOGE("[Crop] The axis should be scalar.");
43         return OH_NN_INVALID_PARAMETER;
44     }
45 
46     void* buffer = tensor->GetBuffer();
47     if (buffer == nullptr) {
48         LOGE("[Crop] Tensor buffer is nullptr.");
49         return OH_NN_INVALID_PARAMETER;
50     }
51     m_axis = *static_cast<const int64_t*>(buffer);
52 
53     return OH_NN_SUCCESS;
54 }
55 
SetOffset(const std::shared_ptr<NNTensor> & tensor)56 OH_NN_ReturnCode CropBuilder::SetOffset(const std::shared_ptr<NNTensor>& tensor)
57 {
58     if (tensor->GetDataType() != OH_NN_INT64) {
59         LOGE("[Crop] The offset should be type OH_NN_INT64.");
60         return OH_NN_INVALID_PARAMETER;
61     }
62 
63     m_offset.clear();
64 
65     void* buffer = tensor->GetBuffer();
66     if (buffer == nullptr) {
67         LOGE("[Crop] Tensor buffer is nullptr.");
68         return OH_NN_INVALID_PARAMETER;
69     }
70 
71     int64_t* pOffset = static_cast<int64_t*>(buffer);
72 
73     uint32_t elementCount = tensor->GetElementCount();
74     for (uint32_t i = 0; i < elementCount; ++i) {
75         m_offset.emplace_back(*pOffset);
76         ++pOffset;
77     }
78     return OH_NN_SUCCESS;
79 }
80 
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)81 OH_NN_ReturnCode CropBuilder::Build(const std::vector<uint32_t>& paramsIndex,
82                                     const std::vector<uint32_t>& inputsIndex,
83                                     const std::vector<uint32_t>& outputsIndex,
84                                     const std::vector<std::shared_ptr<NNTensor>>& allTensors)
85 {
86     if (m_isBuild) {
87         LOGE("[Crop] Build failed, the Crop operation has been build. cannot build again.");
88         return OH_NN_OPERATION_FORBIDDEN;
89     }
90 
91     auto ret = CheckIOIndex(inputsIndex, outputsIndex, allTensors, INPUT_NUM, OUTPUT_NUM);
92     if (ret != OH_NN_SUCCESS) {
93         LOGE("[Crop] Build failed, passed invalid input or output index.");
94         return ret;
95     }
96 
97     ret = CheckParamIndex(paramsIndex, allTensors, PARAM_MAX_NUM);
98     if (ret != OH_NN_SUCCESS) {
99         LOGE("[Crop] Build failed, passed invalid param index.");
100         return ret;
101     }
102 
103     m_inputsIndex = inputsIndex;
104     m_outputsIndex = outputsIndex;
105 
106     for (int i : paramsIndex) {
107         std::shared_ptr<NNTensor> tensor = allTensors[i];
108         tensor->IdentifyOpParameter();
109         if (m_paramMap.find(tensor->GetType()) != m_paramMap.end()) {
110             ret = (this->*(m_paramMap[tensor->GetType()]))(tensor);
111         } else {
112             LOGE("[Crop] Build failed, param invalid, type=%d", tensor->GetType());
113             return OH_NN_INVALID_PARAMETER;
114         }
115 
116         if (ret != OH_NN_SUCCESS) {
117             LOGE("[Crop] Build failed, passed invalid param.");
118             return ret;
119         }
120     }
121 
122     m_name = OP_NAME;
123     m_isBuild = true;
124     return OH_NN_SUCCESS;
125 }
126 
GetPrimitive()127 LiteGraphPrimitvePtr CropBuilder::GetPrimitive()
128 {
129     if (!m_isBuild) {
130         LOGE("[Crop] GetPrimitive failed, cannot get primitive before call build.");
131         return {nullptr, DestroyLiteGraphPrimitive};
132     }
133 
134     void* primitive = mindspore::lite::MindIR_Crop_CreatePrimitive(m_axis, m_offset);
135     LiteGraphPrimitvePtr graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive) ;
136     return graphPrimitivePtr;
137 }
138 
139 REGISTER_OPS(CropBuilder, OH_NN_OPS_CROP);
140 } // namespace Ops
141 } // namespace NeuralNetworkRuntime
142 } // namespace OHOS
143