1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H
18 #define ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H
19 
20 #include <CpuExecutor.h>
21 #include <HalBufferTracker.h>
22 #include <HalInterfaces.h>
23 #include <hwbinder/IPCThreadState.h>
24 
25 #include <memory>
26 #include <string>
27 #include <utility>
28 #include <vector>
29 
30 #include "NeuralNetworks.h"
31 
32 namespace android {
33 namespace nn {
34 namespace sample_driver {
35 
36 using hardware::MQDescriptorSync;
37 
38 // Manages the data buffer for an operand.
39 class SampleBuffer : public V1_3::IBuffer {
40    public:
SampleBuffer(std::shared_ptr<HalManagedBuffer> buffer,std::unique_ptr<HalBufferTracker::Token> token)41     SampleBuffer(std::shared_ptr<HalManagedBuffer> buffer,
42                  std::unique_ptr<HalBufferTracker::Token> token)
43         : kBuffer(std::move(buffer)), kToken(std::move(token)) {
44         CHECK(kBuffer != nullptr);
45         CHECK(kToken != nullptr);
46     }
47     hardware::Return<V1_3::ErrorStatus> copyTo(const hardware::hidl_memory& dst) override;
48     hardware::Return<V1_3::ErrorStatus> copyFrom(
49             const hardware::hidl_memory& src,
50             const hardware::hidl_vec<uint32_t>& dimensions) override;
51 
52    private:
53     const std::shared_ptr<HalManagedBuffer> kBuffer;
54     const std::unique_ptr<HalBufferTracker::Token> kToken;
55 };
56 
57 // Base class used to create sample drivers for the NN HAL.  This class
58 // provides some implementation of the more common functions.
59 //
60 // Since these drivers simulate hardware, they must run the computations
61 // on the CPU.  An actual driver would not do that.
62 class SampleDriver : public V1_3::IDevice {
63    public:
64     SampleDriver(const char* name,
65                  const IOperationResolver* operationResolver = BuiltinOperationResolver::get())
66         : mName(name),
67           mOperationResolver(operationResolver),
68           mHalBufferTracker(HalBufferTracker::create()) {
69         android::nn::initVLogMask();
70     }
71     hardware::Return<void> getCapabilities(getCapabilities_cb cb) override;
72     hardware::Return<void> getCapabilities_1_1(getCapabilities_1_1_cb cb) override;
73     hardware::Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override;
74     hardware::Return<void> getVersionString(getVersionString_cb cb) override;
75     hardware::Return<void> getType(getType_cb cb) override;
76     hardware::Return<void> getSupportedExtensions(getSupportedExtensions_cb) override;
77     hardware::Return<void> getSupportedOperations(const V1_0::Model& model,
78                                                   getSupportedOperations_cb cb) override;
79     hardware::Return<void> getSupportedOperations_1_1(const V1_1::Model& model,
80                                                       getSupportedOperations_1_1_cb cb) override;
81     hardware::Return<void> getSupportedOperations_1_2(const V1_2::Model& model,
82                                                       getSupportedOperations_1_2_cb cb) override;
83     hardware::Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) override;
84     hardware::Return<V1_0::ErrorStatus> prepareModel(
85             const V1_0::Model& model, const sp<V1_0::IPreparedModelCallback>& callback) override;
86     hardware::Return<V1_0::ErrorStatus> prepareModel_1_1(
87             const V1_1::Model& model, V1_1::ExecutionPreference preference,
88             const sp<V1_0::IPreparedModelCallback>& callback) override;
89     hardware::Return<V1_0::ErrorStatus> prepareModel_1_2(
90             const V1_2::Model& model, V1_1::ExecutionPreference preference,
91             const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
92             const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token,
93             const sp<V1_2::IPreparedModelCallback>& callback) override;
94     hardware::Return<V1_3::ErrorStatus> prepareModel_1_3(
95             const V1_3::Model& model, V1_1::ExecutionPreference preference, V1_3::Priority priority,
96             const V1_3::OptionalTimePoint& deadline,
97             const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
98             const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token,
99             const sp<V1_3::IPreparedModelCallback>& callback) override;
100     hardware::Return<V1_0::ErrorStatus> prepareModelFromCache(
101             const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
102             const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token,
103             const sp<V1_2::IPreparedModelCallback>& callback) override;
104     hardware::Return<V1_3::ErrorStatus> prepareModelFromCache_1_3(
105             const V1_3::OptionalTimePoint& deadline,
106             const hardware::hidl_vec<hardware::hidl_handle>& modelCache,
107             const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token,
108             const sp<V1_3::IPreparedModelCallback>& callback) override;
109     hardware::Return<V1_0::DeviceStatus> getStatus() override;
110     hardware::Return<void> allocate(
111             const V1_3::BufferDesc& desc,
112             const hardware::hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
113             const hardware::hidl_vec<V1_3::BufferRole>& inputRoles,
114             const hardware::hidl_vec<V1_3::BufferRole>& outputRoles, allocate_cb cb) override;
115 
116     // Starts and runs the driver service.  Typically called from main().
117     // This will return only once the service shuts down.
118     int run();
119 
getExecutor()120     CpuExecutor getExecutor() const { return CpuExecutor(mOperationResolver); }
getHalBufferTracker()121     const std::shared_ptr<HalBufferTracker>& getHalBufferTracker() const {
122         return mHalBufferTracker;
123     }
124 
125    protected:
126     std::string mName;
127     const IOperationResolver* mOperationResolver;
128     const std::shared_ptr<HalBufferTracker> mHalBufferTracker;
129 };
130 
131 class SamplePreparedModel : public V1_3::IPreparedModel {
132    public:
SamplePreparedModel(const V1_3::Model & model,const SampleDriver * driver,V1_1::ExecutionPreference preference,uid_t userId,V1_3::Priority priority)133     SamplePreparedModel(const V1_3::Model& model, const SampleDriver* driver,
134                         V1_1::ExecutionPreference preference, uid_t userId, V1_3::Priority priority)
135         : mModel(model),
136           mDriver(driver),
137           kPreference(preference),
138           kUserId(userId),
139           kPriority(priority) {
140         (void)kUserId;
141         (void)kPriority;
142     }
143     bool initialize();
144     hardware::Return<V1_0::ErrorStatus> execute(
145             const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override;
146     hardware::Return<V1_0::ErrorStatus> execute_1_2(
147             const V1_0::Request& request, V1_2::MeasureTiming measure,
148             const sp<V1_2::IExecutionCallback>& callback) override;
149     hardware::Return<V1_3::ErrorStatus> execute_1_3(
150             const V1_3::Request& request, V1_2::MeasureTiming measure,
151             const V1_3::OptionalTimePoint& deadline,
152             const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
153             const sp<V1_3::IExecutionCallback>& callback) override;
154     hardware::Return<void> executeSynchronously(const V1_0::Request& request,
155                                                 V1_2::MeasureTiming measure,
156                                                 executeSynchronously_cb cb) override;
157     hardware::Return<void> executeSynchronously_1_3(
158             const V1_3::Request& request, V1_2::MeasureTiming measure,
159             const V1_3::OptionalTimePoint& deadline,
160             const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
161             executeSynchronously_1_3_cb cb) override;
162     hardware::Return<void> configureExecutionBurst(
163             const sp<V1_2::IBurstCallback>& callback,
164             const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
165             const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
166             configureExecutionBurst_cb cb) override;
167     hardware::Return<void> executeFenced(const V1_3::Request& request,
168                                          const hardware::hidl_vec<hardware::hidl_handle>& wait_for,
169                                          V1_2::MeasureTiming measure,
170                                          const V1_3::OptionalTimePoint& deadline,
171                                          const V1_3::OptionalTimeoutDuration& loopTimeoutDuration,
172                                          const V1_3::OptionalTimeoutDuration& duration,
173                                          executeFenced_cb callback) override;
getModel()174     const V1_3::Model* getModel() const { return &mModel; }
175 
176    protected:
177     V1_3::Model mModel;
178     const SampleDriver* mDriver;
179     std::vector<RunTimePoolInfo> mPoolInfos;
180     const V1_1::ExecutionPreference kPreference;
181     const uid_t kUserId;
182     const V1_3::Priority kPriority;
183 };
184 
185 class SampleFencedExecutionCallback : public V1_3::IFencedExecutionCallback {
186    public:
SampleFencedExecutionCallback(V1_2::Timing timingSinceLaunch,V1_2::Timing timingAfterFence,V1_3::ErrorStatus error)187     SampleFencedExecutionCallback(V1_2::Timing timingSinceLaunch, V1_2::Timing timingAfterFence,
188                                   V1_3::ErrorStatus error)
189         : kTimingSinceLaunch(timingSinceLaunch),
190           kTimingAfterFence(timingAfterFence),
191           kErrorStatus(error) {}
getExecutionInfo(getExecutionInfo_cb callback)192     hardware::Return<void> getExecutionInfo(getExecutionInfo_cb callback) override {
193         callback(kErrorStatus, kTimingSinceLaunch, kTimingAfterFence);
194         return hardware::Void();
195     }
196 
197    private:
198     const V1_2::Timing kTimingSinceLaunch;
199     const V1_2::Timing kTimingAfterFence;
200     const V1_3::ErrorStatus kErrorStatus;
201 };
202 
203 }  // namespace sample_driver
204 }  // namespace nn
205 }  // namespace android
206 
207 #endif  // ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H
208