1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "SampleDriver"
18 
19 #include "SampleDriver.h"
20 
21 #include <CpuExecutor.h>
22 #include <ExecutionBurstServer.h>
23 #include <HalBufferTracker.h>
24 #include <HalInterfaces.h>
25 #include <Tracing.h>
26 #include <ValidateHal.h>
27 #include <android-base/logging.h>
28 #include <android-base/properties.h>
29 #include <hidl/LegacySupport.h>
30 #include <nnapi/Types.h>
31 #include <nnapi/hal/1.3/Conversions.h>
32 
33 #include <algorithm>
34 #include <chrono>
35 #include <map>
36 #include <memory>
37 #include <optional>
38 #include <set>
39 #include <thread>
40 #include <tuple>
41 #include <utility>
42 #include <vector>
43 
44 #include "SampleDriverUtils.h"
45 
46 namespace android {
47 namespace nn {
48 namespace sample_driver {
49 
50 namespace {
51 
microsecondsDuration(TimePoint end,TimePoint start)52 uint64_t microsecondsDuration(TimePoint end, TimePoint start) {
53     using Microseconds = std::chrono::duration<uint64_t, std::micro>;
54     return std::chrono::duration_cast<Microseconds>(end - start).count();
55 };
56 
57 }  // namespace
58 
59 static const V1_2::Timing kNoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
60 
getCapabilities(getCapabilities_cb cb)61 hardware::Return<void> SampleDriver::getCapabilities(getCapabilities_cb cb) {
62     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
63                  "SampleDriver::getCapabilities");
64     return getCapabilities_1_3(
65             [&](V1_3::ErrorStatus error, const V1_3::Capabilities& capabilities) {
66                 // TODO(dgross): Do we need to check compliantWithV1_0(capabilities)?
67                 cb(convertToV1_0(error), convertToV1_0(capabilities));
68             });
69 }
70 
getCapabilities_1_1(getCapabilities_1_1_cb cb)71 hardware::Return<void> SampleDriver::getCapabilities_1_1(getCapabilities_1_1_cb cb) {
72     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
73                  "SampleDriver::getCapabilities_1_1");
74     return getCapabilities_1_3(
75             [&](V1_3::ErrorStatus error, const V1_3::Capabilities& capabilities) {
76                 // TODO(dgross): Do we need to check compliantWithV1_1(capabilities)?
77                 cb(convertToV1_0(error), convertToV1_1(capabilities));
78             });
79 }
80 
getCapabilities_1_2(getCapabilities_1_2_cb cb)81 hardware::Return<void> SampleDriver::getCapabilities_1_2(getCapabilities_1_2_cb cb) {
82     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
83                  "SampleDriver::getCapabilities_1_2");
84     return getCapabilities_1_3(
85             [&](V1_3::ErrorStatus error, const V1_3::Capabilities& capabilities) {
86                 // TODO(dgross): Do we need to check compliantWithV1_2(capabilities)?
87                 cb(convertToV1_0(error), convertToV1_2(capabilities));
88             });
89 }
90 
getVersionString(getVersionString_cb cb)91 hardware::Return<void> SampleDriver::getVersionString(getVersionString_cb cb) {
92     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
93                  "SampleDriver::getVersionString");
94     cb(V1_0::ErrorStatus::NONE, "JUST_AN_EXAMPLE");
95     return hardware::Void();
96 }
97 
getType(getType_cb cb)98 hardware::Return<void> SampleDriver::getType(getType_cb cb) {
99     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION, "SampleDriver::getType");
100     cb(V1_0::ErrorStatus::NONE, V1_2::DeviceType::CPU);
101     return hardware::Void();
102 }
103 
getSupportedExtensions(getSupportedExtensions_cb cb)104 hardware::Return<void> SampleDriver::getSupportedExtensions(getSupportedExtensions_cb cb) {
105     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
106                  "SampleDriver::getSupportedExtensions");
107     cb(V1_0::ErrorStatus::NONE, {/* No extensions. */});
108     return hardware::Void();
109 }
110 
getSupportedOperations(const V1_0::Model & model,getSupportedOperations_cb cb)111 hardware::Return<void> SampleDriver::getSupportedOperations(const V1_0::Model& model,
112                                                             getSupportedOperations_cb cb) {
113     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
114                  "SampleDriver::getSupportedOperations");
115     if (!validateModel(model)) {
116         VLOG(DRIVER) << "getSupportedOperations";
117         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
118         return hardware::Void();
119     }
120     return getSupportedOperations_1_3(
121             convertToV1_3(model),
122             [&](V1_3::ErrorStatus status, const hardware::hidl_vec<bool>& supported) {
123                 cb(convertToV1_0(status), supported);
124             });
125 }
126 
getSupportedOperations_1_1(const V1_1::Model & model,getSupportedOperations_1_1_cb cb)127 hardware::Return<void> SampleDriver::getSupportedOperations_1_1(const V1_1::Model& model,
128                                                                 getSupportedOperations_1_1_cb cb) {
129     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
130                  "SampleDriver::getSupportedOperations_1_1");
131     if (!validateModel(model)) {
132         VLOG(DRIVER) << "getSupportedOperations_1_1";
133         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
134         return hardware::Void();
135     }
136     return getSupportedOperations_1_3(
137             convertToV1_3(model),
138             [&](V1_3::ErrorStatus status, const hardware::hidl_vec<bool>& supported) {
139                 cb(convertToV1_0(status), supported);
140             });
141 }
142 
getSupportedOperations_1_2(const V1_2::Model & model,getSupportedOperations_1_2_cb cb)143 hardware::Return<void> SampleDriver::getSupportedOperations_1_2(const V1_2::Model& model,
144                                                                 getSupportedOperations_1_2_cb cb) {
145     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
146                  "SampleDriver::getSupportedOperations_1_2");
147     if (!validateModel(model)) {
148         VLOG(DRIVER) << "getSupportedOperations_1_2";
149         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
150         return hardware::Void();
151     }
152     return getSupportedOperations_1_3(
153             convertToV1_3(model),
154             [&](V1_3::ErrorStatus status, const hardware::hidl_vec<bool>& supported) {
155                 cb(convertToV1_0(status), supported);
156             });
157 }
158 
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)159 hardware::Return<void> SampleDriver::getNumberOfCacheFilesNeeded(
160         getNumberOfCacheFilesNeeded_cb cb) {
161     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INITIALIZATION,
162                  "SampleDriver::getNumberOfCacheFilesNeeded");
163     // Set both numbers to be 0 for cache not supported.
164     cb(V1_0::ErrorStatus::NONE, /*numModelCache=*/0, /*numDataCache=*/0);
165     return hardware::Void();
166 }
167 
prepareModel(const V1_0::Model & model,const sp<V1_0::IPreparedModelCallback> & callback)168 hardware::Return<V1_0::ErrorStatus> SampleDriver::prepareModel(
169         const V1_0::Model& model, const sp<V1_0::IPreparedModelCallback>& callback) {
170     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel");
171     const V1_3::ErrorStatus status =
172             prepareModelBase(model, this, V1_1::ExecutionPreference::FAST_SINGLE_ANSWER,
173                              kDefaultPriority13, {}, callback);
174     return convertToV1_0(status);
175 }
176 
prepareModel_1_1(const V1_1::Model & model,V1_1::ExecutionPreference preference,const sp<V1_0::IPreparedModelCallback> & callback)177 hardware::Return<V1_0::ErrorStatus> SampleDriver::prepareModel_1_1(
178         const V1_1::Model& model, V1_1::ExecutionPreference preference,
179         const sp<V1_0::IPreparedModelCallback>& callback) {
180     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_1");
181     const V1_3::ErrorStatus status =
182             prepareModelBase(model, this, preference, kDefaultPriority13, {}, callback);
183     return convertToV1_0(status);
184 }
185 
prepareModel_1_2(const V1_2::Model & model,V1_1::ExecutionPreference preference,const hardware::hidl_vec<hardware::hidl_handle> &,const hardware::hidl_vec<hardware::hidl_handle> &,const HalCacheToken &,const sp<V1_2::IPreparedModelCallback> & callback)186 hardware::Return<V1_0::ErrorStatus> SampleDriver::prepareModel_1_2(
187         const V1_2::Model& model, V1_1::ExecutionPreference preference,
188         const hardware::hidl_vec<hardware::hidl_handle>&,
189         const hardware::hidl_vec<hardware::hidl_handle>&, const HalCacheToken&,
190         const sp<V1_2::IPreparedModelCallback>& callback) {
191     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_2");
192     const V1_3::ErrorStatus status =
193             prepareModelBase(model, this, preference, kDefaultPriority13, {}, callback);
194     return convertToV1_0(status);
195 }
196 
prepareModel_1_3(const V1_3::Model & model,V1_1::ExecutionPreference preference,V1_3::Priority priority,const V1_3::OptionalTimePoint & deadline,const hardware::hidl_vec<hardware::hidl_handle> &,const hardware::hidl_vec<hardware::hidl_handle> &,const HalCacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)197 hardware::Return<V1_3::ErrorStatus> SampleDriver::prepareModel_1_3(
198         const V1_3::Model& model, V1_1::ExecutionPreference preference, V1_3::Priority priority,
199         const V1_3::OptionalTimePoint& deadline, const hardware::hidl_vec<hardware::hidl_handle>&,
200         const hardware::hidl_vec<hardware::hidl_handle>&, const HalCacheToken&,
201         const sp<V1_3::IPreparedModelCallback>& callback) {
202     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION, "SampleDriver::prepareModel_1_3");
203     return prepareModelBase(model, this, preference, priority, deadline, callback);
204 }
205 
prepareModelFromCache(const hardware::hidl_vec<hardware::hidl_handle> &,const hardware::hidl_vec<hardware::hidl_handle> &,const HalCacheToken &,const sp<V1_2::IPreparedModelCallback> & callback)206 hardware::Return<V1_0::ErrorStatus> SampleDriver::prepareModelFromCache(
207         const hardware::hidl_vec<hardware::hidl_handle>&,
208         const hardware::hidl_vec<hardware::hidl_handle>&, const HalCacheToken&,
209         const sp<V1_2::IPreparedModelCallback>& callback) {
210     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
211                  "SampleDriver::prepareModelFromCache");
212     notify(callback, V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
213     return V1_0::ErrorStatus::GENERAL_FAILURE;
214 }
215 
prepareModelFromCache_1_3(const V1_3::OptionalTimePoint &,const hardware::hidl_vec<hardware::hidl_handle> &,const hardware::hidl_vec<hardware::hidl_handle> &,const HalCacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)216 hardware::Return<V1_3::ErrorStatus> SampleDriver::prepareModelFromCache_1_3(
217         const V1_3::OptionalTimePoint& /*deadline*/,
218         const hardware::hidl_vec<hardware::hidl_handle>&,
219         const hardware::hidl_vec<hardware::hidl_handle>&, const HalCacheToken&,
220         const sp<V1_3::IPreparedModelCallback>& callback) {
221     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_COMPILATION,
222                  "SampleDriver::prepareModelFromCache_1_3");
223     notify(callback, V1_3::ErrorStatus::GENERAL_FAILURE, nullptr);
224     return V1_3::ErrorStatus::GENERAL_FAILURE;
225 }
226 
getStatus()227 hardware::Return<V1_0::DeviceStatus> SampleDriver::getStatus() {
228     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_UNSPECIFIED, "SampleDriver::getStatus");
229     VLOG(DRIVER) << "getStatus()";
230     return V1_0::DeviceStatus::AVAILABLE;
231 }
232 
233 // Safely downcast an IPreparedModel object to SamplePreparedModel.
234 // This function will return nullptr if the IPreparedModel object is not originated from the sample
235 // driver process.
castToSamplePreparedModel(const sp<V1_3::IPreparedModel> & preparedModel)236 static const SamplePreparedModel* castToSamplePreparedModel(
237         const sp<V1_3::IPreparedModel>& preparedModel) {
238     if (preparedModel->isRemote()) {
239         return nullptr;
240     } else {
241         // This static_cast is safe because SamplePreparedModel is the only class that implements
242         // the IPreparedModel interface in the sample driver process.
243         return static_cast<const SamplePreparedModel*>(preparedModel.get());
244     }
245 }
246 
allocate(const V1_3::BufferDesc & desc,const hardware::hidl_vec<sp<V1_3::IPreparedModel>> & preparedModels,const hardware::hidl_vec<V1_3::BufferRole> & inputRoles,const hardware::hidl_vec<V1_3::BufferRole> & outputRoles,allocate_cb cb)247 hardware::Return<void> SampleDriver::allocate(
248         const V1_3::BufferDesc& desc,
249         const hardware::hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
250         const hardware::hidl_vec<V1_3::BufferRole>& inputRoles,
251         const hardware::hidl_vec<V1_3::BufferRole>& outputRoles, allocate_cb cb) {
252     constexpr uint32_t kInvalidBufferToken = 0;
253 
254     VLOG(DRIVER) << "SampleDriver::allocate";
255     std::set<HalPreparedModelRole> roles;
256     V1_3::Operand operand;
257     auto getModel = [](const sp<V1_3::IPreparedModel>& preparedModel) -> const V1_3::Model* {
258         const auto* samplePreparedModel = castToSamplePreparedModel(preparedModel);
259         if (samplePreparedModel == nullptr) {
260             LOG(ERROR) << "SampleDriver::allocate -- unknown remote IPreparedModel.";
261             return nullptr;
262         }
263         return samplePreparedModel->getModel();
264     };
265     if (!validateMemoryDesc(desc, preparedModels, inputRoles, outputRoles, getModel, &roles,
266                             &operand)) {
267         LOG(ERROR) << "SampleDriver::allocate -- validation failed.";
268         cb(V1_3::ErrorStatus::INVALID_ARGUMENT, nullptr, kInvalidBufferToken);
269         return hardware::Void();
270     }
271 
272     if (isExtensionOperandType(operand.type)) {
273         LOG(ERROR) << "SampleDriver::allocate -- does not support extension type.";
274         cb(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
275         return hardware::Void();
276     }
277 
278     // TODO(xusongw): Support allocating buffers with unknown dimensions or rank.
279     uint32_t size = nonExtensionOperandSizeOfData(operand.type, operand.dimensions);
280     VLOG(DRIVER) << "SampleDriver::allocate -- type = " << toString(operand.type)
281                  << ", dimensions = " << toString(operand.dimensions) << ", size = " << size;
282     if (size == 0) {
283         LOG(ERROR) << "SampleDriver::allocate -- does not support dynamic output shape.";
284         cb(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
285         return hardware::Void();
286     }
287 
288     auto bufferWrapper =
289             HalManagedBuffer::create(size, std::move(roles), uncheckedConvert(operand));
290     if (bufferWrapper == nullptr) {
291         LOG(ERROR) << "SampleDriver::allocate -- not enough memory.";
292         cb(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
293         return hardware::Void();
294     }
295 
296     auto token = mHalBufferTracker->add(bufferWrapper);
297     if (token == nullptr) {
298         LOG(ERROR) << "SampleDriver::allocate -- HalBufferTracker returned invalid token.";
299         cb(V1_3::ErrorStatus::GENERAL_FAILURE, nullptr, kInvalidBufferToken);
300         return hardware::Void();
301     }
302 
303     const uint32_t tokenValue = token->get();
304     sp<SampleBuffer> sampleBuffer = new SampleBuffer(std::move(bufferWrapper), std::move(token));
305     VLOG(DRIVER) << "SampleDriver::allocate -- successfully allocates the requested memory";
306     cb(V1_3::ErrorStatus::NONE, std::move(sampleBuffer), tokenValue);
307     return hardware::Void();
308 }
309 
run()310 int SampleDriver::run() {
311     android::hardware::configureRpcThreadpool(4, true);
312     if (registerAsService(mName) != android::OK) {
313         LOG(ERROR) << "Could not register service";
314         return 1;
315     }
316     android::hardware::joinRpcThreadpool();
317     LOG(ERROR) << "Service exited!";
318     return 1;
319 }
320 
copyRunTimePoolInfos(const RunTimePoolInfo & srcPool,const RunTimePoolInfo & dstPool)321 static void copyRunTimePoolInfos(const RunTimePoolInfo& srcPool, const RunTimePoolInfo& dstPool) {
322     CHECK(srcPool.getBuffer() != nullptr);
323     CHECK(dstPool.getBuffer() != nullptr);
324     CHECK(srcPool.getSize() == dstPool.getSize());
325     std::copy(srcPool.getBuffer(), srcPool.getBuffer() + srcPool.getSize(), dstPool.getBuffer());
326     dstPool.flush();
327 }
328 
copyTo(const hardware::hidl_memory & dst)329 hardware::Return<V1_3::ErrorStatus> SampleBuffer::copyTo(const hardware::hidl_memory& dst) {
330     const auto dstPool = RunTimePoolInfo::createFromMemory(uncheckedConvert(dst));
331     if (!dstPool.has_value()) {
332         LOG(ERROR) << "SampleBuffer::copyTo -- unable to map dst memory.";
333         return V1_3::ErrorStatus::GENERAL_FAILURE;
334     }
335     const V1_3::ErrorStatus validationStatus =
336             convertToV1_3(kBuffer->validateCopyTo(dstPool->getSize()));
337     if (validationStatus != V1_3::ErrorStatus::NONE) {
338         return validationStatus;
339     }
340     const auto srcPool = kBuffer->createRunTimePoolInfo();
341     copyRunTimePoolInfos(srcPool, dstPool.value());
342     return V1_3::ErrorStatus::NONE;
343 }
344 
copyFromInternal(const hardware::hidl_memory & src,const hardware::hidl_vec<uint32_t> & dimensions,const std::shared_ptr<HalManagedBuffer> & bufferWrapper)345 static V1_3::ErrorStatus copyFromInternal(const hardware::hidl_memory& src,
346                                           const hardware::hidl_vec<uint32_t>& dimensions,
347                                           const std::shared_ptr<HalManagedBuffer>& bufferWrapper) {
348     CHECK(bufferWrapper != nullptr);
349     const auto srcPool = RunTimePoolInfo::createFromMemory(uncheckedConvert(src));
350     if (!srcPool.has_value()) {
351         LOG(ERROR) << "SampleBuffer::copyFrom -- unable to map src memory.";
352         return V1_3::ErrorStatus::GENERAL_FAILURE;
353     }
354     const V1_3::ErrorStatus validationStatus =
355             convertToV1_3(bufferWrapper->validateCopyFrom(dimensions, srcPool->getSize()));
356     if (validationStatus != V1_3::ErrorStatus::NONE) {
357         return validationStatus;
358     }
359     const auto dstPool = bufferWrapper->createRunTimePoolInfo();
360     copyRunTimePoolInfos(srcPool.value(), dstPool);
361     return V1_3::ErrorStatus::NONE;
362 }
363 
copyFrom(const hardware::hidl_memory & src,const hardware::hidl_vec<uint32_t> & dimensions)364 hardware::Return<V1_3::ErrorStatus> SampleBuffer::copyFrom(
365         const hardware::hidl_memory& src, const hardware::hidl_vec<uint32_t>& dimensions) {
366     const auto status = copyFromInternal(src, dimensions, kBuffer);
367     if (status == V1_3::ErrorStatus::NONE) {
368         kBuffer->updateDimensions(dimensions);
369         kBuffer->setInitialized(true);
370     } else {
371         kBuffer->setInitialized(false);
372     }
373     return status;
374 }
375 
initialize()376 bool SamplePreparedModel::initialize() {
377     return setRunTimePoolInfosFromCanonicalMemories(&mPoolInfos, uncheckedConvert(mModel.pools));
378 }
379 
380 static std::tuple<V1_3::ErrorStatus, std::vector<RunTimePoolInfo>,
381                   std::vector<std::shared_ptr<HalManagedBuffer>>>
createRunTimePoolInfos(const V1_3::Request & request,const SampleDriver & driver,const SamplePreparedModel * preparedModel)382 createRunTimePoolInfos(const V1_3::Request& request, const SampleDriver& driver,
383                        const SamplePreparedModel* preparedModel) {
384     std::vector<RunTimePoolInfo> requestPoolInfos;
385     std::vector<std::shared_ptr<HalManagedBuffer>> bufferWrappers;
386     requestPoolInfos.reserve(request.pools.size());
387     bufferWrappers.reserve(request.pools.size());
388     for (uint32_t i = 0; i < request.pools.size(); i++) {
389         auto& pool = request.pools[i];
390         switch (pool.getDiscriminator()) {
391             case V1_3::Request::MemoryPool::hidl_discriminator::hidlMemory: {
392                 auto buffer =
393                         RunTimePoolInfo::createFromMemory(uncheckedConvert(pool.hidlMemory()));
394                 if (!buffer.has_value()) {
395                     LOG(ERROR) << "createRuntimeMemoriesFromMemoryPools -- could not map pools";
396                     return {V1_3::ErrorStatus::GENERAL_FAILURE, {}, {}};
397                 }
398                 requestPoolInfos.push_back(std::move(*buffer));
399                 bufferWrappers.push_back(nullptr);
400             } break;
401             case V1_3::Request::MemoryPool::hidl_discriminator::token: {
402                 auto bufferWrapper = driver.getHalBufferTracker()->get(pool.token());
403                 if (bufferWrapper == nullptr) {
404                     return {V1_3::ErrorStatus::INVALID_ARGUMENT, {}, {}};
405                 }
406                 const auto validationStatus = convertToV1_3(bufferWrapper->validateRequest(
407                         i, uncheckedConvert(request), preparedModel));
408                 if (validationStatus != V1_3::ErrorStatus::NONE) {
409                     return {validationStatus, {}, {}};
410                 }
411                 requestPoolInfos.push_back(bufferWrapper->createRunTimePoolInfo());
412                 bufferWrappers.push_back(std::move(bufferWrapper));
413             } break;
414         }
415     }
416     return {V1_3::ErrorStatus::NONE, std::move(requestPoolInfos), std::move(bufferWrappers)};
417 }
418 
updateDeviceMemories(V1_3::ErrorStatus status,const V1_3::Request & request,const std::vector<std::shared_ptr<HalManagedBuffer>> & bufferWrappers,const hardware::hidl_vec<V1_2::OutputShape> & outputShapes)419 static V1_3::ErrorStatus updateDeviceMemories(
420         V1_3::ErrorStatus status, const V1_3::Request& request,
421         const std::vector<std::shared_ptr<HalManagedBuffer>>& bufferWrappers,
422         const hardware::hidl_vec<V1_2::OutputShape>& outputShapes) {
423     if (status == V1_3::ErrorStatus::NONE) {
424         for (uint32_t i = 0; i < request.outputs.size(); i++) {
425             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
426             const auto& pool = request.pools[poolIndex];
427             if (pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::token) {
428                 if (!bufferWrappers[poolIndex]->updateDimensions(outputShapes[i].dimensions)) {
429                     return V1_3::ErrorStatus::GENERAL_FAILURE;
430                 }
431             }
432         }
433         for (uint32_t i = 0; i < request.outputs.size(); i++) {
434             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
435             const auto& pool = request.pools[poolIndex];
436             if (pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::token) {
437                 bufferWrappers[poolIndex]->setInitialized(true);
438             }
439         }
440     } else if (status == V1_3::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE) {
441         // If CpuExecutor reports OUTPUT_INSUFFCIENT_SIZE on a device memory, this is because the
442         // dimensions of the device memory are incorrectly specified. The driver should return
443         // GENERAL_FAILURE instead in this case.
444         for (uint32_t i = 0; i < request.outputs.size(); i++) {
445             const uint32_t poolIndex = request.outputs[i].location.poolIndex;
446             const auto& pool = request.pools[poolIndex];
447             if (pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::token) {
448                 if (!outputShapes[i].isSufficient) {
449                     LOG(ERROR) << "Invalid dimensions for output " << i
450                                << ": actual shape = " << toString(outputShapes[i].dimensions);
451                     return V1_3::ErrorStatus::GENERAL_FAILURE;
452                 }
453             }
454         }
455     }
456     return V1_3::ErrorStatus::NONE;
457 }
458 
459 template <typename T_IExecutionCallback>
asyncExecute(const V1_3::Request & request,V1_2::MeasureTiming measure,TimePoint driverStart,const V1_3::Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<T_IExecutionCallback> & callback)460 void asyncExecute(const V1_3::Request& request, V1_2::MeasureTiming measure, TimePoint driverStart,
461                   const V1_3::Model& model, const SampleDriver& driver,
462                   const SamplePreparedModel* preparedModel,
463                   const std::vector<RunTimePoolInfo>& poolInfos, const OptionalTimePoint& deadline,
464                   const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
465                   const sp<T_IExecutionCallback>& callback) {
466     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
467                  "SampleDriver::asyncExecute");
468 
469     const auto [poolStatus, requestPoolInfos, bufferWrappers] =
470             createRunTimePoolInfos(request, driver, preparedModel);
471     if (poolStatus != V1_3::ErrorStatus::NONE) {
472         notify(callback, poolStatus, {}, kNoTiming);
473         return;
474     }
475 
476     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
477                         "SampleDriver::asyncExecute");
478     CpuExecutor executor = driver.getExecutor();
479     if (loopTimeoutDuration.getDiscriminator() !=
480         V1_3::OptionalTimeoutDuration::hidl_discriminator::none) {
481         executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
482     }
483     if (deadline.has_value()) {
484         executor.setDeadline(*deadline);
485     }
486     TimePoint driverEnd, deviceStart, deviceEnd;
487     if (measure == V1_2::MeasureTiming::YES) deviceStart = Clock::now();
488     int n = executor.run(uncheckedConvert(model), uncheckedConvert(request), poolInfos,
489                          requestPoolInfos);
490     if (measure == V1_2::MeasureTiming::YES) deviceEnd = Clock::now();
491     VLOG(DRIVER) << "executor.run returned " << n;
492     V1_3::ErrorStatus executionStatus = convertResultCodeToHalErrorStatus(n);
493     hardware::hidl_vec<V1_2::OutputShape> outputShapes = convertToV1_2(executor.getOutputShapes());
494 
495     // Update device memory metadata.
496     const V1_3::ErrorStatus updateStatus =
497             updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
498     if (updateStatus != V1_3::ErrorStatus::NONE) {
499         notify(callback, updateStatus, {}, kNoTiming);
500         return;
501     }
502 
503     if (measure == V1_2::MeasureTiming::YES && executionStatus == V1_3::ErrorStatus::NONE) {
504         driverEnd = Clock::now();
505         V1_2::Timing timing = {
506                 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
507                 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
508         VLOG(DRIVER) << "SampleDriver::asyncExecute timing = " << toString(timing);
509         notify(callback, executionStatus, outputShapes, timing);
510     } else {
511         notify(callback, executionStatus, outputShapes, kNoTiming);
512     }
513 }
514 
515 template <typename T_IExecutionCallback>
executeBase(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const V1_3::OptionalTimePoint & halDeadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<T_IExecutionCallback> & callback)516 V1_3::ErrorStatus executeBase(const V1_3::Request& request, V1_2::MeasureTiming measure,
517                               const V1_3::Model& model, const SampleDriver& driver,
518                               const SamplePreparedModel* preparedModel,
519                               const std::vector<RunTimePoolInfo>& poolInfos,
520                               const V1_3::OptionalTimePoint& halDeadline,
521                               const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
522                               const sp<T_IExecutionCallback>& callback) {
523     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION, "SampleDriver::executeBase");
524     VLOG(DRIVER) << "executeBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
525 
526     TimePoint driverStart;
527     if (measure == V1_2::MeasureTiming::YES) driverStart = Clock::now();
528 
529     if (callback.get() == nullptr) {
530         LOG(ERROR) << "invalid callback passed to executeBase";
531         return V1_3::ErrorStatus::INVALID_ARGUMENT;
532     }
533     if (!validateRequest(request, model)) {
534         notify(callback, V1_3::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming);
535         return V1_3::ErrorStatus::INVALID_ARGUMENT;
536     }
537     const auto deadline = convert(halDeadline).value();
538     if (hasDeadlinePassed(deadline)) {
539         notify(callback, V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT, {}, kNoTiming);
540         return V1_3::ErrorStatus::NONE;
541     }
542 
543     // This thread is intentionally detached because the sample driver service
544     // is expected to live forever.
545     std::thread([&model, &driver, preparedModel, &poolInfos, request, measure, driverStart,
546                  deadline, loopTimeoutDuration, callback] {
547         asyncExecute(request, measure, driverStart, model, driver, preparedModel, poolInfos,
548                      deadline, loopTimeoutDuration, callback);
549     }).detach();
550 
551     return V1_3::ErrorStatus::NONE;
552 }
553 
execute(const V1_0::Request & request,const sp<V1_0::IExecutionCallback> & callback)554 hardware::Return<V1_0::ErrorStatus> SamplePreparedModel::execute(
555         const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) {
556     const V1_3::ErrorStatus status =
557             executeBase(convertToV1_3(request), V1_2::MeasureTiming::NO, mModel, *mDriver, this,
558                         mPoolInfos, {}, {}, callback);
559     return convertToV1_0(status);
560 }
561 
execute_1_2(const V1_0::Request & request,V1_2::MeasureTiming measure,const sp<V1_2::IExecutionCallback> & callback)562 hardware::Return<V1_0::ErrorStatus> SamplePreparedModel::execute_1_2(
563         const V1_0::Request& request, V1_2::MeasureTiming measure,
564         const sp<V1_2::IExecutionCallback>& callback) {
565     const V1_3::ErrorStatus status = executeBase(convertToV1_3(request), measure, mModel, *mDriver,
566                                                  this, mPoolInfos, {}, {}, callback);
567     return convertToV1_0(status);
568 }
569 
execute_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const sp<V1_3::IExecutionCallback> & callback)570 hardware::Return<V1_3::ErrorStatus> SamplePreparedModel::execute_1_3(
571         const V1_3::Request& request, V1_2::MeasureTiming measure,
572         const V1_3::OptionalTimePoint& deadline,
573         const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
574         const sp<V1_3::IExecutionCallback>& callback) {
575     return executeBase(request, measure, mModel, *mDriver, this, mPoolInfos, deadline,
576                        loopTimeoutDuration, callback);
577 }
578 
579 static std::tuple<V1_3::ErrorStatus, hardware::hidl_vec<V1_2::OutputShape>, V1_2::Timing>
executeSynchronouslyBase(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::Model & model,const SampleDriver & driver,const SamplePreparedModel * preparedModel,const std::vector<RunTimePoolInfo> & poolInfos,const V1_3::OptionalTimePoint & halDeadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration)580 executeSynchronouslyBase(const V1_3::Request& request, V1_2::MeasureTiming measure,
581                          const V1_3::Model& model, const SampleDriver& driver,
582                          const SamplePreparedModel* preparedModel,
583                          const std::vector<RunTimePoolInfo>& poolInfos,
584                          const V1_3::OptionalTimePoint& halDeadline,
585                          const V1_3::OptionalTimeoutDuration& loopTimeoutDuration) {
586     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
587                  "SampleDriver::executeSynchronouslyBase");
588     VLOG(DRIVER) << "executeSynchronouslyBase(" << SHOW_IF_DEBUG(toString(request)) << ")";
589 
590     TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
591     if (measure == V1_2::MeasureTiming::YES) driverStart = Clock::now();
592 
593     if (!validateRequest(request, model)) {
594         return {V1_3::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
595     }
596     const auto deadline = convert(halDeadline).value();
597     if (hasDeadlinePassed(deadline)) {
598         return {V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT, {}, kNoTiming};
599     }
600 
601     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
602                         "SampleDriver::executeSynchronouslyBase");
603     const auto [poolStatus, requestPoolInfos, bufferWrappers] =
604             createRunTimePoolInfos(request, driver, preparedModel);
605     if (poolStatus != V1_3::ErrorStatus::NONE) {
606         return {poolStatus, {}, kNoTiming};
607     }
608 
609     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
610                         "SampleDriver::executeSynchronouslyBase");
611     CpuExecutor executor = driver.getExecutor();
612     if (loopTimeoutDuration.getDiscriminator() !=
613         V1_3::OptionalTimeoutDuration::hidl_discriminator::none) {
614         executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
615     }
616     if (deadline.has_value()) {
617         executor.setDeadline(*deadline);
618     }
619     if (measure == V1_2::MeasureTiming::YES) deviceStart = Clock::now();
620     int n = executor.run(uncheckedConvert(model), uncheckedConvert(request), poolInfos,
621                          requestPoolInfos);
622     if (measure == V1_2::MeasureTiming::YES) deviceEnd = Clock::now();
623     VLOG(DRIVER) << "executor.run returned " << n;
624     V1_3::ErrorStatus executionStatus = convertResultCodeToHalErrorStatus(n);
625     hardware::hidl_vec<V1_2::OutputShape> outputShapes = convertToV1_2(executor.getOutputShapes());
626 
627     // Update device memory metadata.
628     const V1_3::ErrorStatus updateStatus =
629             updateDeviceMemories(executionStatus, request, bufferWrappers, outputShapes);
630     if (updateStatus != V1_3::ErrorStatus::NONE) {
631         return {updateStatus, {}, kNoTiming};
632     }
633 
634     if (measure == V1_2::MeasureTiming::YES && executionStatus == V1_3::ErrorStatus::NONE) {
635         driverEnd = Clock::now();
636         V1_2::Timing timing = {
637                 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
638                 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
639         VLOG(DRIVER) << "executeSynchronouslyBase timing = " << toString(timing);
640         return {executionStatus, std::move(outputShapes), timing};
641     }
642     return {executionStatus, std::move(outputShapes), kNoTiming};
643 }
644 
executeSynchronously(const V1_0::Request & request,V1_2::MeasureTiming measure,executeSynchronously_cb cb)645 hardware::Return<void> SamplePreparedModel::executeSynchronously(const V1_0::Request& request,
646                                                                  V1_2::MeasureTiming measure,
647                                                                  executeSynchronously_cb cb) {
648     auto [status, outputShapes, timing] = executeSynchronouslyBase(
649             convertToV1_3(request), measure, mModel, *mDriver, this, mPoolInfos, {}, {});
650     cb(convertToV1_0(status), std::move(outputShapes), timing);
651     return hardware::Void();
652 }
653 
executeSynchronously_1_3(const V1_3::Request & request,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & deadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,executeSynchronously_1_3_cb cb)654 hardware::Return<void> SamplePreparedModel::executeSynchronously_1_3(
655         const V1_3::Request& request, V1_2::MeasureTiming measure,
656         const V1_3::OptionalTimePoint& deadline,
657         const V1_3::OptionalTimeoutDuration& loopTimeoutDuration, executeSynchronously_1_3_cb cb) {
658     auto [status, outputShapes, timing] = executeSynchronouslyBase(
659             request, measure, mModel, *mDriver, this, mPoolInfos, deadline, loopTimeoutDuration);
660     cb(status, std::move(outputShapes), timing);
661     return hardware::Void();
662 }
663 
664 // The sample driver will finish the execution and then return.
executeFenced(const V1_3::Request & request,const hardware::hidl_vec<hardware::hidl_handle> & waitFor,V1_2::MeasureTiming measure,const V1_3::OptionalTimePoint & halDeadline,const V1_3::OptionalTimeoutDuration & loopTimeoutDuration,const V1_3::OptionalTimeoutDuration & duration,executeFenced_cb cb)665 hardware::Return<void> SamplePreparedModel::executeFenced(
666         const V1_3::Request& request, const hardware::hidl_vec<hardware::hidl_handle>& waitFor,
667         V1_2::MeasureTiming measure, const V1_3::OptionalTimePoint& halDeadline,
668         const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
669         const V1_3::OptionalTimeoutDuration& duration, executeFenced_cb cb) {
670     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
671                  "SamplePreparedModel::executeFenced");
672     VLOG(DRIVER) << "executeFenced(" << SHOW_IF_DEBUG(toString(request)) << ")";
673 
674     TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
675     if (measure == V1_2::MeasureTiming::YES) driverStart = Clock::now();
676 
677     if (!validateRequest(request, mModel, /*allowUnspecifiedOutput=*/false)) {
678         cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hardware::hidl_handle(nullptr), nullptr);
679         return hardware::Void();
680     }
681     const auto deadline = convert(halDeadline).value();
682     if (hasDeadlinePassed(deadline)) {
683         cb(V1_3::ErrorStatus::MISSED_DEADLINE_PERSISTENT, hardware::hidl_handle(nullptr), nullptr);
684         return hardware::Void();
685     }
686 
687     // Wait for the dependent events to signal
688     for (const auto& fenceHandle : waitFor) {
689         if (!fenceHandle.getNativeHandle()) {
690             cb(V1_3::ErrorStatus::INVALID_ARGUMENT, hardware::hidl_handle(nullptr), nullptr);
691             return hardware::Void();
692         }
693         int syncFenceFd = fenceHandle.getNativeHandle()->data[0];
694         if (syncWait(syncFenceFd, -1) != FenceState::SIGNALED) {
695             LOG(ERROR) << "syncWait failed";
696             cb(V1_3::ErrorStatus::GENERAL_FAILURE, hardware::hidl_handle(nullptr), nullptr);
697             return hardware::Void();
698         }
699     }
700 
701     // Update deadline if the timeout duration is closer than the deadline.
702     auto closestDeadline = deadline;
703     if (duration.getDiscriminator() != V1_3::OptionalTimeoutDuration::hidl_discriminator::none) {
704         const auto timeoutDurationDeadline = makeDeadline(duration.nanoseconds());
705         if (!closestDeadline.has_value() || *closestDeadline > timeoutDurationDeadline) {
706             closestDeadline = timeoutDurationDeadline;
707         }
708     }
709 
710     TimePoint driverStartAfterFence;
711     if (measure == V1_2::MeasureTiming::YES) driverStartAfterFence = Clock::now();
712 
713     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_INPUTS_AND_OUTPUTS,
714                         "SamplePreparedModel::executeFenced");
715     const auto [poolStatus, requestPoolInfos, bufferWrappers] =
716             createRunTimePoolInfos(request, *mDriver, this);
717     if (poolStatus != V1_3::ErrorStatus::NONE) {
718         cb(poolStatus, hardware::hidl_handle(nullptr), nullptr);
719         return hardware::Void();
720     }
721 
722     NNTRACE_FULL_SWITCH(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
723                         "SamplePreparedModel::executeFenced");
724     CpuExecutor executor = mDriver->getExecutor();
725     if (loopTimeoutDuration.getDiscriminator() !=
726         V1_3::OptionalTimeoutDuration::hidl_discriminator::none) {
727         executor.setLoopTimeout(loopTimeoutDuration.nanoseconds());
728     }
729     if (closestDeadline.has_value()) {
730         executor.setDeadline(*closestDeadline);
731     }
732     if (measure == V1_2::MeasureTiming::YES) deviceStart = Clock::now();
733     int n = executor.run(uncheckedConvert(mModel), uncheckedConvert(request), mPoolInfos,
734                          requestPoolInfos);
735     if (measure == V1_2::MeasureTiming::YES) deviceEnd = Clock::now();
736     VLOG(DRIVER) << "executor.run returned " << n;
737     V1_3::ErrorStatus executionStatus = convertResultCodeToHalErrorStatus(n);
738     if (executionStatus != V1_3::ErrorStatus::NONE) {
739         cb(executionStatus, hardware::hidl_handle(nullptr), nullptr);
740         return hardware::Void();
741     }
742 
743     // Set output memories to the initialized state.
744     if (executionStatus == V1_3::ErrorStatus::NONE) {
745         for (const auto& output : request.outputs) {
746             const uint32_t poolIndex = output.location.poolIndex;
747             const auto& pool = request.pools[poolIndex];
748             if (pool.getDiscriminator() == V1_3::Request::MemoryPool::hidl_discriminator::token) {
749                 bufferWrappers[poolIndex]->setInitialized(true);
750             }
751         }
752     }
753 
754     V1_2::Timing timingSinceLaunch = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
755     V1_2::Timing timingAfterFence = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
756     if (measure == V1_2::MeasureTiming::YES) {
757         driverEnd = Clock::now();
758         timingSinceLaunch = {
759                 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
760                 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
761         timingAfterFence = {
762                 .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
763                 .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStartAfterFence))};
764         VLOG(DRIVER) << "executeFenced timingSinceLaunch = " << toString(timingSinceLaunch);
765         VLOG(DRIVER) << "executeFenced timingAfterFence = " << toString(timingAfterFence);
766     }
767     sp<SampleFencedExecutionCallback> fencedExecutionCallback =
768             new SampleFencedExecutionCallback(timingSinceLaunch, timingAfterFence, executionStatus);
769     cb(executionStatus, hardware::hidl_handle(nullptr), fencedExecutionCallback);
770     return hardware::Void();
771 }
772 
773 // BurstExecutorWithCache maps hidl_memory when it is first seen, and preserves
774 // the mapping until either (1) the memory is freed in the runtime, or (2) the
775 // burst object is destroyed. This allows for subsequent executions operating on
776 // pools that have been used before to reuse the mapping instead of mapping and
777 // unmapping the memory on each execution.
778 class BurstExecutorWithCache : public ExecutionBurstServer::IBurstExecutorWithCache {
779    public:
BurstExecutorWithCache(const V1_3::Model & model,const SampleDriver * driver,const std::vector<RunTimePoolInfo> & poolInfos)780     BurstExecutorWithCache(const V1_3::Model& model, const SampleDriver* driver,
781                            const std::vector<RunTimePoolInfo>& poolInfos)
782         : mModel(model), mDriver(driver), mModelPoolInfos(poolInfos) {}
783 
isCacheEntryPresent(int32_t slot) const784     bool isCacheEntryPresent(int32_t slot) const override {
785         const auto it = mMemoryCache.find(slot);
786         return (it != mMemoryCache.end()) && it->second.has_value();
787     }
788 
addCacheEntry(const hardware::hidl_memory & memory,int32_t slot)789     void addCacheEntry(const hardware::hidl_memory& memory, int32_t slot) override {
790         mMemoryCache[slot] = RunTimePoolInfo::createFromMemory(uncheckedConvert(memory));
791     }
792 
removeCacheEntry(int32_t slot)793     void removeCacheEntry(int32_t slot) override { mMemoryCache.erase(slot); }
794 
execute(const V1_0::Request & request,const std::vector<int32_t> & slots,V1_2::MeasureTiming measure)795     std::tuple<V1_0::ErrorStatus, hardware::hidl_vec<V1_2::OutputShape>, V1_2::Timing> execute(
796             const V1_0::Request& request, const std::vector<int32_t>& slots,
797             V1_2::MeasureTiming measure) override {
798         NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
799                      "BurstExecutorWithCache::execute");
800 
801         TimePoint driverStart, driverEnd, deviceStart, deviceEnd;
802         if (measure == V1_2::MeasureTiming::YES) driverStart = Clock::now();
803 
804         // ensure all relevant pools are valid
805         if (!std::all_of(slots.begin(), slots.end(),
806                          [this](int32_t slot) { return isCacheEntryPresent(slot); })) {
807             return {V1_0::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
808         }
809 
810         // finish the request object (for validation)
811         hardware::hidl_vec<V1_3::Request::MemoryPool> pools(slots.size());
812         std::transform(slots.begin(), slots.end(), pools.begin(), [this](int32_t slot) {
813             V1_3::Request::MemoryPool pool;
814             pool.hidlMemory(convertToV1_0(mMemoryCache[slot]->getMemory()));
815             return pool;
816         });
817         V1_3::Request fullRequest = {.inputs = request.inputs, .outputs = request.outputs};
818         fullRequest.pools = std::move(pools);
819 
820         // validate request object against the model
821         if (!validateRequest(fullRequest, mModel)) {
822             return {V1_0::ErrorStatus::INVALID_ARGUMENT, {}, kNoTiming};
823         }
824 
825         // select relevant entries from cache
826         std::vector<RunTimePoolInfo> requestPoolInfos;
827         requestPoolInfos.reserve(slots.size());
828         std::transform(slots.begin(), slots.end(), std::back_inserter(requestPoolInfos),
829                        [this](int32_t slot) { return *mMemoryCache[slot]; });
830 
831         // execution
832         // Configuring the loop timeout duration is not supported. This is OK
833         // because burst does not support HAL 1.3 and hence does not support
834         // WHILE loops.
835         CpuExecutor executor = mDriver->getExecutor();
836         if (measure == V1_2::MeasureTiming::YES) deviceStart = Clock::now();
837         int n = executor.run(uncheckedConvert(mModel), uncheckedConvert(fullRequest),
838                              mModelPoolInfos, requestPoolInfos);
839         if (measure == V1_2::MeasureTiming::YES) deviceEnd = Clock::now();
840         VLOG(DRIVER) << "executor.run returned " << n;
841         V1_0::ErrorStatus executionStatus = convertToV1_0(convertResultCodeToHalErrorStatus(n));
842         hardware::hidl_vec<V1_2::OutputShape> outputShapes =
843                 convertToV1_2(executor.getOutputShapes());
844         if (measure == V1_2::MeasureTiming::YES && executionStatus == V1_0::ErrorStatus::NONE) {
845             driverEnd = Clock::now();
846             V1_2::Timing timing = {
847                     .timeOnDevice = uint64_t(microsecondsDuration(deviceEnd, deviceStart)),
848                     .timeInDriver = uint64_t(microsecondsDuration(driverEnd, driverStart))};
849             VLOG(DRIVER) << "BurstExecutorWithCache::execute timing = " << toString(timing);
850             return std::make_tuple(executionStatus, outputShapes, timing);
851         } else {
852             return std::make_tuple(executionStatus, outputShapes, kNoTiming);
853         }
854     }
855 
856    private:
857     const V1_3::Model mModel;
858     const SampleDriver* const mDriver;
859     const std::vector<RunTimePoolInfo> mModelPoolInfos;
860     std::map<int32_t, std::optional<RunTimePoolInfo>> mMemoryCache;  // cached requestPoolInfos
861 };
862 
863 // This is the amount of time the ExecutionBurstServer should spend polling the
864 // FMQ to see if it has data available before it should fall back to waiting on
865 // the futex.
getPollingTimeWindow()866 static std::chrono::microseconds getPollingTimeWindow() {
867     constexpr int32_t defaultPollingTimeWindow = 50;
868 #ifdef NN_DEBUGGABLE
869     constexpr int32_t minPollingTimeWindow = 0;
870     const int32_t selectedPollingTimeWindow =
871             base::GetIntProperty("debug.nn.sample-driver-burst-polling-window",
872                                  defaultPollingTimeWindow, minPollingTimeWindow);
873     return std::chrono::microseconds{selectedPollingTimeWindow};
874 #else
875     return std::chrono::microseconds{defaultPollingTimeWindow};
876 #endif  // NN_DEBUGGABLE
877 }
878 
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,configureExecutionBurst_cb cb)879 hardware::Return<void> SamplePreparedModel::configureExecutionBurst(
880         const sp<V1_2::IBurstCallback>& callback,
881         const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
882         const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
883         configureExecutionBurst_cb cb) {
884     NNTRACE_FULL(NNTRACE_LAYER_DRIVER, NNTRACE_PHASE_EXECUTION,
885                  "SampleDriver::configureExecutionBurst");
886 
887     const bool preferPowerOverLatency = (kPreference == V1_1::ExecutionPreference::LOW_POWER);
888     const auto pollingTimeWindow =
889             (preferPowerOverLatency ? std::chrono::microseconds{0} : getPollingTimeWindow());
890 
891     // Alternatively, the burst could be configured via:
892     // const sp<V1_2::IBurstContext> burst =
893     //         ExecutionBurstServer::create(callback, requestChannel,
894     //                                      resultChannel, this,
895     //                                      pollingTimeWindow);
896     //
897     // However, this alternative representation does not include a memory map
898     // caching optimization, and adds overhead.
899     const std::shared_ptr<BurstExecutorWithCache> executorWithCache =
900             std::make_shared<BurstExecutorWithCache>(mModel, mDriver, mPoolInfos);
901     const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(
902             callback, requestChannel, resultChannel, executorWithCache, pollingTimeWindow);
903 
904     if (burst == nullptr) {
905         cb(V1_0::ErrorStatus::GENERAL_FAILURE, {});
906     } else {
907         cb(V1_0::ErrorStatus::NONE, burst);
908     }
909 
910     return hardware::Void();
911 }
912 
913 }  // namespace sample_driver
914 }  // namespace nn
915 }  // namespace android
916