1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <android-base/logging.h>
18 #include <nnapi/OperandTypes.h>
19 #include <nnapi/OperationTypes.h>
20 #include <nnapi/SharedMemory.h>
21 #include <nnapi/Types.h>
22 
23 #include <algorithm>
24 #include <iterator>
25 #include <limits>
26 #include <memory>
27 #include <utility>
28 #include <vector>
29 
30 #include "TestHarness.h"
31 
32 namespace android::nn::test {
33 namespace {
34 
35 using ::test_helper::TestBuffer;
36 using ::test_helper::TestModel;
37 using ::test_helper::TestOperand;
38 using ::test_helper::TestOperandLifeTime;
39 using ::test_helper::TestOperandType;
40 using ::test_helper::TestOperation;
41 using ::test_helper::TestSubgraph;
42 
createOperand(const TestOperand & operand,Model::OperandValues * operandValues,ConstantMemoryBuilder * memoryBuilder)43 Operand createOperand(const TestOperand& operand, Model::OperandValues* operandValues,
44                       ConstantMemoryBuilder* memoryBuilder) {
45     CHECK(operandValues != nullptr);
46     CHECK(memoryBuilder != nullptr);
47 
48     const OperandType type = static_cast<OperandType>(operand.type);
49     const Operand::LifeTime lifetime = static_cast<Operand::LifeTime>(operand.lifetime);
50 
51     DataLocation location;
52     switch (operand.lifetime) {
53         case TestOperandLifeTime::TEMPORARY_VARIABLE:
54         case TestOperandLifeTime::SUBGRAPH_INPUT:
55         case TestOperandLifeTime::SUBGRAPH_OUTPUT:
56         case TestOperandLifeTime::NO_VALUE:
57             break;
58         case TestOperandLifeTime::CONSTANT_COPY:
59             location = operandValues->append(operand.data.get<uint8_t>(), operand.data.size());
60             break;
61         case TestOperandLifeTime::CONSTANT_REFERENCE:
62             location = memoryBuilder->append(operand.data.get<void>(), operand.data.size());
63             break;
64         case TestOperandLifeTime::SUBGRAPH:
65             location = {.offset = *operand.data.get<uint32_t>()};
66             break;
67     }
68 
69     Operand::ExtraParams extraParams;
70     if (operand.type == TestOperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
71         extraParams =
72                 Operand::SymmPerChannelQuantParams{.scales = operand.channelQuant.scales,
73                                                    .channelDim = operand.channelQuant.channelDim};
74     }
75 
76     return {
77             .type = type,
78             .dimensions = operand.dimensions,
79             .scale = operand.scale,
80             .zeroPoint = operand.zeroPoint,
81             .lifetime = lifetime,
82             .location = location,
83             .extraParams = std::move(extraParams),
84     };
85 }
86 
createSubgraph(const TestSubgraph & testSubgraph,Model::OperandValues * operandValues,ConstantMemoryBuilder * memoryBuilder)87 Model::Subgraph createSubgraph(const TestSubgraph& testSubgraph,
88                                Model::OperandValues* operandValues,
89                                ConstantMemoryBuilder* memoryBuilder) {
90     // Operands.
91     std::vector<Operand> operands;
92     operands.reserve(testSubgraph.operands.size());
93     std::transform(testSubgraph.operands.begin(), testSubgraph.operands.end(),
94                    std::back_inserter(operands),
95                    [operandValues, memoryBuilder](const TestOperand& operand) {
96                        return createOperand(operand, operandValues, memoryBuilder);
97                    });
98 
99     // Operations.
100     std::vector<Operation> operations;
101     operations.reserve(testSubgraph.operations.size());
102     std::transform(testSubgraph.operations.begin(), testSubgraph.operations.end(),
103                    std::back_inserter(operations), [](const TestOperation& op) -> Operation {
104                        return {.type = static_cast<OperationType>(op.type),
105                                .inputs = op.inputs,
106                                .outputs = op.outputs};
107                    });
108 
109     return {.operands = std::move(operands),
110             .operations = std::move(operations),
111             .inputIndexes = testSubgraph.inputIndexes,
112             .outputIndexes = testSubgraph.outputIndexes};
113 }
114 
115 }  // namespace
116 
createModel(const TestModel & testModel)117 Model createModel(const TestModel& testModel) {
118     Model::OperandValues operandValues;
119     ConstantMemoryBuilder memoryBuilder(0);
120 
121     Model::Subgraph mainSubgraph = createSubgraph(testModel.main, &operandValues, &memoryBuilder);
122     std::vector<Model::Subgraph> refSubgraphs;
123     refSubgraphs.reserve(testModel.referenced.size());
124     std::transform(testModel.referenced.begin(), testModel.referenced.end(),
125                    std::back_inserter(refSubgraphs),
126                    [&operandValues, &memoryBuilder](const TestSubgraph& testSubgraph) {
127                        return createSubgraph(testSubgraph, &operandValues, &memoryBuilder);
128                    });
129 
130     // Shared memory.
131     std::vector<SharedMemory> pools;
132     if (!memoryBuilder.empty()) {
133         pools.push_back(memoryBuilder.finish().value());
134     }
135 
136     return {.main = std::move(mainSubgraph),
137             .referenced = std::move(refSubgraphs),
138             .operandValues = std::move(operandValues),
139             .pools = std::move(pools),
140             .relaxComputationFloat32toFloat16 = testModel.isRelaxed};
141 }
142 
createRequest(const TestModel & testModel)143 Request createRequest(const TestModel& testModel) {
144     constexpr uint32_t kInputPoolIndex = 0;
145     constexpr uint32_t kOutputPoolIndex = 1;
146 
147     // Model inputs.
148     std::vector<Request::Argument> inputs;
149     inputs.reserve(testModel.main.inputIndexes.size());
150     ConstantMemoryBuilder inputBuilder(kInputPoolIndex);
151     for (uint32_t operandIndex : testModel.main.inputIndexes) {
152         const auto& op = testModel.main.operands[operandIndex];
153         Request::Argument requestArgument;
154         if (op.data.size() == 0) {
155             // Omitted input.
156             requestArgument = {.lifetime = Request::Argument::LifeTime::NO_VALUE};
157         } else {
158             const DataLocation location = inputBuilder.append(op.data.get<void>(), op.data.size());
159             requestArgument = {.lifetime = Request::Argument::LifeTime::POOL,
160                                .location = location,
161                                .dimensions = op.dimensions};
162         }
163         inputs.push_back(std::move(requestArgument));
164     }
165 
166     // Model outputs.
167     std::vector<Request::Argument> outputs;
168     outputs.reserve(testModel.main.outputIndexes.size());
169     MutableMemoryBuilder outputBuilder(kOutputPoolIndex);
170     for (uint32_t operandIndex : testModel.main.outputIndexes) {
171         const auto& op = testModel.main.operands[operandIndex];
172 
173         // In the case of zero-sized output, we should at least provide a one-byte buffer.
174         // This is because zero-sized tensors are only supported internally to the driver, or
175         // reported in output shapes. It is illegal for the client to pre-specify a zero-sized
176         // tensor as model output. Otherwise, we will have two semantic conflicts:
177         // - "Zero dimension" conflicts with "unspecified dimension".
178         // - "Omitted operand buffer" conflicts with "zero-sized operand buffer".
179         size_t bufferSize = std::max<size_t>(op.data.size(), 1);
180 
181         const DataLocation location = outputBuilder.append(bufferSize);
182         outputs.push_back({.lifetime = Request::Argument::LifeTime::POOL,
183                            .location = location,
184                            .dimensions = op.dimensions});
185     }
186 
187     // Model pools.
188     auto inputMemory = inputBuilder.finish().value();
189     auto outputMemory = outputBuilder.finish().value();
190     std::vector<Request::MemoryPool> pools = {std::move(inputMemory), std::move(outputMemory)};
191 
192     return {.inputs = std::move(inputs), .outputs = std::move(outputs), .pools = std::move(pools)};
193 }
194 
195 }  // namespace android::nn::test
196