1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "CanonicalDevice.h"
18
19 #include <Tracing.h>
20 #include <android-base/logging.h>
21 #include <nnapi/IBuffer.h>
22 #include <nnapi/IDevice.h>
23 #include <nnapi/IPreparedModel.h>
24 #include <nnapi/OperandTypes.h>
25 #include <nnapi/Result.h>
26 #include <nnapi/Types.h>
27 #include <nnapi/Validation.h>
28
29 #include <algorithm>
30 #include <any>
31 #include <functional>
32 #include <iterator>
33 #include <memory>
34 #include <optional>
35 #include <set>
36 #include <string>
37 #include <utility>
38 #include <vector>
39
40 #include "CanonicalBuffer.h"
41 #include "CanonicalPreparedModel.h"
42
43 namespace android::nn::sample {
44 namespace {
45
makeCapabilities()46 Capabilities makeCapabilities() {
47 constexpr float kPerf = 1.0f;
48 const Capabilities::PerformanceInfo kPerfInfo = {.execTime = kPerf, .powerUsage = kPerf};
49
50 constexpr OperandType kOperandsTypes[] = {
51 OperandType::FLOAT32,
52 OperandType::INT32,
53 OperandType::UINT32,
54 OperandType::TENSOR_FLOAT32,
55 OperandType::TENSOR_INT32,
56 OperandType::TENSOR_QUANT8_ASYMM,
57 OperandType::BOOL,
58 OperandType::TENSOR_QUANT16_SYMM,
59 OperandType::TENSOR_FLOAT16,
60 OperandType::TENSOR_BOOL8,
61 OperandType::FLOAT16,
62 OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL,
63 OperandType::TENSOR_QUANT16_ASYMM,
64 OperandType::TENSOR_QUANT8_SYMM,
65 OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
66 };
67
68 std::vector<Capabilities::OperandPerformance> operandPerformance;
69 operandPerformance.reserve(std::size(kOperandsTypes));
70 std::transform(std::begin(kOperandsTypes), std::end(kOperandsTypes),
71 std::back_inserter(operandPerformance), [kPerfInfo](OperandType op) {
72 return Capabilities::OperandPerformance{.type = op, .info = kPerfInfo};
73 });
74 auto table =
75 Capabilities::OperandPerformanceTable::create(std::move(operandPerformance)).value();
76
77 return {.relaxedFloat32toFloat16PerformanceScalar = kPerfInfo,
78 .relaxedFloat32toFloat16PerformanceTensor = kPerfInfo,
79 .operandPerformance = std::move(table),
80 .ifPerformance = kPerfInfo,
81 .whilePerformance = kPerfInfo};
82 }
83
toString(const Dimensions & dimensions)84 std::string toString(const Dimensions& dimensions) {
85 std::ostringstream oss;
86 oss << "[";
87 for (size_t i = 0; i < dimensions.size(); ++i) {
88 if (i != 0) oss << ", ";
89 oss << dimensions[i];
90 }
91 oss << "]";
92 return oss.str();
93 }
94
95 } // namespace
96
Device(std::string name,const IOperationResolver * operationResolver)97 Device::Device(std::string name, const IOperationResolver* operationResolver)
98 : kName(std::move(name)), kOperationResolver(*operationResolver) {
99 CHECK(operationResolver != nullptr);
100 initVLogMask();
101 }
102
getName() const103 const std::string& Device::getName() const {
104 return kName;
105 }
106
getVersionString() const107 const std::string& Device::getVersionString() const {
108 static const std::string kVersionString = "JUST_AN_EXAMPLE";
109 return kVersionString;
110 }
111
getFeatureLevel() const112 Version Device::getFeatureLevel() const {
113 return Version::ANDROID_S;
114 }
115
getType() const116 DeviceType Device::getType() const {
117 return DeviceType::CPU;
118 }
119
getSupportedExtensions() const120 const std::vector<Extension>& Device::getSupportedExtensions() const {
121 static const std::vector<Extension> kExtensions = {/* No extensions. */};
122 return kExtensions;
123 }
124
getCapabilities() const125 const Capabilities& Device::getCapabilities() const {
126 static const Capabilities kCapabilities = makeCapabilities();
127 return kCapabilities;
128 }
129
getNumberOfCacheFilesNeeded() const130 std::pair<uint32_t, uint32_t> Device::getNumberOfCacheFilesNeeded() const {
131 return std::make_pair(/*numModelCache=*/0, /*numDataCache=*/0);
132 }
133
wait() const134 GeneralResult<void> Device::wait() const {
135 return {};
136 }
137
getSupportedOperations(const Model & model) const138 GeneralResult<std::vector<bool>> Device::getSupportedOperations(const Model& model) const {
139 VLOG(DRIVER) << "sample::Device::getSupportedOperations";
140
141 // Validate arguments.
142 if (const auto result = validate(model); !result.ok()) {
143 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << result.error();
144 }
145
146 // Mark all operations except extension operations as supported.
147 std::vector<bool> supported;
148 supported.reserve(model.main.operations.size());
149 std::transform(model.main.operations.begin(), model.main.operations.end(),
150 std::back_inserter(supported), [](const Operation& operation) {
151 return !isExtensionOperationType(operation.type) &&
152 operation.type != OperationType::OEM_OPERATION;
153 });
154
155 return supported;
156 }
157
prepareModel(const Model & model,ExecutionPreference preference,Priority priority,OptionalTimePoint deadline,const std::vector<SharedHandle> &,const std::vector<SharedHandle> &,const CacheToken &) const158 GeneralResult<SharedPreparedModel> Device::prepareModel(
159 const Model& model, ExecutionPreference preference, Priority priority,
160 OptionalTimePoint deadline, const std::vector<SharedHandle>& /*modelCache*/,
161 const std::vector<SharedHandle>& /*dataCache*/, const CacheToken& /*token*/) const {
162 if (VLOG_IS_ON(DRIVER)) {
163 VLOG(DRIVER) << "sample::Device::prepareModel";
164 logModelToInfo(model);
165 }
166
167 // Validate arguments.
168 if (const auto result = validate(model); !result.ok()) {
169 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model: " << result.error();
170 }
171 if (const auto result = validate(preference); !result.ok()) {
172 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT)
173 << "Invalid ExecutionPreference: " << result.error();
174 }
175 if (const auto result = validate(priority); !result.ok()) {
176 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Priority: " << result.error();
177 }
178
179 // Check if deadline has passed.
180 if (hasDeadlinePassed(deadline)) {
181 return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
182 }
183
184 std::vector<RunTimePoolInfo> poolInfos;
185 if (!setRunTimePoolInfosFromCanonicalMemories(&poolInfos, model.pools)) {
186 return NN_ERROR() << "setRunTimePoolInfosFromCanonicalMemories failed";
187 }
188
189 // Create the prepared model.
190 return std::make_shared<const PreparedModel>(model, preference, priority, &kOperationResolver,
191 kBufferTracker, std::move(poolInfos));
192 }
193
prepareModelFromCache(OptionalTimePoint,const std::vector<SharedHandle> &,const std::vector<SharedHandle> &,const CacheToken &) const194 GeneralResult<SharedPreparedModel> Device::prepareModelFromCache(
195 OptionalTimePoint /*deadline*/, const std::vector<SharedHandle>& /*modelCache*/,
196 const std::vector<SharedHandle>& /*dataCache*/, const CacheToken& /*token*/) const {
197 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
198 "sample::Device::prepareModelFromCache");
199 return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
200 << "prepareModelFromCache not supported on sample::Device::prepareModelFromCache("
201 << kName << ")";
202 }
203
allocate(const BufferDesc & desc,const std::vector<SharedPreparedModel> & preparedModels,const std::vector<BufferRole> & inputRoles,const std::vector<BufferRole> & outputRoles) const204 GeneralResult<SharedBuffer> Device::allocate(const BufferDesc& desc,
205 const std::vector<SharedPreparedModel>& preparedModels,
206 const std::vector<BufferRole>& inputRoles,
207 const std::vector<BufferRole>& outputRoles) const {
208 VLOG(DRIVER) << "sample::Device::allocate";
209 std::set<PreparedModelRole> roles;
210 Operand operand;
211 auto getModel = [](const SharedPreparedModel& preparedModel) -> const Model* {
212 std::any resource = preparedModel->getUnderlyingResource();
213 const Model** maybeModel = std::any_cast<const Model*>(&resource);
214 if (maybeModel == nullptr) {
215 LOG(ERROR) << "sample::Device::allocate -- unknown remote IPreparedModel.";
216 return nullptr;
217 }
218 return *maybeModel;
219 };
220 if (const auto result = validateMemoryDesc(desc, preparedModels, inputRoles, outputRoles,
221 getModel, &roles, &operand);
222 !result.ok()) {
223 return NN_ERROR(ErrorStatus::INVALID_ARGUMENT)
224 << "sample::Device::allocate -- validation failed: " << result.error();
225 }
226
227 if (isExtensionOperandType(operand.type)) {
228 return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
229 << "sample::Device::allocate -- does not support extension type.";
230 }
231
232 // TODO(xusongw): Support allocating buffers with unknown dimensions or rank.
233 uint32_t size = nonExtensionOperandSizeOfData(operand.type, operand.dimensions);
234 VLOG(DRIVER) << "sample::Device::allocate -- type = " << operand.type
235 << ", dimensions = " << toString(operand.dimensions) << ", size = " << size;
236 if (size == 0) {
237 return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
238 << "sample::Device::allocate -- does not support dynamic output shape.";
239 }
240
241 auto bufferWrapper = ManagedBuffer::create(size, std::move(roles), operand);
242 if (bufferWrapper == nullptr) {
243 return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
244 << "sample::Device::allocate -- not enough memory.";
245 }
246
247 auto token = kBufferTracker->add(bufferWrapper);
248 if (token == nullptr) {
249 return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
250 << "sample::Device::allocate -- BufferTracker returned invalid token.";
251 }
252
253 auto sampleBuffer = std::make_shared<const Buffer>(std::move(bufferWrapper), std::move(token));
254 VLOG(DRIVER) << "sample::Device::allocate -- successfully allocates the requested memory";
255 return sampleBuffer;
256 }
257
258 } // namespace android::nn::sample
259