1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "SampleDriver"
18
19 #include "SampleDriver.h"
20
21 #include <CpuExecutor.h>
22 #include <ExecutionBurstServer.h>
23 #include <HalBufferTracker.h>
24 #include <HalInterfaces.h>
25 #include <Tracing.h>
26 #include <ValidateHal.h>
27 #include <android-base/logging.h>
28 #include <android-base/properties.h>
29 #include <hidl/LegacySupport.h>
30 #include <nnapi/Types.h>
31 #include <nnapi/hal/1.3/Conversions.h>
32
33 #include <algorithm>
34 #include <chrono>
35 #include <map>
36 #include <memory>
37 #include <optional>
38 #include <set>
39 #include <thread>
40 #include <tuple>
41 #include <utility>
42 #include <vector>
43
44 #include "SampleDriverUtils.h"
45
46 namespace android {
47 namespace nn {
48 namespace sample_driver {
49
50 namespace {
51
microsecondsDuration(TimePoint end,TimePoint start)52 uint64_t microsecondsDuration(TimePoint end, TimePoint start) {
53 using Microseconds = std::chrono::duration<uint64_t, std::micro>;
54 return std::chrono::duration_cast<Microseconds>(end - start).count();
55 };
56
57 } // namespace
58
59 static const V1_2::Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
60
getCapabilities(getCapabilities_cb cb)61 hardware::Return<void> SampleDriver::getCapabilities(getCapabilities_cb cb) {
62 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
63 "SampleDriver::getCapabilities");
64 return getCapabilities_1_3(
65 [&](V1_3::ErrorStatus error, const V1_3::Capabilities& capabilities) {
66 // TODO(dgross): Do we need to check compliantWithV1_0(capabilities)?
67 cb(convertToV1_0(error), convertToV1_0(capabilities));
68 });
69 }
70
getCapabilities_1_1(getCapabilities_1_1_cb cb)71 hardware::Return<void> SampleDriver::getCapabilities_1_1(getCapabilities_1_1_cb cb) {
72 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
73 "SampleDriver::getCapabilities_1_1");
74 return getCapabilities_1_3(
75 [&](V1_3::ErrorStatus error, const V1_3::Capabilities& capabilities) {
76 // TODO(dgross): Do we need to check compliantWithV1_1(capabilities)?
77 cb(convertToV1_0(error), convertToV1_1(capabilities));
78 });
79 }
80
getCapabilities_1_2(getCapabilities_1_2_cb cb)81 hardware::Return<void> SampleDriver::getCapabilities_1_2(getCapabilities_1_2_cb cb) {
82 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
83 "SampleDriver::getCapabilities_1_2");
84 return getCapabilities_1_3(
85 [&](V1_3::ErrorStatus error, const V1_3::Capabilities& capabilities) {
86 // TODO(dgross): Do we need to check compliantWithV1_2(capabilities)?
87 cb(convertToV1_0(error), convertToV1_2(capabilities));
88 });
89 }
90
getVersionString(getVersionString_cb cb)91 hardware::Return<void> SampleDriver::getVersionString(getVersionString_cb cb) {
92 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
93 "SampleDriver::getVersionString");
94 cb(V1_0::ErrorStatus::NONE, "JUST_AN_EXAMPLE");
95 return hardware::Void();
96 }
97
getType(getType_cb cb)98 hardware::Return<void> SampleDriver::getType(getType_cb cb) {
99 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION, "SampleDriver::getType");
100 cb(V1_0::ErrorStatus::NONE, V1_2::DeviceType::CPU);
101 return hardware::Void();
102 }
103
getSupportedExtensions(getSupportedExtensions_cb cb)104 hardware::Return<void> SampleDriver::getSupportedExtensions(getSupportedExtensions_cb cb) {
105 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
106 "SampleDriver::getSupportedExtensions");
107 cb(V1_0::ErrorStatus::NONE, {/* No extensions. */});
108 return hardware::Void();
109 }
110
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb cb)111 hardware::Return<void> SampleDriver::getSupportedOperations(const V1_0::Model& model,
112 getSupportedOperations_cb cb) {
113 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
114 "SampleDriver::getSupportedOperations");
115 if (!validateModel(model)) {
116 VLOG(DRIVER) << "getSupportedOperations";
117 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
118 return hardware::Void();
119 }
120 return getSupportedOperations_1_3(
121 convertToV1_3(model),
122 [&](V1_3::ErrorStatus status, const hardware::hidl_vec<bool>& supported) {
123 cb(convertToV1_0(status), supported);
124 });
125 }
126
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb cb)127 hardware::Return<void> SampleDriver::getSupportedOperations_1_1(const V1_1::Model& model,
128 getSupportedOperations_1_1_cb cb) {
129 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
130 "SampleDriver::getSupportedOperations_1_1");
131 if (!validateModel(model)) {
132 VLOG(DRIVER) << "getSupportedOperations_1_1";
133 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
134 return hardware::Void();
135 }
136 return getSupportedOperations_1_3(
137 convertToV1_3(model),
138 [&](V1_3::ErrorStatus status, const hardware::hidl_vec<bool>& supported) {
139 cb(convertToV1_0(status), supported);
140 });
141 }
142
getSupportedOperations_1_2(const V1_2::Model & model,getSupportedOperations_1_2_cb cb)143 hardware::Return<void> SampleDriver::getSupportedOperations_1_2(const V1_2::Model& model,
144 getSupportedOperations_1_2_cb cb) {
145 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
146 "SampleDriver::getSupportedOperations_1_2");
147 if (!validateModel(model)) {
148 VLOG(DRIVER) << "getSupportedOperations_1_2";
149 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
150 return hardware::Void();
151 }
152 return getSupportedOperations_1_3(
153 convertToV1_3(model),
154 [&](V1_3::ErrorStatus status, const hardware::hidl_vec<bool>& supported) {
155 cb(convertToV1_0(status), supported);
156 });
157 }
158
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)159 hardware::Return<void> SampleDriver::getNumberOfCacheFilesNeeded(
160 getNumberOfCacheFilesNeeded_cb cb) {
161 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
162 "SampleDriver::getNumberOfCacheFilesNeeded");
163 // Set both numbers to be 0 for cache not supported.
164 cb(V1_0::ErrorStatus::NONE, /*numModelCache=*/0, /*numDataCache=*/0);
165 return hardware::Void();
166 }
167
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & callback)168 hardware::Return<V1_0::ErrorStatus> SampleDriver::prepareModel(
169 const V1_0::Model& model, const sp<V1_0::IPreparedModelCallback>& callback) {
170 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel");
171 const V1_3::ErrorStatus status =
172 prepareModelBase(model, this, V1_1::ExecutionPreference::FAST_SINGLE_ANSWER,
173 kDefaultPriority13, {}, callback);
174 return convertToV1_0(status);
175 }
176
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & callback)177 hardware::Return<V1_0::ErrorStatus> SampleDriver::prepareModel_1_1(
178 const V1_1::Model& model, V1_1::ExecutionPreference preference,
179 const sp<V1_0::IPreparedModelCallback>& callback) {
180 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_1");
181 const V1_3::ErrorStatus status =
182 prepareModelBase(model, this, preference, kDefaultPriority13, {}, callback);
183 return convertToV1_0(status);
184 }
185
prepareModel_1_2(const V1_2::Model & model,V1_1::ExecutionPreference preference,const hardware::hidl_vec<hardware::hidl_handle> &,const hardware::hidl_vec<hardware::hidl_handle> &,const HalCacheToken &,const sp<V1_2::IPreparedModelCallback> & callback)186 hardware::Return<V1_0::ErrorStatus> SampleDriver::prepareModel_1_2(
187 const V1_2::Model& model, V1_1::ExecutionPreference preference,
188 const hardware::hidl_vec<hardware::hidl_handle>&,
189 const hardware::hidl_vec<hardware::hidl_handle>&, const HalCacheToken&,
190 const sp<V1_2::IPreparedModelCallback>& callback) {
191 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_2");
192 const V1_3::ErrorStatus status =
193 prepareModelBase(model, this, preference, kDefaultPriority13, {}, callback);
194 return convertToV1_0(status);
195 }
196
prepareModel_1_3(const V1_3::Model & model,V1_1::ExecutionPreference preference,V1_3::Priority priority,const V1_3::OptionalTimePoint & deadline,const hardware::hidl_vec<hardware::hidl_handle> &,const hardware::hidl_vec<hardware::hidl_handle> &,const HalCacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)197 hardware::Return<V1_3::ErrorStatus> SampleDriver::prepareModel_1_3(
198 const V1_3::Model& model, V1_1::ExecutionPreference preference, V1_3::Priority priority,
199 const V1_3::OptionalTimePoint& deadline, const hardware::hidl_vec<hardware::hidl_handle>&,
200 const hardware::hidl_vec<hardware::hidl_handle>&, const HalCacheToken&,
201 const sp<V1_3::IPreparedModelCallback>& callback) {
202 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_3");
203 return prepareModelBase(model, this, preference, priority, deadline, callback);
204 }
205
prepareModelFromCache(const hardware::hidl_vec<hardware::hidl_handle> &,const hardware::hidl_vec<hardware::hidl_handle> &,const HalCacheToken &,const sp<V1_2::IPreparedModelCallback> & callback)206 hardware::Return<V1_0::ErrorStatus> SampleDriver::prepareModelFromCache(
207 const hardware::hidl_vec<hardware::hidl_handle>&,
208 const hardware::hidl_vec<hardware::hidl_handle>&, const HalCacheToken&,
209 const sp<V1_2::IPreparedModelCallback>& callback) {
210 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
211 "SampleDriver::prepareModelFromCache");
212 notify(callback, V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
213 return V1_0::ErrorStatus::GENERAL_FAILURE;
214 }
215
prepareModelFromCache_1_3(const V1_3::OptionalTimePoint &,const hardware::hidl_vec<hardware::hidl_handle> &,const hardware::hidl_vec<hardware::hidl_handle> &,const HalCacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)216 hardware::Return<V1_3::ErrorStatus> SampleDriver::prepareModelFromCache_1_3(
217 const V1_3::OptionalTimePoint& /*deadline*/,
218 const hardware::hidl_vec<hardware::hidl_handle>&,
219 const hardware::hidl_vec<hardware::hidl_handle>&, const HalCacheToken&,
220 const sp<V1_3::IPreparedModelCallback>& callback) {
221 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
222 "SampleDriver::prepareModelFromCache_1_3");
223 notify(callback, V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
224 return V1_3::ErrorStatus::GENERAL_FAILURE;
225 }
226
getStatus()227 hardware::Return<V1_0::DeviceStatus> SampleDriver::getStatus() {
228 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_UNSPECIFIED, "SampleDriver::getStatus");
229 VLOG(DRIVER) << "getStatus()";
230 return V1_0::DeviceStatus::AVAILABLE;
231 }
232
233 // Safely downcast an IPreparedModel object to SamplePreparedModel.
234 // This function will return nullptr if the IPreparedModel object is not originated from the sample
235 // driver process.
castToSamplePreparedModel(const sp<V1_3::IPreparedModel> & preparedModel)236 static const SamplePreparedModel* castToSamplePreparedModel(
237 const sp<V1_3::IPreparedModel>& preparedModel) {
238 if (preparedModel->isRemote()) {
239 return nullptr;
240 } else {
241 // This static_cast is safe because SamplePreparedModel is the only class that implements
242 // the IPreparedModel interface in the sample driver process.
243 return static_cast<const SamplePreparedModel*>(preparedModel.get());
244 }
245 }
246
allocate(const V1_3::BufferDesc & desc,const hardware::hidl_vec<sp<V1_3::IPreparedModel>> & preparedModels,const hardware::hidl_vec<V1_3::BufferRole> & inputRoles,const hardware::hidl_vec<V1_3::BufferRole> & outputRoles,allocate_cb cb)247 hardware::Return<void> SampleDriver::allocate(
248 const V1_3::BufferDesc& desc,
249 const hardware::hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
250 const hardware::hidl_vec<V1_3::BufferRole>& inputRoles,
251 const hardware::hidl_vec<V1_3::BufferRole>& outputRoles, allocate_cb cb) {
252 constexpr uint32_t kInvalidBufferToken = 0;
253
254 VLOG(DRIVER) << "SampleDriver::allocate";
255 std::set<HalPreparedModelRole> roles;
256 V1_3::Operand operand;
257 auto getModel = [](const sp<V1_3::IPreparedModel>& preparedModel) -> const V1_3::Model* {
258 const auto* samplePreparedModel = castToSamplePreparedModel(preparedModel);
259 if (samplePreparedModel == nullptr) {
260 LOG(ERROR) << "SampleDriver::allocate -- unknown remote IPreparedModel.";
261 return nullptr;
262 }
263 return samplePreparedModel->getModel();
264 };
265 if (!validateMemoryDesc(desc, preparedModels, inputRoles, outputRoles, getModel, &roles,
266 &operand)) {
267 LOG(ERROR) << "SampleDriver::allocate -- validation failed.";
268 cb(V1_3::ErrorStatus::INVALID_ARGUMENT, nullptr, kInvalidBufferToken);
269 return hardware::Void();
270 }
271
272 if (isExtensionOperandType(operand.type)) {
273 LOG(ERROR) << "SampleDriver::allocate -- does not support extension type.";
274 cb(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
275 return hardware::Void();
276 }
277
278 // TODO(xusongw): Support allocating buffers with unknown dimensions or rank.
279 uint32_t size = nonExtensionOperandSizeOfData(operand.type, operand.dimensions);
280 VLOG(DRIVER) << "SampleDriver::allocate -- type = " << toString(operand.type)
281 << ", dimensions = " << toString(operand.dimensions) << ", size = " << size;
282 if (size == 0) {
283 LOG(ERROR) << "SampleDriver::allocate -- does not support dynamic output shape.";
284 cb(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
285 return hardware::Void();
286 }
287
288 auto bufferWrapper =
289 HalManagedBuffer::create(size, std::move(roles), uncheckedConvert(operand));
290 if (bufferWrapper == nullptr) {
291 LOG(ERROR) << "SampleDriver::allocate -- not enough memory.";
292 cb(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
293 return hardware::Void();
294 }
295
296 auto token = mHalBufferTracker->add(bufferWrapper);
297 if (token == nullptr) {
298 LOG(ERROR) << "SampleDriver::allocate -- HalBufferTracker returned invalid token.";
299 cb(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
300 return hardware::Void();
301 }
302
303 const uint32_t tokenValue = token->get();
304 sp<SampleBuffer> sampleBuffer = new SampleBuffer(std::move(bufferWrapper), std::move(token));
305 VLOG(DRIVER) << "SampleDriver::allocate -- successfully allocates the requested memory";
306 cb(V1_3::ErrorStatus::NONE, std::move(sampleBuffer), tokenValue);
307 return hardware::Void();
308 }
309
run()310 int SampleDriver::run() {
311 android::hardware::configureRpcThreadpool(4, true);
312 if (registerAsService(mName) != android::OK) {
313 LOG(ERROR) << "Could not register service";
314 return 1;
315 }
316 android::hardware::joinRpcThreadpool();
317 LOG(ERROR) << "Service exited!";
318 return 1;
319 }
320
copyRunTimePoolInfos(const RunTimePoolInfo & srcPool,const RunTimePoolInfo & dstPool)321 static void copyRunTimePoolInfos(const RunTimePoolInfo& srcPool, const RunTimePoolInfo& dstPool) {
322 CHECK(srcPool.getBuffer() != nullptr);
323 CHECK(dstPool.getBuffer() != nullptr);
324 CHECK(srcPool.getSize() == dstPool.getSize());
325 std::copy(srcPool.getBuffer(), srcPool.getBuffer() + srcPool.getSize(), dstPool.getBuffer());
326 dstPool.flush();
327 }
328
copyTo(const hardware::hidl_memory & dst)329 hardware::Return<V1_3::ErrorStatus> SampleBuffer::copyTo(const hardware::hidl_memory& dst) {
330 const auto dstPool = RunTimePoolInfo::createFromMemory(uncheckedConvert(dst));
331 if (!dstPool.has_value()) {
332 LOG(ERROR) << "SampleBuffer::copyTo -- unable to map dst memory.";
333 return V1_3::ErrorStatus::GENERAL_FAILURE;
334 }
335 const V1_3::ErrorStatus validationStatus =
336 convertToV1_3(kBuffer->validateCopyTo(dstPool->getSize()));
337 if (validationStatus != V1_3::ErrorStatus::NONE) {
338 return validationStatus;
339 }
340 const auto srcPool = kBuffer->createRunTimePoolInfo();
341 copyRunTimePoolInfos(srcPool, dstPool.value());
342 return V1_3::ErrorStatus::NONE;
343 }
344
copyFromInternal(const hardware::hidl_memory & src,const hardware::hidl_vec<uint32_t> & dimensions,const std::shared_ptr<HalManagedBuffer> & bufferWrapper)345 static V1_3::ErrorStatus copyFromInternal(const hardware::hidl_memory& src,
346 const hardware::hidl_vec<uint32_t>& dimensions,
347 const std::shared_ptr<HalManagedBuffer>& bufferWrapper) {
348 CHECK(bufferWrapper != nullptr);
349 const auto srcPool = RunTimePoolInfo::createFromMemory(uncheckedConvert(src));
350 if (!srcPool.has_value()) {
351 LOG(ERROR) << "SampleBuffer::copyFrom -- unable to map src memory.";
352 return V1_3::ErrorStatus::GENERAL_FAILURE;
353 }
354 const V1_3::ErrorStatus validationStatus =
355 convertToV1_3(bufferWrapper->validateCopyFrom(dimensions, srcPool->getSize()));
356 if (validationStatus != V1_3::ErrorStatus::NONE) {
357 return validationStatus;
358 }
359 const auto dstPool = bufferWrapper->createRunTimePoolInfo();
360 copyRunTimePoolInfos(srcPool.value(), dstPool);
361 return V1_3::ErrorStatus::NONE;
362 }
363
copyFrom(const hardware::hidl_memory & src,const hardware::hidl_vec<uint32_t> & dimensions)364 hardware::Return<V1_3::ErrorStatus> SampleBuffer::copyFrom(
365 const hardware::hidl_memory& src, const hardware::hidl_vec<uint32_t>& dimensions) {
366 const auto status = copyFromInternal(src, dimensions, kBuffer);
367 if (status == V1_3::ErrorStatus::NONE) {
368 kBuffer->updateDimensions(dimensions);
369 kBuffer->setInitialized(true);
370 } else {
371 kBuffer->setInitialized(false);
372 }
373 return status;
374 }
375
initialize()376 bool SamplePreparedModel::initialize() {
377 return setRunTimePoolInfosFromCanonicalMemories(&mPoolInfos, uncheckedConvert(mModel.pools));
378 }
379
380 static std::tuple<V1_3::ErrorStatus, std::vector<RunTimePoolInfo>,
381 std::vector<std::shared_ptr<HalManagedBuffer>>>
createRunTimePoolInfos(const V1_3::Request & request,const SampleDriver & driver,const SamplePreparedModel * preparedModel)382 createRunTimePoolInfos(const V1_3::Request& request, const SampleDriver& driver,
383 const SamplePreparedModel* preparedModel) {
384 std::vector<RunTimePoolInfo> requestPoolInfos;
385 std::vector<std::shared_ptr<HalManagedBuffer>> bufferWrappers;
386 requestPoolInfos.reserve(request.pools.size());
387 bufferWrappers.reserve(request.pools.size());
388 for (uint32_t i = 0; i < request.pools.size(); i++) {
389 auto& pool = request.pools[i];
390 switch (pool.getDiscriminator()) {
391 case V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory: {
392 auto buffer =
393 RunTimePoolInfo::createFromMemory(uncheckedConvert(pool.hidlMemory()));
394 if (!buffer.has_value()) {
395 LOG(ERROR) << "createRuntimeMemoriesFromMemoryPools -- could not map pools";
396 return {V1_3::ErrorStatus::GENERAL_FAILURE, {}, {}};
397 }
398 requestPoolInfos.push_back(std::move(*buffer));
399 bufferWrappers.push_back(nullptr);
400 } break;
401 case V1_3::Request::MemoryPool::hidl_discriminator::token: {
402 auto bufferWrapper = driver.getHalBufferTracker()->get(pool.token());
403 if (bufferWrapper == nullptr) {
404 return {V1_3::ErrorStatus::INVALID_ARGUMENT, {}, {}};
405 }
406 const auto validationStatus = convertToV1_3(bufferWrapper->validateRequest(
407 i, uncheckedConvert(request), preparedModel));
408 if (validationStatus != V1_3::ErrorStatus::NONE) {
409 return {validationStatus, {}, {}};
410 }
411 requestPoolInfos.push_back(bufferWrapper->createRunTimePoolInfo());
412 bufferWrappers.push_back(std::move(bufferWrapper));
413 } break;
414 }
415 }
416 return {V1_3::ErrorStatus::NONE, std::move(requestPoolInfos), std::move(bufferWrappers)};
417 }
418
updateDeviceMemories(V1_3::ErrorStatus status,const V1_3::Request & request,const std::vector<std::shared_ptr<HalManagedBuffer>> & bufferWrappers,const hardware::hidl_vec<V1_2::OutputShape> & outputShapes)419 static V1_3::ErrorStatus updateDeviceMemories(
420 V1_3::ErrorStatus status, const V1_3::Request& request,
421 const std::vector<std::shared_ptr<HalManagedBuffer>>& bufferWrappers,
422 const hardware::hidl_vec<V1_2::OutputShape>& outputShapes) {
423 if (status == V1_3::ErrorStatus::NONE) {
424 for (uint32_t i = 0; i < request.outputs.size(); i++) {
425 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
426 const auto& pool = request.pools[poolIndex];
427 if (pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::token) {
428 if (!bufferWrappers[poolIndex]->updateDimensions(outputShapes[i].dimensions)) {
429 return V1_3::ErrorStatus::GENERAL_FAILURE;
430 }
431 }
432 }
433 for (uint32_t i = 0; i < request.outputs.size(); i++) {
434 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
435 const auto& pool = request.pools[poolIndex];
436 if (pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::token) {
437 bufferWrappers[poolIndex]->setInitialized(true);
438 }
439 }
440 } else if (status == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
441 // If CpuExecutor reports OUTPUT_INSUFFCIENT_SIZE on a device memory, this is because the
442 // dimensions of the device memory are incorrectly specified. The driver should return
443 // GENERAL_FAILURE instead in this case.
444 for (uint32_t i = 0; i < request.outputs.size(); i++) {
445 const uint32_t poolIndex = request.outputs[i].location.poolIndex;
446 const auto& pool = request.pools[poolIndex];
447 if (pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::token) {
448 if (!outputShapes[i].isSufficient) {
449 LOG(ERROR) << "Invalid dimensions for output " << i
450 << ": actual shape = " << toString(outputShapes[i].dimensions);
451 return V1_3::ErrorStatus::GENERAL_FAILURE;
452 }
453 }
454 }
455 }
456 return V1_3::ErrorStatus::NONE;
457 }
458
459 template <typename T_IExecutionCallback>
asyncExecute(const V1_3::Request & request,V1_2::MeasureTiming measure,TimePoint driverStart,const V1_3::Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<T_IExecutionCallback> & callback)460 void asyncExecute(const V1_3::Request& request, V1_2::MeasureTiming measure, TimePoint driverStart,
461 const V1_3::Model& model, const SampleDriver& driver,
462 const SamplePreparedModel* preparedModel,
463 const std::vector<RunTimePoolInfo>& poolInfos, const OptionalTimePoint& deadline,
464 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
465 const sp<T_IExecutionCallback>& callback) {
466 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
467 "SampleDriver::asyncExecute");
468
469 const auto [poolStatus, requestPoolInfos, bufferWrappers] =
470 createRunTimePoolInfos(request, driver, preparedModel);
471 if (poolStatus != V1_3::ErrorStatus::NONE) {
472 notify(callback, poolStatus, {}, kNoTiming);
473 return;
474 }
475
476 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
477 "SampleDriver::asyncExecute");
478 CpuExecutor executor = driver.getExecutor();
479 if (loopTimeoutDuration.getDiscriminator() !=
480 V1_3::OptionalTimeoutDuration::hidl_discriminator::none) {
481 executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
482 }
483 if (deadline.has_value()) {
484 executor.setDeadline(*deadline);
485 }
486 TimePoint driverEnd, deviceStart, deviceEnd;
487 if (measure == V1_2::MeasureTiming::YES) deviceStart = Clock::now();
488 int n = executor.run(uncheckedConvert(model), uncheckedConvert(request), poolInfos,
489 requestPoolInfos);
490 if (measure == V1_2::MeasureTiming::YES) deviceEnd = Clock::now();
491 VLOG(DRIVER) << "executor.run returned " << n;
492 V1_3::ErrorStatus executionStatus = convertResultCodeToHalErrorStatus(n);
493 hardware::hidl_vec<V1_2::OutputShape> outputShapes = convertToV1_2(executor.getOutputShapes());
494
495 // Update device memory metadata.
496 const V1_3::ErrorStatus updateStatus =
497 updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
498 if (updateStatus != V1_3::ErrorStatus::NONE) {
499 notify(callback, updateStatus, {}, kNoTiming);
500 return;
501 }
502
503 if (measure == V1_2::MeasureTiming::YES && executionStatus == V1_3::ErrorStatus::NONE) {
504 driverEnd = Clock::now();
505 V1_2::Timing timing = {
506 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
507 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
508 VLOG(DRIVER) << "SampleDriver::asyncExecute timing = " << toString(timing);
509 notify(callback, executionStatus, outputShapes, timing);
510 } else {
511 notify(callback, executionStatus, outputShapes, kNoTiming);
512 }
513 }
514
515 template <typename T_IExecutionCallback>
executeBase(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const V1_3::OptionalTimePoint & halDeadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<T_IExecutionCallback> & callback)516 V1_3::ErrorStatus executeBase(const V1_3::Request& request, V1_2::MeasureTiming measure,
517 const V1_3::Model& model, const SampleDriver& driver,
518 const SamplePreparedModel* preparedModel,
519 const std::vector<RunTimePoolInfo>& poolInfos,
520 const V1_3::OptionalTimePoint& halDeadline,
521 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
522 const sp<T_IExecutionCallback>& callback) {
523 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "SampleDriver::executeBase");
524 VLOG(DRIVER) << "executeBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
525
526 TimePoint driverStart;
527 if (measure == V1_2::MeasureTiming::YES) driverStart = Clock::now();
528
529 if (callback.get() == nullptr) {
530 LOG(ERROR) << "invalid callback passed to executeBase";
531 return V1_3::ErrorStatus::INVALID_ARGUMENT;
532 }
533 if (!validateRequest(request, model)) {
534 notify(callback, V1_3::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
535 return V1_3::ErrorStatus::INVALID_ARGUMENT;
536 }
537 const auto deadline = convert(halDeadline).value();
538 if (hasDeadlinePassed(deadline)) {
539 notify(callback, V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT, {}, kNoTiming);
540 return V1_3::ErrorStatus::NONE;
541 }
542
543 // This thread is intentionally detached because the sample driver service
544 // is expected to live forever.
545 std::thread([&model, &driver, preparedModel, &poolInfos, request, measure, driverStart,
546 deadline, loopTimeoutDuration, callback] {
547 asyncExecute(request, measure, driverStart, model, driver, preparedModel, poolInfos,
548 deadline, loopTimeoutDuration, callback);
549 }).detach();
550
551 return V1_3::ErrorStatus::NONE;
552 }
553
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)554 hardware::Return<V1_0::ErrorStatus> SamplePreparedModel::execute(
555 const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) {
556 const V1_3::ErrorStatus status =
557 executeBase(convertToV1_3(request), V1_2::MeasureTiming::NO, mModel, *mDriver, this,
558 mPoolInfos, {}, {}, callback);
559 return convertToV1_0(status);
560 }
561
execute_1_2(const V1_0::Request & request,V1_2::MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)562 hardware::Return<V1_0::ErrorStatus> SamplePreparedModel::execute_1_2(
563 const V1_0::Request& request, V1_2::MeasureTiming measure,
564 const sp<V1_2::IExecutionCallback>& callback) {
565 const V1_3::ErrorStatus status = executeBase(convertToV1_3(request), measure, mModel, *mDriver,
566 this, mPoolInfos, {}, {}, callback);
567 return convertToV1_0(status);
568 }
569
execute_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)570 hardware::Return<V1_3::ErrorStatus> SamplePreparedModel::execute_1_3(
571 const V1_3::Request& request, V1_2::MeasureTiming measure,
572 const V1_3::OptionalTimePoint& deadline,
573 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
574 const sp<V1_3::IExecutionCallback>& callback) {
575 return executeBase(request, measure, mModel, *mDriver, this, mPoolInfos, deadline,
576 loopTimeoutDuration, callback);
577 }
578
579 static std::tuple<V1_3::ErrorStatus, hardware::hidl_vec<V1_2::OutputShape>, V1_2::Timing>
executeSynchronouslyBase(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const V1_3::OptionalTimePoint & halDeadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration)580 executeSynchronouslyBase(const V1_3::Request& request, V1_2::MeasureTiming measure,
581 const V1_3::Model& model, const SampleDriver& driver,
582 const SamplePreparedModel* preparedModel,
583 const std::vector<RunTimePoolInfo>& poolInfos,
584 const V1_3::OptionalTimePoint& halDeadline,
585 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration) {
586 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
587 "SampleDriver::executeSynchronouslyBase");
588 VLOG(DRIVER) << "executeSynchronouslyBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
589
590 TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
591 if (measure == V1_2::MeasureTiming::YES) driverStart = Clock::now();
592
593 if (!validateRequest(request, model)) {
594 return {V1_3::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
595 }
596 const auto deadline = convert(halDeadline).value();
597 if (hasDeadlinePassed(deadline)) {
598 return {V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT, {}, kNoTiming};
599 }
600
601 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
602 "SampleDriver::executeSynchronouslyBase");
603 const auto [poolStatus, requestPoolInfos, bufferWrappers] =
604 createRunTimePoolInfos(request, driver, preparedModel);
605 if (poolStatus != V1_3::ErrorStatus::NONE) {
606 return {poolStatus, {}, kNoTiming};
607 }
608
609 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
610 "SampleDriver::executeSynchronouslyBase");
611 CpuExecutor executor = driver.getExecutor();
612 if (loopTimeoutDuration.getDiscriminator() !=
613 V1_3::OptionalTimeoutDuration::hidl_discriminator::none) {
614 executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
615 }
616 if (deadline.has_value()) {
617 executor.setDeadline(*deadline);
618 }
619 if (measure == V1_2::MeasureTiming::YES) deviceStart = Clock::now();
620 int n = executor.run(uncheckedConvert(model), uncheckedConvert(request), poolInfos,
621 requestPoolInfos);
622 if (measure == V1_2::MeasureTiming::YES) deviceEnd = Clock::now();
623 VLOG(DRIVER) << "executor.run returned " << n;
624 V1_3::ErrorStatus executionStatus = convertResultCodeToHalErrorStatus(n);
625 hardware::hidl_vec<V1_2::OutputShape> outputShapes = convertToV1_2(executor.getOutputShapes());
626
627 // Update device memory metadata.
628 const V1_3::ErrorStatus updateStatus =
629 updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
630 if (updateStatus != V1_3::ErrorStatus::NONE) {
631 return {updateStatus, {}, kNoTiming};
632 }
633
634 if (measure == V1_2::MeasureTiming::YES && executionStatus == V1_3::ErrorStatus::NONE) {
635 driverEnd = Clock::now();
636 V1_2::Timing timing = {
637 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
638 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
639 VLOG(DRIVER) << "executeSynchronouslyBase timing = " << toString(timing);
640 return {executionStatus, std::move(outputShapes), timing};
641 }
642 return {executionStatus, std::move(outputShapes), kNoTiming};
643 }
644
executeSynchronously(const V1_0::Request & request,V1_2::MeasureTiming measure,executeSynchronously_cb cb)645 hardware::Return<void> SamplePreparedModel::executeSynchronously(const V1_0::Request& request,
646 V1_2::MeasureTiming measure,
647 executeSynchronously_cb cb) {
648 auto [status, outputShapes, timing] = executeSynchronouslyBase(
649 convertToV1_3(request), measure, mModel, *mDriver, this, mPoolInfos, {}, {});
650 cb(convertToV1_0(status), std::move(outputShapes), timing);
651 return hardware::Void();
652 }
653
executeSynchronously_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,executeSynchronously_1_3_cb cb)654 hardware::Return<void> SamplePreparedModel::executeSynchronously_1_3(
655 const V1_3::Request& request, V1_2::MeasureTiming measure,
656 const V1_3::OptionalTimePoint& deadline,
657 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration, executeSynchronously_1_3_cb cb) {
658 auto [status, outputShapes, timing] = executeSynchronouslyBase(
659 request, measure, mModel, *mDriver, this, mPoolInfos, deadline, loopTimeoutDuration);
660 cb(status, std::move(outputShapes), timing);
661 return hardware::Void();
662 }
663
664 // The sample driver will finish the execution and then return.
executeFenced(const V1_3::Request & request,const hardware::hidl_vec<hardware::hidl_handle> & waitFor,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & halDeadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const V1_3::OptionalTimeoutDuration & duration,executeFenced_cb cb)665 hardware::Return<void> SamplePreparedModel::executeFenced(
666 const V1_3::Request& request, const hardware::hidl_vec<hardware::hidl_handle>& waitFor,
667 V1_2::MeasureTiming measure, const V1_3::OptionalTimePoint& halDeadline,
668 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
669 const V1_3::OptionalTimeoutDuration& duration, executeFenced_cb cb) {
670 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
671 "SamplePreparedModel::executeFenced");
672 VLOG(DRIVER) << "executeFenced(" << SHOW_IF_DEBUG(toString(request)) << ")";
673
674 TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
675 if (measure == V1_2::MeasureTiming::YES) driverStart = Clock::now();
676
677 if (!validateRequest(request, mModel, /*allowUnspecifiedOutput=*/false)) {
678 cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hardware::hidl_handle(nullptr), nullptr);
679 return hardware::Void();
680 }
681 const auto deadline = convert(halDeadline).value();
682 if (hasDeadlinePassed(deadline)) {
683 cb(V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT, hardware::hidl_handle(nullptr), nullptr);
684 return hardware::Void();
685 }
686
687 // Wait for the dependent events to signal
688 for (const auto& fenceHandle : waitFor) {
689 if (!fenceHandle.getNativeHandle()) {
690 cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hardware::hidl_handle(nullptr), nullptr);
691 return hardware::Void();
692 }
693 int syncFenceFd = fenceHandle.getNativeHandle()->data[0];
694 if (syncWait(syncFenceFd, -1) != FenceState::SIGNALED) {
695 LOG(ERROR) << "syncWait failed";
696 cb(V1_3::ErrorStatus::GENERAL_FAILURE, hardware::hidl_handle(nullptr), nullptr);
697 return hardware::Void();
698 }
699 }
700
701 // Update deadline if the timeout duration is closer than the deadline.
702 auto closestDeadline = deadline;
703 if (duration.getDiscriminator() != V1_3::OptionalTimeoutDuration::hidl_discriminator::none) {
704 const auto timeoutDurationDeadline = makeDeadline(duration.nanoseconds());
705 if (!closestDeadline.has_value() || *closestDeadline > timeoutDurationDeadline) {
706 closestDeadline = timeoutDurationDeadline;
707 }
708 }
709
710 TimePoint driverStartAfterFence;
711 if (measure == V1_2::MeasureTiming::YES) driverStartAfterFence = Clock::now();
712
713 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
714 "SamplePreparedModel::executeFenced");
715 const auto [poolStatus, requestPoolInfos, bufferWrappers] =
716 createRunTimePoolInfos(request, *mDriver, this);
717 if (poolStatus != V1_3::ErrorStatus::NONE) {
718 cb(poolStatus, hardware::hidl_handle(nullptr), nullptr);
719 return hardware::Void();
720 }
721
722 NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
723 "SamplePreparedModel::executeFenced");
724 CpuExecutor executor = mDriver->getExecutor();
725 if (loopTimeoutDuration.getDiscriminator() !=
726 V1_3::OptionalTimeoutDuration::hidl_discriminator::none) {
727 executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
728 }
729 if (closestDeadline.has_value()) {
730 executor.setDeadline(*closestDeadline);
731 }
732 if (measure == V1_2::MeasureTiming::YES) deviceStart = Clock::now();
733 int n = executor.run(uncheckedConvert(mModel), uncheckedConvert(request), mPoolInfos,
734 requestPoolInfos);
735 if (measure == V1_2::MeasureTiming::YES) deviceEnd = Clock::now();
736 VLOG(DRIVER) << "executor.run returned " << n;
737 V1_3::ErrorStatus executionStatus = convertResultCodeToHalErrorStatus(n);
738 if (executionStatus != V1_3::ErrorStatus::NONE) {
739 cb(executionStatus, hardware::hidl_handle(nullptr), nullptr);
740 return hardware::Void();
741 }
742
743 // Set output memories to the initialized state.
744 if (executionStatus == V1_3::ErrorStatus::NONE) {
745 for (const auto& output : request.outputs) {
746 const uint32_t poolIndex = output.location.poolIndex;
747 const auto& pool = request.pools[poolIndex];
748 if (pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::token) {
749 bufferWrappers[poolIndex]->setInitialized(true);
750 }
751 }
752 }
753
754 V1_2::Timing timingSinceLaunch = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
755 V1_2::Timing timingAfterFence = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
756 if (measure == V1_2::MeasureTiming::YES) {
757 driverEnd = Clock::now();
758 timingSinceLaunch = {
759 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
760 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
761 timingAfterFence = {
762 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
763 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStartAfterFence))};
764 VLOG(DRIVER) << "executeFenced timingSinceLaunch = " << toString(timingSinceLaunch);
765 VLOG(DRIVER) << "executeFenced timingAfterFence = " << toString(timingAfterFence);
766 }
767 sp<SampleFencedExecutionCallback> fencedExecutionCallback =
768 new SampleFencedExecutionCallback(timingSinceLaunch, timingAfterFence, executionStatus);
769 cb(executionStatus, hardware::hidl_handle(nullptr), fencedExecutionCallback);
770 return hardware::Void();
771 }
772
773 // BurstExecutorWithCache maps hidl_memory when it is first seen, and preserves
774 // the mapping until either (1) the memory is freed in the runtime, or (2) the
775 // burst object is destroyed. This allows for subsequent executions operating on
776 // pools that have been used before to reuse the mapping instead of mapping and
777 // unmapping the memory on each execution.
778 class BurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
779 public:
BurstExecutorWithCache(const V1_3::Model & model,const SampleDriver * driver,const std::vector<RunTimePoolInfo> & poolInfos)780 BurstExecutorWithCache(const V1_3::Model& model, const SampleDriver* driver,
781 const std::vector<RunTimePoolInfo>& poolInfos)
782 : mModel(model), mDriver(driver), mModelPoolInfos(poolInfos) {}
783
isCacheEntryPresent(int32_t slot) const784 bool isCacheEntryPresent(int32_t slot) const override {
785 const auto it = mMemoryCache.find(slot);
786 return (it != mMemoryCache.end()) && it->second.has_value();
787 }
788
addCacheEntry(const hardware::hidl_memory & memory,int32_t slot)789 void addCacheEntry(const hardware::hidl_memory& memory, int32_t slot) override {
790 mMemoryCache[slot] = RunTimePoolInfo::createFromMemory(uncheckedConvert(memory));
791 }
792
removeCacheEntry(int32_t slot)793 void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
794
execute(const V1_0::Request & request,const std::vector<int32_t> & slots,V1_2::MeasureTiming measure)795 std::tuple<V1_0::ErrorStatus, hardware::hidl_vec<V1_2::OutputShape>, V1_2::Timing> execute(
796 const V1_0::Request& request, const std::vector<int32_t>& slots,
797 V1_2::MeasureTiming measure) override {
798 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
799 "BurstExecutorWithCache::execute");
800
801 TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
802 if (measure == V1_2::MeasureTiming::YES) driverStart = Clock::now();
803
804 // ensure all relevant pools are valid
805 if (!std::all_of(slots.begin(), slots.end(),
806 [this](int32_t slot) { return isCacheEntryPresent(slot); })) {
807 return {V1_0::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
808 }
809
810 // finish the request object (for validation)
811 hardware::hidl_vec<V1_3::Request::MemoryPool> pools(slots.size());
812 std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) {
813 V1_3::Request::MemoryPool pool;
814 pool.hidlMemory(convertToV1_0(mMemoryCache[slot]->getMemory()));
815 return pool;
816 });
817 V1_3::Request fullRequest = {.inputs = request.inputs, .outputs = request.outputs};
818 fullRequest.pools = std::move(pools);
819
820 // validate request object against the model
821 if (!validateRequest(fullRequest, mModel)) {
822 return {V1_0::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
823 }
824
825 // select relevant entries from cache
826 std::vector<RunTimePoolInfo> requestPoolInfos;
827 requestPoolInfos.reserve(slots.size());
828 std::transform(slots.begin(), slots.end(), std::back_inserter(requestPoolInfos),
829 [this](int32_t slot) { return *mMemoryCache[slot]; });
830
831 // execution
832 // Configuring the loop timeout duration is not supported. This is OK
833 // because burst does not support HAL 1.3 and hence does not support
834 // WHILE loops.
835 CpuExecutor executor = mDriver->getExecutor();
836 if (measure == V1_2::MeasureTiming::YES) deviceStart = Clock::now();
837 int n = executor.run(uncheckedConvert(mModel), uncheckedConvert(fullRequest),
838 mModelPoolInfos, requestPoolInfos);
839 if (measure == V1_2::MeasureTiming::YES) deviceEnd = Clock::now();
840 VLOG(DRIVER) << "executor.run returned " << n;
841 V1_0::ErrorStatus executionStatus = convertToV1_0(convertResultCodeToHalErrorStatus(n));
842 hardware::hidl_vec<V1_2::OutputShape> outputShapes =
843 convertToV1_2(executor.getOutputShapes());
844 if (measure == V1_2::MeasureTiming::YES && executionStatus == V1_0::ErrorStatus::NONE) {
845 driverEnd = Clock::now();
846 V1_2::Timing timing = {
847 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
848 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
849 VLOG(DRIVER) << "BurstExecutorWithCache::execute timing = " << toString(timing);
850 return std::make_tuple(executionStatus, outputShapes, timing);
851 } else {
852 return std::make_tuple(executionStatus, outputShapes, kNoTiming);
853 }
854 }
855
856 private:
857 const V1_3::Model mModel;
858 const SampleDriver* const mDriver;
859 const std::vector<RunTimePoolInfo> mModelPoolInfos;
860 std::map<int32_t, std::optional<RunTimePoolInfo>> mMemoryCache; // cached requestPoolInfos
861 };
862
863 // This is the amount of time the ExecutionBurstServer should spend polling the
864 // FMQ to see if it has data available before it should fall back to waiting on
865 // the futex.
getPollingTimeWindow()866 static std::chrono::microseconds getPollingTimeWindow() {
867 constexpr int32_t defaultPollingTimeWindow = 50;
868 #ifdef NN_DEBUGGABLE
869 constexpr int32_t minPollingTimeWindow = 0;
870 const int32_t selectedPollingTimeWindow =
871 base::GetIntProperty("debug.nn.sample-driver-burst-polling-window",
872 defaultPollingTimeWindow, minPollingTimeWindow);
873 return std::chrono::microseconds{selectedPollingTimeWindow};
874 #else
875 return std::chrono::microseconds{defaultPollingTimeWindow};
876 #endif // NN_DEBUGGABLE
877 }
878
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)879 hardware::Return<void> SamplePreparedModel::configureExecutionBurst(
880 const sp<V1_2::IBurstCallback>& callback,
881 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
882 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
883 configureExecutionBurst_cb cb) {
884 NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
885 "SampleDriver::configureExecutionBurst");
886
887 const bool preferPowerOverLatency = (kPreference == V1_1::ExecutionPreference::LOW_POWER);
888 const auto pollingTimeWindow =
889 (preferPowerOverLatency ? std::chrono::microseconds{0} : getPollingTimeWindow());
890
891 // Alternatively, the burst could be configured via:
892 // const sp<V1_2::IBurstContext> burst =
893 // ExecutionBurstServer::create(callback, requestChannel,
894 // resultChannel, this,
895 // pollingTimeWindow);
896 //
897 // However, this alternative representation does not include a memory map
898 // caching optimization, and adds overhead.
899 const std::shared_ptr<BurstExecutorWithCache> executorWithCache =
900 std::make_shared<BurstExecutorWithCache>(mModel, mDriver, mPoolInfos);
901 const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(
902 callback, requestChannel, resultChannel, executorWithCache, pollingTimeWindow);
903
904 if (burst == nullptr) {
905 cb(V1_0::ErrorStatus::GENERAL_FAILURE, {});
906 } else {
907 cb(V1_0::ErrorStatus::NONE, burst);
908 }
909
910 return hardware::Void();
911 }
912
913 } // namespace sample_driver
914 } // namespace nn
915 } // namespace android
916