1 /*
2 * Copyright (C) 2020 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <android-base/logging.h>
18 #include <nnapi/OperandTypes.h>
19 #include <nnapi/OperationTypes.h>
20 #include <nnapi/SharedMemory.h>
21 #include <nnapi/Types.h>
22
23 #include <algorithm>
24 #include <iterator>
25 #include <limits>
26 #include <memory>
27 #include <utility>
28 #include <vector>
29
30 #include "TestHarness.h"
31
32 namespace android::nn::test {
33 namespace {
34
35 using ::test_helper::TestBuffer;
36 using ::test_helper::TestModel;
37 using ::test_helper::TestOperand;
38 using ::test_helper::TestOperandLifeTime;
39 using ::test_helper::TestOperandType;
40 using ::test_helper::TestOperation;
41 using ::test_helper::TestSubgraph;
42
createOperand(const TestOperand & operand,Model::OperandValues * operandValues,ConstantMemoryBuilder * memoryBuilder)43 Operand createOperand(const TestOperand& operand, Model::OperandValues* operandValues,
44 ConstantMemoryBuilder* memoryBuilder) {
45 CHECK(operandValues != nullptr);
46 CHECK(memoryBuilder != nullptr);
47
48 const OperandType type = static_cast<OperandType>(operand.type);
49 const Operand::LifeTime lifetime = static_cast<Operand::LifeTime>(operand.lifetime);
50
51 DataLocation location;
52 switch (operand.lifetime) {
53 case TestOperandLifeTime::TEMPORARY_VARIABLE:
54 case TestOperandLifeTime::SUBGRAPH_INPUT:
55 case TestOperandLifeTime::SUBGRAPH_OUTPUT:
56 case TestOperandLifeTime::NO_VALUE:
57 break;
58 case TestOperandLifeTime::CONSTANT_COPY:
59 location = operandValues->append(operand.data.get<uint8_t>(), operand.data.size());
60 break;
61 case TestOperandLifeTime::CONSTANT_REFERENCE:
62 location = memoryBuilder->append(operand.data.get<void>(), operand.data.size());
63 break;
64 case TestOperandLifeTime::SUBGRAPH:
65 location = {.offset = *operand.data.get<uint32_t>()};
66 break;
67 }
68
69 Operand::ExtraParams extraParams;
70 if (operand.type == TestOperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
71 extraParams =
72 Operand::SymmPerChannelQuantParams{.scales = operand.channelQuant.scales,
73 .channelDim = operand.channelQuant.channelDim};
74 }
75
76 return {
77 .type = type,
78 .dimensions = operand.dimensions,
79 .scale = operand.scale,
80 .zeroPoint = operand.zeroPoint,
81 .lifetime = lifetime,
82 .location = location,
83 .extraParams = std::move(extraParams),
84 };
85 }
86
createSubgraph(const TestSubgraph & testSubgraph,Model::OperandValues * operandValues,ConstantMemoryBuilder * memoryBuilder)87 Model::Subgraph createSubgraph(const TestSubgraph& testSubgraph,
88 Model::OperandValues* operandValues,
89 ConstantMemoryBuilder* memoryBuilder) {
90 // Operands.
91 std::vector<Operand> operands;
92 operands.reserve(testSubgraph.operands.size());
93 std::transform(testSubgraph.operands.begin(), testSubgraph.operands.end(),
94 std::back_inserter(operands),
95 [operandValues, memoryBuilder](const TestOperand& operand) {
96 return createOperand(operand, operandValues, memoryBuilder);
97 });
98
99 // Operations.
100 std::vector<Operation> operations;
101 operations.reserve(testSubgraph.operations.size());
102 std::transform(testSubgraph.operations.begin(), testSubgraph.operations.end(),
103 std::back_inserter(operations), [](const TestOperation& op) -> Operation {
104 return {.type = static_cast<OperationType>(op.type),
105 .inputs = op.inputs,
106 .outputs = op.outputs};
107 });
108
109 return {.operands = std::move(operands),
110 .operations = std::move(operations),
111 .inputIndexes = testSubgraph.inputIndexes,
112 .outputIndexes = testSubgraph.outputIndexes};
113 }
114
115 } // namespace
116
createModel(const TestModel & testModel)117 Model createModel(const TestModel& testModel) {
118 Model::OperandValues operandValues;
119 ConstantMemoryBuilder memoryBuilder(0);
120
121 Model::Subgraph mainSubgraph = createSubgraph(testModel.main, &operandValues, &memoryBuilder);
122 std::vector<Model::Subgraph> refSubgraphs;
123 refSubgraphs.reserve(testModel.referenced.size());
124 std::transform(testModel.referenced.begin(), testModel.referenced.end(),
125 std::back_inserter(refSubgraphs),
126 [&operandValues, &memoryBuilder](const TestSubgraph& testSubgraph) {
127 return createSubgraph(testSubgraph, &operandValues, &memoryBuilder);
128 });
129
130 // Shared memory.
131 std::vector<SharedMemory> pools;
132 if (!memoryBuilder.empty()) {
133 pools.push_back(memoryBuilder.finish().value());
134 }
135
136 return {.main = std::move(mainSubgraph),
137 .referenced = std::move(refSubgraphs),
138 .operandValues = std::move(operandValues),
139 .pools = std::move(pools),
140 .relaxComputationFloat32toFloat16 = testModel.isRelaxed};
141 }
142
createRequest(const TestModel & testModel)143 Request createRequest(const TestModel& testModel) {
144 constexpr uint32_t kInputPoolIndex = 0;
145 constexpr uint32_t kOutputPoolIndex = 1;
146
147 // Model inputs.
148 std::vector<Request::Argument> inputs;
149 inputs.reserve(testModel.main.inputIndexes.size());
150 ConstantMemoryBuilder inputBuilder(kInputPoolIndex);
151 for (uint32_t operandIndex : testModel.main.inputIndexes) {
152 const auto& op = testModel.main.operands[operandIndex];
153 Request::Argument requestArgument;
154 if (op.data.size() == 0) {
155 // Omitted input.
156 requestArgument = {.lifetime = Request::Argument::LifeTime::NO_VALUE};
157 } else {
158 const DataLocation location = inputBuilder.append(op.data.get<void>(), op.data.size());
159 requestArgument = {.lifetime = Request::Argument::LifeTime::POOL,
160 .location = location,
161 .dimensions = op.dimensions};
162 }
163 inputs.push_back(std::move(requestArgument));
164 }
165
166 // Model outputs.
167 std::vector<Request::Argument> outputs;
168 outputs.reserve(testModel.main.outputIndexes.size());
169 MutableMemoryBuilder outputBuilder(kOutputPoolIndex);
170 for (uint32_t operandIndex : testModel.main.outputIndexes) {
171 const auto& op = testModel.main.operands[operandIndex];
172
173 // In the case of zero-sized output, we should at least provide a one-byte buffer.
174 // This is because zero-sized tensors are only supported internally to the driver, or
175 // reported in output shapes. It is illegal for the client to pre-specify a zero-sized
176 // tensor as model output. Otherwise, we will have two semantic conflicts:
177 // - "Zero dimension" conflicts with "unspecified dimension".
178 // - "Omitted operand buffer" conflicts with "zero-sized operand buffer".
179 size_t bufferSize = std::max<size_t>(op.data.size(), 1);
180
181 const DataLocation location = outputBuilder.append(bufferSize);
182 outputs.push_back({.lifetime = Request::Argument::LifeTime::POOL,
183 .location = location,
184 .dimensions = op.dimensions});
185 }
186
187 // Model pools.
188 auto inputMemory = inputBuilder.finish().value();
189 auto outputMemory = outputBuilder.finish().value();
190 std::vector<Request::MemoryPool> pools = {std::move(inputMemory), std::move(outputMemory)};
191
192 return {.inputs = std::move(inputs), .outputs = std::move(outputs), .pools = std::move(pools)};
193 }
194
195 } // namespace android::nn::test
196