1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "ExecutionBurstController"
18 
19 #include "ExecutionBurstController.h"
20 
21 #include <android-base/logging.h>
22 
23 #include <algorithm>
24 #include <cstring>
25 #include <limits>
26 #include <memory>
27 #include <string>
28 #include <thread>
29 #include <tuple>
30 #include <utility>
31 #include <vector>
32 
33 #include "HalInterfaces.h"
34 #include "Tracing.h"
35 #include "Utils.h"
36 
37 namespace android::nn {
38 namespace {
39 
40 using V1_2::FmqRequestDatum;
41 using V1_2::FmqResultDatum;
42 using V1_2::IBurstCallback;
43 using V1_2::IBurstContext;
44 using FmqRequestDescriptor = hardware::MQDescriptorSync<FmqRequestDatum>;
45 using FmqResultDescriptor = hardware::MQDescriptorSync<FmqResultDatum>;
46 
47 constexpr V1_2::Timing kNoTiming12 = {std::numeric_limits<uint64_t>::max(),
48                                       std::numeric_limits<uint64_t>::max()};
49 
50 class BurstContextDeathHandler : public hardware::hidl_death_recipient {
51    public:
52     using Callback = std::function<void()>;
53 
BurstContextDeathHandler(const Callback & onDeathCallback)54     BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) {
55         CHECK(onDeathCallback != nullptr);
56     }
57 
serviceDied(uint64_t,const wp<hidl::base::V1_0::IBase> &)58     void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
59         LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!";
60         mOnDeathCallback();
61     }
62 
63    private:
64     const Callback mOnDeathCallback;
65 };
66 
67 }  // anonymous namespace
68 
69 // serialize a request into a packet
serialize(const V1_0::Request & request,V1_2::MeasureTiming measure,const std::vector<int32_t> & slots)70 std::vector<FmqRequestDatum> serialize(const V1_0::Request& request, V1_2::MeasureTiming measure,
71                                        const std::vector<int32_t>& slots) {
72     // count how many elements need to be sent for a request
73     size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size();
74     for (const auto& input : request.inputs) {
75         count += input.dimensions.size();
76     }
77     for (const auto& output : request.outputs) {
78         count += output.dimensions.size();
79     }
80 
81     // create buffer to temporarily store elements
82     std::vector<FmqRequestDatum> data;
83     data.reserve(count);
84 
85     // package packetInfo
86     {
87         FmqRequestDatum datum;
88         datum.packetInformation(
89                 {/*.packetSize=*/static_cast<uint32_t>(count),
90                  /*.numberOfInputOperands=*/static_cast<uint32_t>(request.inputs.size()),
91                  /*.numberOfOutputOperands=*/static_cast<uint32_t>(request.outputs.size()),
92                  /*.numberOfPools=*/static_cast<uint32_t>(request.pools.size())});
93         data.push_back(datum);
94     }
95 
96     // package input data
97     for (const auto& input : request.inputs) {
98         // package operand information
99         FmqRequestDatum datum;
100         datum.inputOperandInformation(
101                 {/*.hasNoValue=*/input.hasNoValue,
102                  /*.location=*/input.location,
103                  /*.numberOfDimensions=*/static_cast<uint32_t>(input.dimensions.size())});
104         data.push_back(datum);
105 
106         // package operand dimensions
107         for (uint32_t dimension : input.dimensions) {
108             FmqRequestDatum datum;
109             datum.inputOperandDimensionValue(dimension);
110             data.push_back(datum);
111         }
112     }
113 
114     // package output data
115     for (const auto& output : request.outputs) {
116         // package operand information
117         FmqRequestDatum datum;
118         datum.outputOperandInformation(
119                 {/*.hasNoValue=*/output.hasNoValue,
120                  /*.location=*/output.location,
121                  /*.numberOfDimensions=*/static_cast<uint32_t>(output.dimensions.size())});
122         data.push_back(datum);
123 
124         // package operand dimensions
125         for (uint32_t dimension : output.dimensions) {
126             FmqRequestDatum datum;
127             datum.outputOperandDimensionValue(dimension);
128             data.push_back(datum);
129         }
130     }
131 
132     // package pool identifier
133     for (int32_t slot : slots) {
134         FmqRequestDatum datum;
135         datum.poolIdentifier(slot);
136         data.push_back(datum);
137     }
138 
139     // package measureTiming
140     {
141         FmqRequestDatum datum;
142         datum.measureTiming(measure);
143         data.push_back(datum);
144     }
145 
146     // return packet
147     return data;
148 }
149 
150 // deserialize a packet into the result
151 std::optional<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::Timing>>
deserialize(const std::vector<FmqResultDatum> & data)152 deserialize(const std::vector<FmqResultDatum>& data) {
153     using discriminator = FmqResultDatum::hidl_discriminator;
154 
155     std::vector<V1_2::OutputShape> outputShapes;
156     size_t index = 0;
157 
158     // validate packet information
159     if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
160         LOG(ERROR) << "FMQ Result packet ill-formed";
161         return std::nullopt;
162     }
163 
164     // unpackage packet information
165     const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation();
166     index++;
167     const uint32_t packetSize = packetInfo.packetSize;
168     const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus;
169     const uint32_t numberOfOperands = packetInfo.numberOfOperands;
170 
171     // verify packet size
172     if (data.size() != packetSize) {
173         LOG(ERROR) << "FMQ Result packet ill-formed";
174         return std::nullopt;
175     }
176 
177     // unpackage operands
178     for (size_t operand = 0; operand < numberOfOperands; ++operand) {
179         // validate operand information
180         if (data[index].getDiscriminator() != discriminator::operandInformation) {
181             LOG(ERROR) << "FMQ Result packet ill-formed";
182             return std::nullopt;
183         }
184 
185         // unpackage operand information
186         const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation();
187         index++;
188         const bool isSufficient = operandInfo.isSufficient;
189         const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
190 
191         // unpackage operand dimensions
192         std::vector<uint32_t> dimensions;
193         dimensions.reserve(numberOfDimensions);
194         for (size_t i = 0; i < numberOfDimensions; ++i) {
195             // validate dimension
196             if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
197                 LOG(ERROR) << "FMQ Result packet ill-formed";
198                 return std::nullopt;
199             }
200 
201             // unpackage dimension
202             const uint32_t dimension = data[index].operandDimensionValue();
203             index++;
204 
205             // store result
206             dimensions.push_back(dimension);
207         }
208 
209         // store result
210         outputShapes.push_back({/*.dimensions=*/dimensions, /*.isSufficient=*/isSufficient});
211     }
212 
213     // validate execution timing
214     if (data[index].getDiscriminator() != discriminator::executionTiming) {
215         LOG(ERROR) << "FMQ Result packet ill-formed";
216         return std::nullopt;
217     }
218 
219     // unpackage execution timing
220     const V1_2::Timing timing = data[index].executionTiming();
221     index++;
222 
223     // validate packet information
224     if (index != packetSize) {
225         LOG(ERROR) << "FMQ Result packet ill-formed";
226         return std::nullopt;
227     }
228 
229     // return result
230     return std::make_tuple(errorStatus, std::move(outputShapes), timing);
231 }
232 
legacyConvertResultCodeToErrorStatus(int resultCode)233 V1_0::ErrorStatus legacyConvertResultCodeToErrorStatus(int resultCode) {
234     return convertToV1_0(convertResultCodeToErrorStatus(resultCode));
235 }
236 
237 std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*>
create(size_t channelLength,std::chrono::microseconds pollingTimeWindow)238 ResultChannelReceiver::create(size_t channelLength, std::chrono::microseconds pollingTimeWindow) {
239     std::unique_ptr<FmqResultChannel> fmqResultChannel =
240             std::make_unique<FmqResultChannel>(channelLength, /*confEventFlag=*/true);
241     if (!fmqResultChannel->isValid()) {
242         LOG(ERROR) << "Unable to create ResultChannelReceiver";
243         return {nullptr, nullptr};
244     }
245 
246     const FmqResultDescriptor* descriptor = fmqResultChannel->getDesc();
247     return std::make_pair(
248             std::make_unique<ResultChannelReceiver>(std::move(fmqResultChannel), pollingTimeWindow),
249             descriptor);
250 }
251 
ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,std::chrono::microseconds pollingTimeWindow)252 ResultChannelReceiver::ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,
253                                              std::chrono::microseconds pollingTimeWindow)
254     : mFmqResultChannel(std::move(fmqResultChannel)), kPollingTimeWindow(pollingTimeWindow) {}
255 
256 std::optional<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::Timing>>
getBlocking()257 ResultChannelReceiver::getBlocking() {
258     const auto packet = getPacketBlocking();
259     if (!packet) {
260         return std::nullopt;
261     }
262 
263     return deserialize(*packet);
264 }
265 
invalidate()266 void ResultChannelReceiver::invalidate() {
267     mValid = false;
268 
269     // force unblock
270     // ExecutionBurstController waits on a result packet after sending a
271     // request. If the driver containing ExecutionBurstServer crashes, the
272     // controller may be waiting on the futex. This force unblock wakes up any
273     // thread waiting on the futex.
274     // TODO: look for a different/better way to signal/notify the futex to
275     // wake up any thread waiting on it
276     FmqResultDatum datum;
277     datum.packetInformation({/*.packetSize=*/0,
278                              /*.errorStatus=*/V1_0::ErrorStatus::GENERAL_FAILURE,
279                              /*.numberOfOperands=*/0});
280     mFmqResultChannel->writeBlocking(&datum, 1);
281 }
282 
getPacketBlocking()283 std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
284     if (!mValid) {
285         return std::nullopt;
286     }
287 
288     // First spend time polling if results are available in FMQ instead of
289     // waiting on the futex. Polling is more responsive (yielding lower
290     // latencies), but can take up more power, so only poll for a limited period
291     // of time.
292 
293     auto& getCurrentTime = std::chrono::high_resolution_clock::now;
294     const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow;
295 
296     while (getCurrentTime() < timeToStopPolling) {
297         // if class is being torn down, immediately return
298         if (!mValid.load(std::memory_order_relaxed)) {
299             return std::nullopt;
300         }
301 
302         // Check if data is available. If it is, immediately retrieve it and
303         // return.
304         const size_t available = mFmqResultChannel->availableToRead();
305         if (available > 0) {
306             std::vector<FmqResultDatum> packet(available);
307             const bool success = mFmqResultChannel->read(packet.data(), available);
308             if (!success) {
309                 LOG(ERROR) << "Error receiving packet";
310                 return std::nullopt;
311             }
312             return std::make_optional(std::move(packet));
313         }
314 
315         std::this_thread::yield();
316     }
317 
318     // If we get to this point, we either stopped polling because it was taking
319     // too long or polling was not allowed. Instead, perform a blocking call
320     // which uses a futex to save power.
321 
322     // wait for result packet and read first element of result packet
323     FmqResultDatum datum;
324     bool success = mFmqResultChannel->readBlocking(&datum, 1);
325 
326     // retrieve remaining elements
327     // NOTE: all of the data is already available at this point, so there's no
328     // need to do a blocking wait to wait for more data. This is known because
329     // in FMQ, all writes are published (made available) atomically. Currently,
330     // the producer always publishes the entire packet in one function call, so
331     // if the first element of the packet is available, the remaining elements
332     // are also available.
333     const size_t count = mFmqResultChannel->availableToRead();
334     std::vector<FmqResultDatum> packet(count + 1);
335     std::memcpy(&packet.front(), &datum, sizeof(datum));
336     success &= mFmqResultChannel->read(packet.data() + 1, count);
337 
338     if (!mValid) {
339         return std::nullopt;
340     }
341 
342     // ensure packet was successfully received
343     if (!success) {
344         LOG(ERROR) << "Error receiving packet";
345         return std::nullopt;
346     }
347 
348     return std::make_optional(std::move(packet));
349 }
350 
351 std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*>
create(size_t channelLength)352 RequestChannelSender::create(size_t channelLength) {
353     std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
354             std::make_unique<FmqRequestChannel>(channelLength, /*confEventFlag=*/true);
355     if (!fmqRequestChannel->isValid()) {
356         LOG(ERROR) << "Unable to create RequestChannelSender";
357         return {nullptr, nullptr};
358     }
359 
360     const FmqRequestDescriptor* descriptor = fmqRequestChannel->getDesc();
361     return std::make_pair(std::make_unique<RequestChannelSender>(std::move(fmqRequestChannel)),
362                           descriptor);
363 }
364 
RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)365 RequestChannelSender::RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)
366     : mFmqRequestChannel(std::move(fmqRequestChannel)) {}
367 
send(const V1_0::Request & request,V1_2::MeasureTiming measure,const std::vector<int32_t> & slots)368 bool RequestChannelSender::send(const V1_0::Request& request, V1_2::MeasureTiming measure,
369                                 const std::vector<int32_t>& slots) {
370     const std::vector<FmqRequestDatum> serialized = serialize(request, measure, slots);
371     return sendPacket(serialized);
372 }
373 
sendPacket(const std::vector<FmqRequestDatum> & packet)374 bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
375     if (!mValid) {
376         return false;
377     }
378 
379     if (packet.size() > mFmqRequestChannel->availableToWrite()) {
380         LOG(ERROR)
381                 << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
382         return false;
383     }
384 
385     // Always send the packet with "blocking" because this signals the futex and
386     // unblocks the consumer if it is waiting on the futex.
387     return mFmqRequestChannel->writeBlocking(packet.data(), packet.size());
388 }
389 
invalidate()390 void RequestChannelSender::invalidate() {
391     mValid = false;
392 }
393 
getMemories(const hardware::hidl_vec<int32_t> & slots,getMemories_cb cb)394 hardware::Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
395         const hardware::hidl_vec<int32_t>& slots, getMemories_cb cb) {
396     std::lock_guard<std::mutex> guard(mMutex);
397 
398     // get all memories
399     hardware::hidl_vec<hardware::hidl_memory> memories(slots.size());
400     std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) {
401         return slot < mMemoryCache.size() ? mMemoryCache[slot] : hardware::hidl_memory{};
402     });
403 
404     // ensure all memories are valid
405     if (!std::all_of(memories.begin(), memories.end(),
406                      [](const hardware::hidl_memory& memory) { return memory.valid(); })) {
407         cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
408         return hardware::Void();
409     }
410 
411     // return successful
412     cb(V1_0::ErrorStatus::NONE, std::move(memories));
413     return hardware::Void();
414 }
415 
getSlots(const hardware::hidl_vec<hardware::hidl_memory> & memories,const std::vector<intptr_t> & keys)416 std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots(
417         const hardware::hidl_vec<hardware::hidl_memory>& memories,
418         const std::vector<intptr_t>& keys) {
419     std::lock_guard<std::mutex> guard(mMutex);
420 
421     // retrieve (or bind) all slots corresponding to memories
422     std::vector<int32_t> slots;
423     slots.reserve(memories.size());
424     for (size_t i = 0; i < memories.size(); ++i) {
425         slots.push_back(getSlotLocked(memories[i], keys[i]));
426     }
427     return slots;
428 }
429 
freeMemory(intptr_t key)430 std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory(
431         intptr_t key) {
432     std::lock_guard<std::mutex> guard(mMutex);
433 
434     auto iter = mMemoryIdToSlot.find(key);
435     if (iter == mMemoryIdToSlot.end()) {
436         return {false, 0};
437     }
438     const int32_t slot = iter->second;
439     mMemoryIdToSlot.erase(key);
440     mMemoryCache[slot] = {};
441     mFreeSlots.push(slot);
442     return {true, slot};
443 }
444 
getSlotLocked(const hardware::hidl_memory & memory,intptr_t key)445 int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(
446         const hardware::hidl_memory& memory, intptr_t key) {
447     auto iter = mMemoryIdToSlot.find(key);
448     if (iter == mMemoryIdToSlot.end()) {
449         const int32_t slot = allocateSlotLocked();
450         mMemoryIdToSlot[key] = slot;
451         mMemoryCache[slot] = memory;
452         return slot;
453     } else {
454         const int32_t slot = iter->second;
455         return slot;
456     }
457 }
458 
allocateSlotLocked()459 int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() {
460     constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
461 
462     // if there is a free slot, use it
463     if (mFreeSlots.size() > 0) {
464         const int32_t slot = mFreeSlots.top();
465         mFreeSlots.pop();
466         return slot;
467     }
468 
469     // otherwise use a slot for the first time
470     CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
471     const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
472     mMemoryCache.emplace_back();
473 
474     return slot;
475 }
476 
create(const sp<V1_2::IPreparedModel> & preparedModel,std::chrono::microseconds pollingTimeWindow)477 std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
478         const sp<V1_2::IPreparedModel>& preparedModel,
479         std::chrono::microseconds pollingTimeWindow) {
480     // check inputs
481     if (preparedModel == nullptr) {
482         LOG(ERROR) << "ExecutionBurstController::create passed a nullptr";
483         return nullptr;
484     }
485 
486     // create callback object
487     sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
488 
489     // create FMQ objects
490     auto [requestChannelSenderTemp, requestChannelDescriptor] =
491             RequestChannelSender::create(kExecutionBurstChannelLength);
492     auto [resultChannelReceiverTemp, resultChannelDescriptor] =
493             ResultChannelReceiver::create(kExecutionBurstChannelLength, pollingTimeWindow);
494     std::shared_ptr<RequestChannelSender> requestChannelSender =
495             std::move(requestChannelSenderTemp);
496     std::shared_ptr<ResultChannelReceiver> resultChannelReceiver =
497             std::move(resultChannelReceiverTemp);
498 
499     // check FMQ objects
500     if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor ||
501         !resultChannelDescriptor) {
502         LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
503         return nullptr;
504     }
505 
506     // configure burst
507     V1_0::ErrorStatus errorStatus;
508     sp<IBurstContext> burstContext;
509     const hardware::Return<void> ret = preparedModel->configureExecutionBurst(
510             callback, *requestChannelDescriptor, *resultChannelDescriptor,
511             [&errorStatus, &burstContext](V1_0::ErrorStatus status,
512                                           const sp<IBurstContext>& context) {
513                 errorStatus = status;
514                 burstContext = context;
515             });
516 
517     // check burst
518     if (!ret.isOk()) {
519         LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description "
520                    << ret.description();
521         return nullptr;
522     }
523     if (errorStatus != V1_0::ErrorStatus::NONE) {
524         LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status "
525                    << toString(errorStatus);
526         return nullptr;
527     }
528     if (burstContext == nullptr) {
529         LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
530         return nullptr;
531     }
532 
533     // create death handler object
534     BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender,
535                                                           resultChannelReceiver] {
536         requestChannelSender->invalidate();
537         resultChannelReceiver->invalidate();
538     };
539     const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback);
540 
541     // linkToDeath registers a callback that will be invoked on service death to
542     // proactively handle service crashes. If the linkToDeath call fails,
543     // asynchronous calls are susceptible to hangs if the service crashes before
544     // providing the response.
545     const hardware::Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0);
546     if (!deathHandlerRet.isOk() || deathHandlerRet != true) {
547         LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient "
548                       "for the IBurstContext object.";
549         return nullptr;
550     }
551 
552     // make and return controller
553     return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver,
554                                                       burstContext, callback, deathHandler);
555 }
556 
ExecutionBurstController(const std::shared_ptr<RequestChannelSender> & requestChannelSender,const std::shared_ptr<ResultChannelReceiver> & resultChannelReceiver,const sp<IBurstContext> & burstContext,const sp<ExecutionBurstCallback> & callback,const sp<hardware::hidl_death_recipient> & deathHandler)557 ExecutionBurstController::ExecutionBurstController(
558         const std::shared_ptr<RequestChannelSender>& requestChannelSender,
559         const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
560         const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback,
561         const sp<hardware::hidl_death_recipient>& deathHandler)
562     : mRequestChannelSender(requestChannelSender),
563       mResultChannelReceiver(resultChannelReceiver),
564       mBurstContext(burstContext),
565       mMemoryCache(callback),
566       mDeathHandler(deathHandler) {}
567 
~ExecutionBurstController()568 ExecutionBurstController::~ExecutionBurstController() {
569     // It is safe to ignore any errors resulting from this unlinkToDeath call
570     // because the ExecutionBurstController object is already being destroyed
571     // and its underlying IBurstContext object is no longer being used by the NN
572     // runtime.
573     if (mDeathHandler) {
574         mBurstContext->unlinkToDeath(mDeathHandler).isOk();
575     }
576 }
577 
getExecutionResult(V1_0::ErrorStatus status,std::vector<V1_2::OutputShape> outputShapes,V1_2::Timing timing,bool fallback)578 static std::tuple<int, std::vector<V1_2::OutputShape>, V1_2::Timing, bool> getExecutionResult(
579         V1_0::ErrorStatus status, std::vector<V1_2::OutputShape> outputShapes, V1_2::Timing timing,
580         bool fallback) {
581     auto [n, checkedOutputShapes, checkedTiming] =
582             getExecutionResult(convertToV1_3(status), std::move(outputShapes), timing);
583     return {n, convertToV1_2(checkedOutputShapes), convertToV1_2(checkedTiming), fallback};
584 }
585 
586 std::tuple<int, std::vector<V1_2::OutputShape>, V1_2::Timing, bool>
compute(const V1_0::Request & request,V1_2::MeasureTiming measure,const std::vector<intptr_t> & memoryIds)587 ExecutionBurstController::compute(const V1_0::Request& request, V1_2::MeasureTiming measure,
588                                   const std::vector<intptr_t>& memoryIds) {
589     // This is the first point when we know an execution is occurring, so begin
590     // to collect systraces. Note that the first point we can begin collecting
591     // systraces in ExecutionBurstServer is when the RequestChannelReceiver
592     // realizes there is data in the FMQ, so ExecutionBurstServer collects
593     // systraces at different points in the code.
594     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
595 
596     std::lock_guard<std::mutex> guard(mMutex);
597 
598     // send request packet
599     const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds);
600     const bool success = mRequestChannelSender->send(request, measure, slots);
601     if (!success) {
602         LOG(ERROR) << "Error sending FMQ packet";
603         // only use fallback execution path if the packet could not be sent
604         return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming12,
605                                   /*fallback=*/true);
606     }
607 
608     // get result packet
609     const auto result = mResultChannelReceiver->getBlocking();
610     if (!result) {
611         LOG(ERROR) << "Error retrieving FMQ packet";
612         // only use fallback execution path if the packet could not be sent
613         return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming12,
614                                   /*fallback=*/false);
615     }
616 
617     // unpack results and return (only use fallback execution path if the
618     // packet could not be sent)
619     auto [status, outputShapes, timing] = std::move(*result);
620     return getExecutionResult(status, std::move(outputShapes), timing, /*fallback=*/false);
621 }
622 
freeMemory(intptr_t key)623 void ExecutionBurstController::freeMemory(intptr_t key) {
624     std::lock_guard<std::mutex> guard(mMutex);
625 
626     bool valid;
627     int32_t slot;
628     std::tie(valid, slot) = mMemoryCache->freeMemory(key);
629     if (valid) {
630         mBurstContext->freeMemory(slot).isOk();
631     }
632 }
633 
634 }  // namespace android::nn
635