1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "SampleDriverUtils.h"
18 
19 #include <aidl/android/hardware/common/NativeHandle.h>
20 #include <android/binder_auto_utils.h>
21 #include <android/binder_ibinder.h>
22 #include <nnapi/Validation.h>
23 #include <nnapi/hal/aidl/Conversions.h>
24 #include <nnapi/hal/aidl/Utils.h>
25 #include <utils/NativeHandle.h>
26 
27 #include <memory>
28 #include <string>
29 #include <thread>
30 #include <utility>
31 
32 #include "SampleDriver.h"
33 
34 namespace android {
35 namespace nn {
36 namespace sample_driver {
37 
notify(const std::shared_ptr<aidl_hal::IPreparedModelCallback> & callback,const aidl_hal::ErrorStatus & status,const std::shared_ptr<aidl_hal::IPreparedModel> & preparedModel)38 void notify(const std::shared_ptr<aidl_hal::IPreparedModelCallback>& callback,
39             const aidl_hal::ErrorStatus& status,
40             const std::shared_ptr<aidl_hal::IPreparedModel>& preparedModel) {
41     const auto ret = callback->notify(status, preparedModel);
42     if (!ret.isOk()) {
43         LOG(ERROR) << "Error when calling IPreparedModelCallback::notify: " << ret.getDescription()
44                    << " " << ret.getMessage();
45     }
46 }
47 
toAStatus(aidl_hal::ErrorStatus errorStatus)48 ndk::ScopedAStatus toAStatus(aidl_hal::ErrorStatus errorStatus) {
49     if (errorStatus == aidl_hal::ErrorStatus::NONE) {
50         return ndk::ScopedAStatus::ok();
51     }
52     return ndk::ScopedAStatus::fromServiceSpecificError(static_cast<int32_t>(errorStatus));
53 }
54 
toAStatus(aidl_hal::ErrorStatus errorStatus,const std::string & errorMessage)55 ndk::ScopedAStatus toAStatus(aidl_hal::ErrorStatus errorStatus, const std::string& errorMessage) {
56     if (errorStatus == aidl_hal::ErrorStatus::NONE) {
57         return ndk::ScopedAStatus::ok();
58     }
59     return ndk::ScopedAStatus::fromServiceSpecificErrorWithMessage(
60             static_cast<int32_t>(errorStatus), errorMessage.c_str());
61 }
62 
prepareModelBase(aidl_hal::Model && model,const SampleDriver * driver,aidl_hal::ExecutionPreference preference,aidl_hal::Priority priority,int64_t halDeadline,const std::shared_ptr<aidl_hal::IPreparedModelCallback> & callback,bool isFullModelSupported)63 ndk::ScopedAStatus prepareModelBase(
64         aidl_hal::Model&& model, const SampleDriver* driver,
65         aidl_hal::ExecutionPreference preference, aidl_hal::Priority priority, int64_t halDeadline,
66         const std::shared_ptr<aidl_hal::IPreparedModelCallback>& callback,
67         bool isFullModelSupported) {
68     const uid_t userId = AIBinder_getCallingUid();
69     if (callback.get() == nullptr) {
70         LOG(ERROR) << "invalid callback passed to prepareModelBase";
71         return toAStatus(aidl_hal::ErrorStatus::INVALID_ARGUMENT,
72                          "invalid callback passed to prepareModelBase");
73     }
74     const auto canonicalModel = convert(model);
75     if (!canonicalModel.has_value()) {
76         VLOG(DRIVER) << "invalid model passed to prepareModelBase";
77         notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
78         return toAStatus(aidl_hal::ErrorStatus::INVALID_ARGUMENT,
79                          "invalid model passed to prepareModelBase");
80     }
81     if (VLOG_IS_ON(DRIVER)) {
82         VLOG(DRIVER) << "prepareModelBase";
83         logModelToInfo(canonicalModel.value());
84     }
85     if (!aidl_hal::utils::valid(preference)) {
86         const std::string log_message =
87                 "invalid execution preference passed to prepareModelBase: " + toString(preference);
88         VLOG(DRIVER) << log_message;
89         notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
90         return toAStatus(aidl_hal::ErrorStatus::INVALID_ARGUMENT, log_message);
91     }
92     if (!aidl_hal::utils::valid(priority)) {
93         const std::string log_message =
94                 "invalid priority passed to prepareModelBase: " + toString(priority);
95         VLOG(DRIVER) << log_message;
96         notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
97         return toAStatus(aidl_hal::ErrorStatus::INVALID_ARGUMENT, log_message);
98     }
99 
100     if (!isFullModelSupported) {
101         VLOG(DRIVER) << "model is not fully supported";
102         notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
103         return ndk::ScopedAStatus::ok();
104     }
105 
106     if (halDeadline < -1) {
107         notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
108         return toAStatus(aidl_hal::ErrorStatus::INVALID_ARGUMENT,
109                          "Invalid deadline: " + toString(halDeadline));
110     }
111     const auto deadline = makeDeadline(halDeadline);
112     if (hasDeadlinePassed(deadline)) {
113         notify(callback, aidl_hal::ErrorStatus::MISSED_DEADLINE_PERSISTENT, nullptr);
114         return ndk::ScopedAStatus::ok();
115     }
116 
117     // asynchronously prepare the model from a new, detached thread
118     std::thread(
119             [driver, preference, userId, priority, callback](aidl_hal::Model&& model) {
120                 std::shared_ptr<SamplePreparedModel> preparedModel =
121                         ndk::SharedRefBase::make<SamplePreparedModel>(std::move(model), driver,
122                                                                       preference, userId, priority);
123                 if (!preparedModel->initialize()) {
124                     notify(callback, aidl_hal::ErrorStatus::INVALID_ARGUMENT, nullptr);
125                     return;
126                 }
127                 notify(callback, aidl_hal::ErrorStatus::NONE, preparedModel);
128             },
129             std::move(model))
130             .detach();
131 
132     return ndk::ScopedAStatus::ok();
133 }
134 
135 }  // namespace sample_driver
136 }  // namespace nn
137 }  // namespace android
138