1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_AIDL_SAMPLE_DRIVER_H
18 #define ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_AIDL_SAMPLE_DRIVER_H
19 
20 #include <android/binder_auto_utils.h>
21 
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "AidlBufferTracker.h"
28 #include "AidlHalInterfaces.h"
29 #include "CpuExecutor.h"
30 #include "NeuralNetworks.h"
31 
32 namespace android {
33 namespace nn {
34 namespace sample_driver {
35 
36 // Manages the data buffer for an operand.
37 class SampleBuffer : public aidl_hal::BnBuffer {
38    public:
SampleBuffer(std::shared_ptr<AidlManagedBuffer> buffer,std::unique_ptr<AidlBufferTracker::Token> token)39     SampleBuffer(std::shared_ptr<AidlManagedBuffer> buffer,
40                  std::unique_ptr<AidlBufferTracker::Token> token)
41         : kBuffer(std::move(buffer)), kToken(std::move(token)) {
42         CHECK(kBuffer != nullptr);
43         CHECK(kToken != nullptr);
44     }
45     ndk::ScopedAStatus copyFrom(const aidl_hal::Memory& src,
46                                 const std::vector<int32_t>& dimensions) override;
47     ndk::ScopedAStatus copyTo(const aidl_hal::Memory& dst) override;
48 
49    private:
50     const std::shared_ptr<AidlManagedBuffer> kBuffer;
51     const std::unique_ptr<AidlBufferTracker::Token> kToken;
52 };
53 
54 // Base class used to create sample drivers for the NN HAL.  This class
55 // provides some implementation of the more common functions.
56 //
57 // Since these drivers simulate hardware, they must run the computations
58 // on the CPU.  An actual driver would not do that.
59 class SampleDriver : public aidl_hal::BnDevice {
60    public:
61     SampleDriver(const char* name,
62                  const IOperationResolver* operationResolver = BuiltinOperationResolver::get())
63         : mName(name),
64           mOperationResolver(operationResolver),
65           mBufferTracker(AidlBufferTracker::create()) {
66         android::nn::initVLogMask();
67     }
68     ndk::ScopedAStatus allocate(const aidl_hal::BufferDesc& desc,
69                                 const std::vector<aidl_hal::IPreparedModelParcel>& preparedModels,
70                                 const std::vector<aidl_hal::BufferRole>& inputRoles,
71                                 const std::vector<aidl_hal::BufferRole>& outputRoles,
72                                 aidl_hal::DeviceBuffer* buffer) override;
73     ndk::ScopedAStatus getNumberOfCacheFilesNeeded(
74             aidl_hal::NumberOfCacheFiles* numberOfCacheFiles) override;
75     ndk::ScopedAStatus getSupportedExtensions(
76             std::vector<aidl_hal::Extension>* extensions) override;
77     ndk::ScopedAStatus getType(aidl_hal::DeviceType* deviceType) override;
78     ndk::ScopedAStatus getVersionString(std::string* version) override;
79     ndk::ScopedAStatus prepareModel(
80             const aidl_hal::Model& model, aidl_hal::ExecutionPreference preference,
81             aidl_hal::Priority priority, int64_t deadlineNs,
82             const std::vector<ndk::ScopedFileDescriptor>& modelCache,
83             const std::vector<ndk::ScopedFileDescriptor>& dataCache,
84             const std::vector<uint8_t>& token,
85             const std::shared_ptr<aidl_hal::IPreparedModelCallback>& callback) override;
86     ndk::ScopedAStatus prepareModelFromCache(
87             int64_t deadlineNs, const std::vector<ndk::ScopedFileDescriptor>& modelCache,
88             const std::vector<ndk::ScopedFileDescriptor>& dataCache,
89             const std::vector<uint8_t>& token,
90             const std::shared_ptr<aidl_hal::IPreparedModelCallback>& callback) override;
91 
92     // Starts and runs the driver service.  Typically called from main().
93     // This will return only once the service shuts down.
94     int run();
95 
getExecutor()96     CpuExecutor getExecutor() const { return CpuExecutor(mOperationResolver); }
getBufferTracker()97     const std::shared_ptr<AidlBufferTracker>& getBufferTracker() const { return mBufferTracker; }
98 
99    protected:
100     std::string mName;
101     const IOperationResolver* mOperationResolver;
102     const std::shared_ptr<AidlBufferTracker> mBufferTracker;
103 };
104 
105 class SamplePreparedModel : public aidl_hal::BnPreparedModel {
106    public:
SamplePreparedModel(aidl_hal::Model && model,const SampleDriver * driver,aidl_hal::ExecutionPreference preference,uid_t userId,aidl_hal::Priority priority)107     SamplePreparedModel(aidl_hal::Model&& model, const SampleDriver* driver,
108                         aidl_hal::ExecutionPreference preference, uid_t userId,
109                         aidl_hal::Priority priority)
110         : mModel(std::move(model)),
111           mDriver(driver),
112           kPreference(preference),
113           kUserId(userId),
114           kPriority(priority) {
115         (void)kUserId;
116         (void)kPriority;
117     }
118     bool initialize();
119     ndk::ScopedAStatus executeSynchronously(const aidl_hal::Request& request, bool measureTiming,
120                                             int64_t deadlineNs, int64_t loopTimeoutDurationNs,
121                                             aidl_hal::ExecutionResult* executionResult) override;
122     ndk::ScopedAStatus executeFenced(const aidl_hal::Request& request,
123                                      const std::vector<ndk::ScopedFileDescriptor>& waitFor,
124                                      bool measureTiming, int64_t deadlineNs,
125                                      int64_t loopTimeoutDurationNs, int64_t durationNs,
126                                      aidl_hal::FencedExecutionResult* executionResult) override;
127     ndk::ScopedAStatus configureExecutionBurst(std::shared_ptr<aidl_hal::IBurst>* burst) override;
getModel()128     const aidl_hal::Model* getModel() const { return &mModel; }
129 
130    protected:
131     aidl_hal::Model mModel;
132     const SampleDriver* mDriver;
133     std::vector<RunTimePoolInfo> mPoolInfos;
134     const aidl_hal::ExecutionPreference kPreference;
135     const uid_t kUserId;
136     const aidl_hal::Priority kPriority;
137 };
138 
139 class SampleFencedExecutionCallback : public aidl_hal::BnFencedExecutionCallback {
140    public:
SampleFencedExecutionCallback(aidl_hal::Timing timingSinceLaunch,aidl_hal::Timing timingAfterFence,aidl_hal::ErrorStatus error)141     SampleFencedExecutionCallback(aidl_hal::Timing timingSinceLaunch,
142                                   aidl_hal::Timing timingAfterFence, aidl_hal::ErrorStatus error)
143         : kTimingSinceLaunch(timingSinceLaunch),
144           kTimingAfterFence(timingAfterFence),
145           kErrorStatus(error) {}
getExecutionInfo(aidl_hal::Timing * timingLaunched,aidl_hal::Timing * timingFenced,aidl_hal::ErrorStatus * errorStatus)146     ndk::ScopedAStatus getExecutionInfo(aidl_hal::Timing* timingLaunched,
147                                         aidl_hal::Timing* timingFenced,
148                                         aidl_hal::ErrorStatus* errorStatus) override {
149         *timingLaunched = kTimingSinceLaunch;
150         *timingFenced = kTimingAfterFence;
151         *errorStatus = kErrorStatus;
152         return ndk::ScopedAStatus::ok();
153     }
154 
155    private:
156     const aidl_hal::Timing kTimingSinceLaunch;
157     const aidl_hal::Timing kTimingAfterFence;
158     const aidl_hal::ErrorStatus kErrorStatus;
159 };
160 
161 class SampleBurst : public aidl_hal::BnBurst {
162    public:
163     // Precondition: preparedModel != nullptr
164     explicit SampleBurst(std::shared_ptr<SamplePreparedModel> preparedModel);
165 
166     ndk::ScopedAStatus executeSynchronously(const aidl_hal::Request& request,
167                                             const std::vector<int64_t>& memoryIdentifierTokens,
168                                             bool measureTiming, int64_t deadlineNs,
169                                             int64_t loopTimeoutDurationNs,
170                                             aidl_hal::ExecutionResult* executionResult) override;
171     ndk::ScopedAStatus releaseMemoryResource(int64_t memoryIdentifierToken) override;
172 
173    protected:
174     std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT;
175     const std::shared_ptr<SamplePreparedModel> kPreparedModel;
176 };
177 
178 }  // namespace sample_driver
179 }  // namespace nn
180 }  // namespace android
181 
182 #endif  // ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_AIDL_SAMPLE_DRIVER_H
183