1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "SampleDriverUtils.h"
18
19 #include <aidl/android/hardware/common/NativeHandle.h>
20 #include <android/binder_auto_utils.h>
21 #include <android/binder_ibinder.h>
22 #include <nnapi/Validation.h>
23 #include <nnapi/hal/aidl/Conversions.h>
24 #include <nnapi/hal/aidl/Utils.h>
25 #include <utils/NativeHandle.h>
26
27 #include <memory>
28 #include <string>
29 #include <thread>
30 #include <utility>
31
32 #include "SampleDriver.h"
33
34 namespace android {
35 namespace nn {
36 namespace sample_driver {
37
notify(const std::shared_ptr<aidl_hal::IPreparedModelCallback> & callback,const aidl_hal::ErrorStatus & status,const std::shared_ptr<aidl_hal::IPreparedModel> & preparedModel)38 void notify(const std::shared_ptr<aidl_hal::IPreparedModelCallback>& callback,
39 const aidl_hal::ErrorStatus& status,
40 const std::shared_ptr<aidl_hal::IPreparedModel>& preparedModel) {
41 const auto ret = callback->notify(status, preparedModel);
42 if (!ret.isOk()) {
43 LOG(ERROR) << "Error when calling IPreparedModelCallback::notify: " << ret.getDescription()
44 << " " << ret.getMessage();
45 }
46 }
47
toAStatus(aidl_hal::ErrorStatus errorStatus)48 ndk::ScopedAStatus toAStatus(aidl_hal::ErrorStatus errorStatus) {
49 if (errorStatus == aidl_hal::ErrorStatus::NONE) {
50 return ndk::ScopedAStatus::ok();
51 }
52 return ndk::ScopedAStatus::fromServiceSpecificError(static_cast<int32_t>(errorStatus));
53 }
54
toAStatus(aidl_hal::ErrorStatus errorStatus,const std::string & errorMessage)55 ndk::ScopedAStatus toAStatus(aidl_hal::ErrorStatus errorStatus, const std::string& errorMessage) {
56 if (errorStatus == aidl_hal::ErrorStatus::NONE) {
57 return ndk::ScopedAStatus::ok();
58 }
59 return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
60 static_cast<int32_t>(errorStatus), errorMessage.c_str());
61 }
62
prepareModelBase(aidl_hal::Model && model,const SampleDriver * driver,aidl_hal::ExecutionPreference preference,aidl_hal::Priority priority,int64_t halDeadline,const std::shared_ptr<aidl_hal::IPreparedModelCallback> & callback,bool isFullModelSupported)63 ndk::ScopedAStatus prepareModelBase(
64 aidl_hal::Model&& model, const SampleDriver* driver,
65 aidl_hal::ExecutionPreference preference, aidl_hal::Priority priority, int64_t halDeadline,
66 const std::shared_ptr<aidl_hal::IPreparedModelCallback>& callback,
67 bool isFullModelSupported) {
68 const uid_t userId = AIBinder_getCallingUid();
69 if (callback.get() == nullptr) {
70 LOG(ERROR) << "invalid callback passed to prepareModelBase";
71 return toAStatus(aidl_hal::ErrorStatus::INVALID_ARGUMENT,
72 "invalid callback passed to prepareModelBase");
73 }
74 const auto canonicalModel = convert(model);
75 if (!canonicalModel.has_value()) {
76 VLOG(DRIVER) << "invalid model passed to prepareModelBase";
77 notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
78 return toAStatus(aidl_hal::ErrorStatus::INVALID_ARGUMENT,
79 "invalid model passed to prepareModelBase");
80 }
81 if (VLOG_IS_ON(DRIVER)) {
82 VLOG(DRIVER) << "prepareModelBase";
83 logModelToInfo(canonicalModel.value());
84 }
85 if (!aidl_hal::utils::valid(preference)) {
86 const std::string log_message =
87 "invalid execution preference passed to prepareModelBase: " + toString(preference);
88 VLOG(DRIVER) << log_message;
89 notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
90 return toAStatus(aidl_hal::ErrorStatus::INVALID_ARGUMENT, log_message);
91 }
92 if (!aidl_hal::utils::valid(priority)) {
93 const std::string log_message =
94 "invalid priority passed to prepareModelBase: " + toString(priority);
95 VLOG(DRIVER) << log_message;
96 notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
97 return toAStatus(aidl_hal::ErrorStatus::INVALID_ARGUMENT, log_message);
98 }
99
100 if (!isFullModelSupported) {
101 VLOG(DRIVER) << "model is not fully supported";
102 notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
103 return ndk::ScopedAStatus::ok();
104 }
105
106 if (halDeadline < -1) {
107 notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
108 return toAStatus(aidl_hal::ErrorStatus::INVALID_ARGUMENT,
109 "Invalid deadline: " + toString(halDeadline));
110 }
111 const auto deadline = makeDeadline(halDeadline);
112 if (hasDeadlinePassed(deadline)) {
113 notify(callback, aidl_hal::ErrorStatus::MISSED_DEADLINE_PERSISTENT, nullptr);
114 return ndk::ScopedAStatus::ok();
115 }
116
117 // asynchronously prepare the model from a new, detached thread
118 std::thread(
119 [driver, preference, userId, priority, callback](aidl_hal::Model&& model) {
120 std::shared_ptr<SamplePreparedModel> preparedModel =
121 ndk::SharedRefBase::make<SamplePreparedModel>(std::move(model), driver,
122 preference, userId, priority);
123 if (!preparedModel->initialize()) {
124 notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
125 return;
126 }
127 notify(callback, aidl_hal::ErrorStatus::NONE, preparedModel);
128 },
129 std::move(model))
130 .detach();
131
132 return ndk::ScopedAStatus::ok();
133 }
134
135 } // namespace sample_driver
136 } // namespace nn
137 } // namespace android
138