1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "ExecutionBurstController"
18
19 #include "ExecutionBurstController.h"
20
21 #include <android-base/logging.h>
22
23 #include <algorithm>
24 #include <cstring>
25 #include <limits>
26 #include <memory>
27 #include <string>
28 #include <thread>
29 #include <tuple>
30 #include <utility>
31 #include <vector>
32
33 #include "HalInterfaces.h"
34 #include "Tracing.h"
35 #include "Utils.h"
36
37 namespace android::nn {
38 namespace {
39
40 using V1_2::FmqRequestDatum;
41 using V1_2::FmqResultDatum;
42 using V1_2::IBurstCallback;
43 using V1_2::IBurstContext;
44 using FmqRequestDescriptor = hardware::MQDescriptorSync<FmqRequestDatum>;
45 using FmqResultDescriptor = hardware::MQDescriptorSync<FmqResultDatum>;
46
47 constexpr V1_2::Timing kNoTiming12 = {std::numeric_limits<uint64_t>::max(),
48 std::numeric_limits<uint64_t>::max()};
49
50 class BurstContextDeathHandler : public hardware::hidl_death_recipient {
51 public:
52 using Callback = std::function<void()>;
53
BurstContextDeathHandler(const Callback & onDeathCallback)54 BurstContextDeathHandler(const Callback& onDeathCallback) : mOnDeathCallback(onDeathCallback) {
55 CHECK(onDeathCallback != nullptr);
56 }
57
serviceDied(uint64_t,const wp<hidl::base::V1_0::IBase> &)58 void serviceDied(uint64_t /*cookie*/, const wp<hidl::base::V1_0::IBase>& /*who*/) override {
59 LOG(ERROR) << "BurstContextDeathHandler::serviceDied -- service unexpectedly died!";
60 mOnDeathCallback();
61 }
62
63 private:
64 const Callback mOnDeathCallback;
65 };
66
67 } // anonymous namespace
68
69 // serialize a request into a packet
serialize(const V1_0::Request & request,V1_2::MeasureTiming measure,const std::vector<int32_t> & slots)70 std::vector<FmqRequestDatum> serialize(const V1_0::Request& request, V1_2::MeasureTiming measure,
71 const std::vector<int32_t>& slots) {
72 // count how many elements need to be sent for a request
73 size_t count = 2 + request.inputs.size() + request.outputs.size() + request.pools.size();
74 for (const auto& input : request.inputs) {
75 count += input.dimensions.size();
76 }
77 for (const auto& output : request.outputs) {
78 count += output.dimensions.size();
79 }
80
81 // create buffer to temporarily store elements
82 std::vector<FmqRequestDatum> data;
83 data.reserve(count);
84
85 // package packetInfo
86 {
87 FmqRequestDatum datum;
88 datum.packetInformation(
89 {/*.packetSize=*/static_cast<uint32_t>(count),
90 /*.numberOfInputOperands=*/static_cast<uint32_t>(request.inputs.size()),
91 /*.numberOfOutputOperands=*/static_cast<uint32_t>(request.outputs.size()),
92 /*.numberOfPools=*/static_cast<uint32_t>(request.pools.size())});
93 data.push_back(datum);
94 }
95
96 // package input data
97 for (const auto& input : request.inputs) {
98 // package operand information
99 FmqRequestDatum datum;
100 datum.inputOperandInformation(
101 {/*.hasNoValue=*/input.hasNoValue,
102 /*.location=*/input.location,
103 /*.numberOfDimensions=*/static_cast<uint32_t>(input.dimensions.size())});
104 data.push_back(datum);
105
106 // package operand dimensions
107 for (uint32_t dimension : input.dimensions) {
108 FmqRequestDatum datum;
109 datum.inputOperandDimensionValue(dimension);
110 data.push_back(datum);
111 }
112 }
113
114 // package output data
115 for (const auto& output : request.outputs) {
116 // package operand information
117 FmqRequestDatum datum;
118 datum.outputOperandInformation(
119 {/*.hasNoValue=*/output.hasNoValue,
120 /*.location=*/output.location,
121 /*.numberOfDimensions=*/static_cast<uint32_t>(output.dimensions.size())});
122 data.push_back(datum);
123
124 // package operand dimensions
125 for (uint32_t dimension : output.dimensions) {
126 FmqRequestDatum datum;
127 datum.outputOperandDimensionValue(dimension);
128 data.push_back(datum);
129 }
130 }
131
132 // package pool identifier
133 for (int32_t slot : slots) {
134 FmqRequestDatum datum;
135 datum.poolIdentifier(slot);
136 data.push_back(datum);
137 }
138
139 // package measureTiming
140 {
141 FmqRequestDatum datum;
142 datum.measureTiming(measure);
143 data.push_back(datum);
144 }
145
146 // return packet
147 return data;
148 }
149
150 // deserialize a packet into the result
151 std::optional<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::Timing>>
deserialize(const std::vector<FmqResultDatum> & data)152 deserialize(const std::vector<FmqResultDatum>& data) {
153 using discriminator = FmqResultDatum::hidl_discriminator;
154
155 std::vector<V1_2::OutputShape> outputShapes;
156 size_t index = 0;
157
158 // validate packet information
159 if (data.size() == 0 || data[index].getDiscriminator() != discriminator::packetInformation) {
160 LOG(ERROR) << "FMQ Result packet ill-formed";
161 return std::nullopt;
162 }
163
164 // unpackage packet information
165 const FmqResultDatum::PacketInformation& packetInfo = data[index].packetInformation();
166 index++;
167 const uint32_t packetSize = packetInfo.packetSize;
168 const V1_0::ErrorStatus errorStatus = packetInfo.errorStatus;
169 const uint32_t numberOfOperands = packetInfo.numberOfOperands;
170
171 // verify packet size
172 if (data.size() != packetSize) {
173 LOG(ERROR) << "FMQ Result packet ill-formed";
174 return std::nullopt;
175 }
176
177 // unpackage operands
178 for (size_t operand = 0; operand < numberOfOperands; ++operand) {
179 // validate operand information
180 if (data[index].getDiscriminator() != discriminator::operandInformation) {
181 LOG(ERROR) << "FMQ Result packet ill-formed";
182 return std::nullopt;
183 }
184
185 // unpackage operand information
186 const FmqResultDatum::OperandInformation& operandInfo = data[index].operandInformation();
187 index++;
188 const bool isSufficient = operandInfo.isSufficient;
189 const uint32_t numberOfDimensions = operandInfo.numberOfDimensions;
190
191 // unpackage operand dimensions
192 std::vector<uint32_t> dimensions;
193 dimensions.reserve(numberOfDimensions);
194 for (size_t i = 0; i < numberOfDimensions; ++i) {
195 // validate dimension
196 if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
197 LOG(ERROR) << "FMQ Result packet ill-formed";
198 return std::nullopt;
199 }
200
201 // unpackage dimension
202 const uint32_t dimension = data[index].operandDimensionValue();
203 index++;
204
205 // store result
206 dimensions.push_back(dimension);
207 }
208
209 // store result
210 outputShapes.push_back({/*.dimensions=*/dimensions, /*.isSufficient=*/isSufficient});
211 }
212
213 // validate execution timing
214 if (data[index].getDiscriminator() != discriminator::executionTiming) {
215 LOG(ERROR) << "FMQ Result packet ill-formed";
216 return std::nullopt;
217 }
218
219 // unpackage execution timing
220 const V1_2::Timing timing = data[index].executionTiming();
221 index++;
222
223 // validate packet information
224 if (index != packetSize) {
225 LOG(ERROR) << "FMQ Result packet ill-formed";
226 return std::nullopt;
227 }
228
229 // return result
230 return std::make_tuple(errorStatus, std::move(outputShapes), timing);
231 }
232
legacyConvertResultCodeToErrorStatus(int resultCode)233 V1_0::ErrorStatus legacyConvertResultCodeToErrorStatus(int resultCode) {
234 return convertToV1_0(convertResultCodeToErrorStatus(resultCode));
235 }
236
237 std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*>
create(size_t channelLength,std::chrono::microseconds pollingTimeWindow)238 ResultChannelReceiver::create(size_t channelLength, std::chrono::microseconds pollingTimeWindow) {
239 std::unique_ptr<FmqResultChannel> fmqResultChannel =
240 std::make_unique<FmqResultChannel>(channelLength, /*confEventFlag=*/true);
241 if (!fmqResultChannel->isValid()) {
242 LOG(ERROR) << "Unable to create ResultChannelReceiver";
243 return {nullptr, nullptr};
244 }
245
246 const FmqResultDescriptor* descriptor = fmqResultChannel->getDesc();
247 return std::make_pair(
248 std::make_unique<ResultChannelReceiver>(std::move(fmqResultChannel), pollingTimeWindow),
249 descriptor);
250 }
251
ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,std::chrono::microseconds pollingTimeWindow)252 ResultChannelReceiver::ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel,
253 std::chrono::microseconds pollingTimeWindow)
254 : mFmqResultChannel(std::move(fmqResultChannel)), kPollingTimeWindow(pollingTimeWindow) {}
255
256 std::optional<std::tuple<V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, V1_2::Timing>>
getBlocking()257 ResultChannelReceiver::getBlocking() {
258 const auto packet = getPacketBlocking();
259 if (!packet) {
260 return std::nullopt;
261 }
262
263 return deserialize(*packet);
264 }
265
invalidate()266 void ResultChannelReceiver::invalidate() {
267 mValid = false;
268
269 // force unblock
270 // ExecutionBurstController waits on a result packet after sending a
271 // request. If the driver containing ExecutionBurstServer crashes, the
272 // controller may be waiting on the futex. This force unblock wakes up any
273 // thread waiting on the futex.
274 // TODO: look for a different/better way to signal/notify the futex to
275 // wake up any thread waiting on it
276 FmqResultDatum datum;
277 datum.packetInformation({/*.packetSize=*/0,
278 /*.errorStatus=*/V1_0::ErrorStatus::GENERAL_FAILURE,
279 /*.numberOfOperands=*/0});
280 mFmqResultChannel->writeBlocking(&datum, 1);
281 }
282
getPacketBlocking()283 std::optional<std::vector<FmqResultDatum>> ResultChannelReceiver::getPacketBlocking() {
284 if (!mValid) {
285 return std::nullopt;
286 }
287
288 // First spend time polling if results are available in FMQ instead of
289 // waiting on the futex. Polling is more responsive (yielding lower
290 // latencies), but can take up more power, so only poll for a limited period
291 // of time.
292
293 auto& getCurrentTime = std::chrono::high_resolution_clock::now;
294 const auto timeToStopPolling = getCurrentTime() + kPollingTimeWindow;
295
296 while (getCurrentTime() < timeToStopPolling) {
297 // if class is being torn down, immediately return
298 if (!mValid.load(std::memory_order_relaxed)) {
299 return std::nullopt;
300 }
301
302 // Check if data is available. If it is, immediately retrieve it and
303 // return.
304 const size_t available = mFmqResultChannel->availableToRead();
305 if (available > 0) {
306 std::vector<FmqResultDatum> packet(available);
307 const bool success = mFmqResultChannel->read(packet.data(), available);
308 if (!success) {
309 LOG(ERROR) << "Error receiving packet";
310 return std::nullopt;
311 }
312 return std::make_optional(std::move(packet));
313 }
314
315 std::this_thread::yield();
316 }
317
318 // If we get to this point, we either stopped polling because it was taking
319 // too long or polling was not allowed. Instead, perform a blocking call
320 // which uses a futex to save power.
321
322 // wait for result packet and read first element of result packet
323 FmqResultDatum datum;
324 bool success = mFmqResultChannel->readBlocking(&datum, 1);
325
326 // retrieve remaining elements
327 // NOTE: all of the data is already available at this point, so there's no
328 // need to do a blocking wait to wait for more data. This is known because
329 // in FMQ, all writes are published (made available) atomically. Currently,
330 // the producer always publishes the entire packet in one function call, so
331 // if the first element of the packet is available, the remaining elements
332 // are also available.
333 const size_t count = mFmqResultChannel->availableToRead();
334 std::vector<FmqResultDatum> packet(count + 1);
335 std::memcpy(&packet.front(), &datum, sizeof(datum));
336 success &= mFmqResultChannel->read(packet.data() + 1, count);
337
338 if (!mValid) {
339 return std::nullopt;
340 }
341
342 // ensure packet was successfully received
343 if (!success) {
344 LOG(ERROR) << "Error receiving packet";
345 return std::nullopt;
346 }
347
348 return std::make_optional(std::move(packet));
349 }
350
351 std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*>
create(size_t channelLength)352 RequestChannelSender::create(size_t channelLength) {
353 std::unique_ptr<FmqRequestChannel> fmqRequestChannel =
354 std::make_unique<FmqRequestChannel>(channelLength, /*confEventFlag=*/true);
355 if (!fmqRequestChannel->isValid()) {
356 LOG(ERROR) << "Unable to create RequestChannelSender";
357 return {nullptr, nullptr};
358 }
359
360 const FmqRequestDescriptor* descriptor = fmqRequestChannel->getDesc();
361 return std::make_pair(std::make_unique<RequestChannelSender>(std::move(fmqRequestChannel)),
362 descriptor);
363 }
364
RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)365 RequestChannelSender::RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel)
366 : mFmqRequestChannel(std::move(fmqRequestChannel)) {}
367
send(const V1_0::Request & request,V1_2::MeasureTiming measure,const std::vector<int32_t> & slots)368 bool RequestChannelSender::send(const V1_0::Request& request, V1_2::MeasureTiming measure,
369 const std::vector<int32_t>& slots) {
370 const std::vector<FmqRequestDatum> serialized = serialize(request, measure, slots);
371 return sendPacket(serialized);
372 }
373
sendPacket(const std::vector<FmqRequestDatum> & packet)374 bool RequestChannelSender::sendPacket(const std::vector<FmqRequestDatum>& packet) {
375 if (!mValid) {
376 return false;
377 }
378
379 if (packet.size() > mFmqRequestChannel->availableToWrite()) {
380 LOG(ERROR)
381 << "RequestChannelSender::sendPacket -- packet size exceeds size available in FMQ";
382 return false;
383 }
384
385 // Always send the packet with "blocking" because this signals the futex and
386 // unblocks the consumer if it is waiting on the futex.
387 return mFmqRequestChannel->writeBlocking(packet.data(), packet.size());
388 }
389
invalidate()390 void RequestChannelSender::invalidate() {
391 mValid = false;
392 }
393
getMemories(const hardware::hidl_vec<int32_t> & slots,getMemories_cb cb)394 hardware::Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
395 const hardware::hidl_vec<int32_t>& slots, getMemories_cb cb) {
396 std::lock_guard<std::mutex> guard(mMutex);
397
398 // get all memories
399 hardware::hidl_vec<hardware::hidl_memory> memories(slots.size());
400 std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) {
401 return slot < mMemoryCache.size() ? mMemoryCache[slot] : hardware::hidl_memory{};
402 });
403
404 // ensure all memories are valid
405 if (!std::all_of(memories.begin(), memories.end(),
406 [](const hardware::hidl_memory& memory) { return memory.valid(); })) {
407 cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {});
408 return hardware::Void();
409 }
410
411 // return successful
412 cb(V1_0::ErrorStatus::NONE, std::move(memories));
413 return hardware::Void();
414 }
415
getSlots(const hardware::hidl_vec<hardware::hidl_memory> & memories,const std::vector<intptr_t> & keys)416 std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots(
417 const hardware::hidl_vec<hardware::hidl_memory>& memories,
418 const std::vector<intptr_t>& keys) {
419 std::lock_guard<std::mutex> guard(mMutex);
420
421 // retrieve (or bind) all slots corresponding to memories
422 std::vector<int32_t> slots;
423 slots.reserve(memories.size());
424 for (size_t i = 0; i < memories.size(); ++i) {
425 slots.push_back(getSlotLocked(memories[i], keys[i]));
426 }
427 return slots;
428 }
429
freeMemory(intptr_t key)430 std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory(
431 intptr_t key) {
432 std::lock_guard<std::mutex> guard(mMutex);
433
434 auto iter = mMemoryIdToSlot.find(key);
435 if (iter == mMemoryIdToSlot.end()) {
436 return {false, 0};
437 }
438 const int32_t slot = iter->second;
439 mMemoryIdToSlot.erase(key);
440 mMemoryCache[slot] = {};
441 mFreeSlots.push(slot);
442 return {true, slot};
443 }
444
getSlotLocked(const hardware::hidl_memory & memory,intptr_t key)445 int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(
446 const hardware::hidl_memory& memory, intptr_t key) {
447 auto iter = mMemoryIdToSlot.find(key);
448 if (iter == mMemoryIdToSlot.end()) {
449 const int32_t slot = allocateSlotLocked();
450 mMemoryIdToSlot[key] = slot;
451 mMemoryCache[slot] = memory;
452 return slot;
453 } else {
454 const int32_t slot = iter->second;
455 return slot;
456 }
457 }
458
allocateSlotLocked()459 int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() {
460 constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
461
462 // if there is a free slot, use it
463 if (mFreeSlots.size() > 0) {
464 const int32_t slot = mFreeSlots.top();
465 mFreeSlots.pop();
466 return slot;
467 }
468
469 // otherwise use a slot for the first time
470 CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
471 const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
472 mMemoryCache.emplace_back();
473
474 return slot;
475 }
476
create(const sp<V1_2::IPreparedModel> & preparedModel,std::chrono::microseconds pollingTimeWindow)477 std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
478 const sp<V1_2::IPreparedModel>& preparedModel,
479 std::chrono::microseconds pollingTimeWindow) {
480 // check inputs
481 if (preparedModel == nullptr) {
482 LOG(ERROR) << "ExecutionBurstController::create passed a nullptr";
483 return nullptr;
484 }
485
486 // create callback object
487 sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
488
489 // create FMQ objects
490 auto [requestChannelSenderTemp, requestChannelDescriptor] =
491 RequestChannelSender::create(kExecutionBurstChannelLength);
492 auto [resultChannelReceiverTemp, resultChannelDescriptor] =
493 ResultChannelReceiver::create(kExecutionBurstChannelLength, pollingTimeWindow);
494 std::shared_ptr<RequestChannelSender> requestChannelSender =
495 std::move(requestChannelSenderTemp);
496 std::shared_ptr<ResultChannelReceiver> resultChannelReceiver =
497 std::move(resultChannelReceiverTemp);
498
499 // check FMQ objects
500 if (!requestChannelSender || !resultChannelReceiver || !requestChannelDescriptor ||
501 !resultChannelDescriptor) {
502 LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
503 return nullptr;
504 }
505
506 // configure burst
507 V1_0::ErrorStatus errorStatus;
508 sp<IBurstContext> burstContext;
509 const hardware::Return<void> ret = preparedModel->configureExecutionBurst(
510 callback, *requestChannelDescriptor, *resultChannelDescriptor,
511 [&errorStatus, &burstContext](V1_0::ErrorStatus status,
512 const sp<IBurstContext>& context) {
513 errorStatus = status;
514 burstContext = context;
515 });
516
517 // check burst
518 if (!ret.isOk()) {
519 LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with description "
520 << ret.description();
521 return nullptr;
522 }
523 if (errorStatus != V1_0::ErrorStatus::NONE) {
524 LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with status "
525 << toString(errorStatus);
526 return nullptr;
527 }
528 if (burstContext == nullptr) {
529 LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
530 return nullptr;
531 }
532
533 // create death handler object
534 BurstContextDeathHandler::Callback onDeathCallback = [requestChannelSender,
535 resultChannelReceiver] {
536 requestChannelSender->invalidate();
537 resultChannelReceiver->invalidate();
538 };
539 const sp<BurstContextDeathHandler> deathHandler = new BurstContextDeathHandler(onDeathCallback);
540
541 // linkToDeath registers a callback that will be invoked on service death to
542 // proactively handle service crashes. If the linkToDeath call fails,
543 // asynchronous calls are susceptible to hangs if the service crashes before
544 // providing the response.
545 const hardware::Return<bool> deathHandlerRet = burstContext->linkToDeath(deathHandler, 0);
546 if (!deathHandlerRet.isOk() || deathHandlerRet != true) {
547 LOG(ERROR) << "ExecutionBurstController::create -- Failed to register a death recipient "
548 "for the IBurstContext object.";
549 return nullptr;
550 }
551
552 // make and return controller
553 return std::make_unique<ExecutionBurstController>(requestChannelSender, resultChannelReceiver,
554 burstContext, callback, deathHandler);
555 }
556
ExecutionBurstController(const std::shared_ptr<RequestChannelSender> & requestChannelSender,const std::shared_ptr<ResultChannelReceiver> & resultChannelReceiver,const sp<IBurstContext> & burstContext,const sp<ExecutionBurstCallback> & callback,const sp<hardware::hidl_death_recipient> & deathHandler)557 ExecutionBurstController::ExecutionBurstController(
558 const std::shared_ptr<RequestChannelSender>& requestChannelSender,
559 const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver,
560 const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback,
561 const sp<hardware::hidl_death_recipient>& deathHandler)
562 : mRequestChannelSender(requestChannelSender),
563 mResultChannelReceiver(resultChannelReceiver),
564 mBurstContext(burstContext),
565 mMemoryCache(callback),
566 mDeathHandler(deathHandler) {}
567
~ExecutionBurstController()568 ExecutionBurstController::~ExecutionBurstController() {
569 // It is safe to ignore any errors resulting from this unlinkToDeath call
570 // because the ExecutionBurstController object is already being destroyed
571 // and its underlying IBurstContext object is no longer being used by the NN
572 // runtime.
573 if (mDeathHandler) {
574 mBurstContext->unlinkToDeath(mDeathHandler).isOk();
575 }
576 }
577
getExecutionResult(V1_0::ErrorStatus status,std::vector<V1_2::OutputShape> outputShapes,V1_2::Timing timing,bool fallback)578 static std::tuple<int, std::vector<V1_2::OutputShape>, V1_2::Timing, bool> getExecutionResult(
579 V1_0::ErrorStatus status, std::vector<V1_2::OutputShape> outputShapes, V1_2::Timing timing,
580 bool fallback) {
581 auto [n, checkedOutputShapes, checkedTiming] =
582 getExecutionResult(convertToV1_3(status), std::move(outputShapes), timing);
583 return {n, convertToV1_2(checkedOutputShapes), convertToV1_2(checkedTiming), fallback};
584 }
585
586 std::tuple<int, std::vector<V1_2::OutputShape>, V1_2::Timing, bool>
compute(const V1_0::Request & request,V1_2::MeasureTiming measure,const std::vector<intptr_t> & memoryIds)587 ExecutionBurstController::compute(const V1_0::Request& request, V1_2::MeasureTiming measure,
588 const std::vector<intptr_t>& memoryIds) {
589 // This is the first point when we know an execution is occurring, so begin
590 // to collect systraces. Note that the first point we can begin collecting
591 // systraces in ExecutionBurstServer is when the RequestChannelReceiver
592 // realizes there is data in the FMQ, so ExecutionBurstServer collects
593 // systraces at different points in the code.
594 NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
595
596 std::lock_guard<std::mutex> guard(mMutex);
597
598 // send request packet
599 const std::vector<int32_t> slots = mMemoryCache->getSlots(request.pools, memoryIds);
600 const bool success = mRequestChannelSender->send(request, measure, slots);
601 if (!success) {
602 LOG(ERROR) << "Error sending FMQ packet";
603 // only use fallback execution path if the packet could not be sent
604 return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming12,
605 /*fallback=*/true);
606 }
607
608 // get result packet
609 const auto result = mResultChannelReceiver->getBlocking();
610 if (!result) {
611 LOG(ERROR) << "Error retrieving FMQ packet";
612 // only use fallback execution path if the packet could not be sent
613 return getExecutionResult(V1_0::ErrorStatus::GENERAL_FAILURE, {}, kNoTiming12,
614 /*fallback=*/false);
615 }
616
617 // unpack results and return (only use fallback execution path if the
618 // packet could not be sent)
619 auto [status, outputShapes, timing] = std::move(*result);
620 return getExecutionResult(status, std::move(outputShapes), timing, /*fallback=*/false);
621 }
622
freeMemory(intptr_t key)623 void ExecutionBurstController::freeMemory(intptr_t key) {
624 std::lock_guard<std::mutex> guard(mMutex);
625
626 bool valid;
627 int32_t slot;
628 std::tie(valid, slot) = mMemoryCache->freeMemory(key);
629 if (valid) {
630 mBurstContext->freeMemory(slot).isOk();
631 }
632 }
633
634 } // namespace android::nn
635