1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "transpose_builder.h"
17 
18 #include "mindir.h"
19 
20 namespace OHOS {
21 namespace NeuralNetworkRuntime {
22 namespace Ops {
23 static const int INPUT_NUM = 2;
24 static const int OUTPUT_NUM = 1;
25 static const int PARAM_NUM = 0;
26 static const std::string OP_NAME = "Transpose";
27 
TransposeBuilder()28 TransposeBuilder::TransposeBuilder() {}
29 
~TransposeBuilder()30 TransposeBuilder::~TransposeBuilder() {}
31 
32 /**
33  * Build method.
34  * 1.set attr of ops.
35  * 2.set inputIndex of ops.
36  * 3.set outputIndex of ops.
37  */
Build(const std::vector<uint32_t> & paramsIndex,const std::vector<uint32_t> & inputsIndex,const std::vector<uint32_t> & outputsIndex,const std::vector<std::shared_ptr<NNTensor>> & allTensors)38 OH_NN_ReturnCode TransposeBuilder::Build(const std::vector<uint32_t>& paramsIndex,
39                                          const std::vector<uint32_t>& inputsIndex,
40                                          const std::vector<uint32_t>& outputsIndex,
41                                          const std::vector<std::shared_ptr<NNTensor>>& allTensors)
42 {
43     if (m_isBuild) {
44         LOGE("[TransposeBuilder] Transpose operation has been build, cannot build again.");
45         return OH_NN_OPERATION_FORBIDDEN;
46     }
47 
48     OH_NN_ReturnCode returnCode = CheckIOIndex(inputsIndex, outputsIndex, allTensors, INPUT_NUM, OUTPUT_NUM);
49     if (returnCode != OH_NN_SUCCESS) {
50         LOGE("[TransposeBuilder] Passed invalid input or output index.");
51         return returnCode;
52     }
53 
54     m_inputsIndex = inputsIndex;
55     m_outputsIndex = outputsIndex;
56 
57     returnCode = CheckParamIndex(paramsIndex, allTensors, PARAM_NUM);
58     if (returnCode != OH_NN_SUCCESS) {
59         LOGE("[TransposeBuilder] Passed invalid param index.");
60         return returnCode;
61     }
62 
63     m_isBuild = true;
64     m_name = OP_NAME;
65     return OH_NN_SUCCESS;
66 }
67 
GetPrimitive()68 LiteGraphPrimitvePtr TransposeBuilder::GetPrimitive()
69 {
70     if (!m_isBuild) {
71         LOGE("[TransposeBuilder] Cannot get primitive before call build.");
72         return {nullptr, DestroyLiteGraphPrimitive};
73     }
74 
75     auto primitive = mindspore::lite::MindIR_Transpose_CreatePrimitive();
76     if (primitive == nullptr) {
77         LOGE("[TransposeBuilder] MindIR_Transpose_CreatePrimitive failed.");
78         return {nullptr, DestroyLiteGraphPrimitive};
79     }
80 
81     LiteGraphPrimitvePtr  graphPrimitivePtr(primitive, DestroyLiteGraphPrimitive);
82     return graphPrimitivePtr;
83 }
84 
85 REGISTER_OPS(TransposeBuilder, OH_NN_OPS_TRANSPOSE);
86 } // namespace Ops
87 } // namespace NeuralNetworkRuntime
88 } // namespace OHOS
89