1 /* 2 * Copyright (C) 2021 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_AIDL_SAMPLE_DRIVER_H 18 #define ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_AIDL_SAMPLE_DRIVER_H 19 20 #include <android/binder_auto_utils.h> 21 22 #include <memory> 23 #include <string> 24 #include <utility> 25 #include <vector> 26 27 #include "AidlBufferTracker.h" 28 #include "AidlHalInterfaces.h" 29 #include "CpuExecutor.h" 30 #include "NeuralNetworks.h" 31 32 namespace android { 33 namespace nn { 34 namespace sample_driver { 35 36 // Manages the data buffer for an operand. 37 class SampleBuffer : public aidl_hal::BnBuffer { 38 public: SampleBuffer(std::shared_ptr<AidlManagedBuffer> buffer,std::unique_ptr<AidlBufferTracker::Token> token)39 SampleBuffer(std::shared_ptr<AidlManagedBuffer> buffer, 40 std::unique_ptr<AidlBufferTracker::Token> token) 41 : kBuffer(std::move(buffer)), kToken(std::move(token)) { 42 CHECK(kBuffer != nullptr); 43 CHECK(kToken != nullptr); 44 } 45 ndk::ScopedAStatus copyFrom(const aidl_hal::Memory& src, 46 const std::vector<int32_t>& dimensions) override; 47 ndk::ScopedAStatus copyTo(const aidl_hal::Memory& dst) override; 48 49 private: 50 const std::shared_ptr<AidlManagedBuffer> kBuffer; 51 const std::unique_ptr<AidlBufferTracker::Token> kToken; 52 }; 53 54 // Base class used to create sample drivers for the NN HAL. This class 55 // provides some implementation of the more common functions. 56 // 57 // Since these drivers simulate hardware, they must run the computations 58 // on the CPU. An actual driver would not do that. 59 class SampleDriver : public aidl_hal::BnDevice { 60 public: 61 SampleDriver(const char* name, 62 const IOperationResolver* operationResolver = BuiltinOperationResolver::get()) 63 : mName(name), 64 mOperationResolver(operationResolver), 65 mBufferTracker(AidlBufferTracker::create()) { 66 android::nn::initVLogMask(); 67 } 68 ndk::ScopedAStatus allocate(const aidl_hal::BufferDesc& desc, 69 const std::vector<aidl_hal::IPreparedModelParcel>& preparedModels, 70 const std::vector<aidl_hal::BufferRole>& inputRoles, 71 const std::vector<aidl_hal::BufferRole>& outputRoles, 72 aidl_hal::DeviceBuffer* buffer) override; 73 ndk::ScopedAStatus getNumberOfCacheFilesNeeded( 74 aidl_hal::NumberOfCacheFiles* numberOfCacheFiles) override; 75 ndk::ScopedAStatus getSupportedExtensions( 76 std::vector<aidl_hal::Extension>* extensions) override; 77 ndk::ScopedAStatus getType(aidl_hal::DeviceType* deviceType) override; 78 ndk::ScopedAStatus getVersionString(std::string* version) override; 79 ndk::ScopedAStatus prepareModel( 80 const aidl_hal::Model& model, aidl_hal::ExecutionPreference preference, 81 aidl_hal::Priority priority, int64_t deadlineNs, 82 const std::vector<ndk::ScopedFileDescriptor>& modelCache, 83 const std::vector<ndk::ScopedFileDescriptor>& dataCache, 84 const std::vector<uint8_t>& token, 85 const std::shared_ptr<aidl_hal::IPreparedModelCallback>& callback) override; 86 ndk::ScopedAStatus prepareModelFromCache( 87 int64_t deadlineNs, const std::vector<ndk::ScopedFileDescriptor>& modelCache, 88 const std::vector<ndk::ScopedFileDescriptor>& dataCache, 89 const std::vector<uint8_t>& token, 90 const std::shared_ptr<aidl_hal::IPreparedModelCallback>& callback) override; 91 92 // Starts and runs the driver service. Typically called from main(). 93 // This will return only once the service shuts down. 94 int run(); 95 getExecutor()96 CpuExecutor getExecutor() const { return CpuExecutor(mOperationResolver); } getBufferTracker()97 const std::shared_ptr<AidlBufferTracker>& getBufferTracker() const { return mBufferTracker; } 98 99 protected: 100 std::string mName; 101 const IOperationResolver* mOperationResolver; 102 const std::shared_ptr<AidlBufferTracker> mBufferTracker; 103 }; 104 105 class SamplePreparedModel : public aidl_hal::BnPreparedModel { 106 public: SamplePreparedModel(aidl_hal::Model && model,const SampleDriver * driver,aidl_hal::ExecutionPreference preference,uid_t userId,aidl_hal::Priority priority)107 SamplePreparedModel(aidl_hal::Model&& model, const SampleDriver* driver, 108 aidl_hal::ExecutionPreference preference, uid_t userId, 109 aidl_hal::Priority priority) 110 : mModel(std::move(model)), 111 mDriver(driver), 112 kPreference(preference), 113 kUserId(userId), 114 kPriority(priority) { 115 (void)kUserId; 116 (void)kPriority; 117 } 118 bool initialize(); 119 ndk::ScopedAStatus executeSynchronously(const aidl_hal::Request& request, bool measureTiming, 120 int64_t deadlineNs, int64_t loopTimeoutDurationNs, 121 aidl_hal::ExecutionResult* executionResult) override; 122 ndk::ScopedAStatus executeFenced(const aidl_hal::Request& request, 123 const std::vector<ndk::ScopedFileDescriptor>& waitFor, 124 bool measureTiming, int64_t deadlineNs, 125 int64_t loopTimeoutDurationNs, int64_t durationNs, 126 aidl_hal::FencedExecutionResult* executionResult) override; 127 ndk::ScopedAStatus configureExecutionBurst(std::shared_ptr<aidl_hal::IBurst>* burst) override; getModel()128 const aidl_hal::Model* getModel() const { return &mModel; } 129 130 protected: 131 aidl_hal::Model mModel; 132 const SampleDriver* mDriver; 133 std::vector<RunTimePoolInfo> mPoolInfos; 134 const aidl_hal::ExecutionPreference kPreference; 135 const uid_t kUserId; 136 const aidl_hal::Priority kPriority; 137 }; 138 139 class SampleFencedExecutionCallback : public aidl_hal::BnFencedExecutionCallback { 140 public: SampleFencedExecutionCallback(aidl_hal::Timing timingSinceLaunch,aidl_hal::Timing timingAfterFence,aidl_hal::ErrorStatus error)141 SampleFencedExecutionCallback(aidl_hal::Timing timingSinceLaunch, 142 aidl_hal::Timing timingAfterFence, aidl_hal::ErrorStatus error) 143 : kTimingSinceLaunch(timingSinceLaunch), 144 kTimingAfterFence(timingAfterFence), 145 kErrorStatus(error) {} getExecutionInfo(aidl_hal::Timing * timingLaunched,aidl_hal::Timing * timingFenced,aidl_hal::ErrorStatus * errorStatus)146 ndk::ScopedAStatus getExecutionInfo(aidl_hal::Timing* timingLaunched, 147 aidl_hal::Timing* timingFenced, 148 aidl_hal::ErrorStatus* errorStatus) override { 149 *timingLaunched = kTimingSinceLaunch; 150 *timingFenced = kTimingAfterFence; 151 *errorStatus = kErrorStatus; 152 return ndk::ScopedAStatus::ok(); 153 } 154 155 private: 156 const aidl_hal::Timing kTimingSinceLaunch; 157 const aidl_hal::Timing kTimingAfterFence; 158 const aidl_hal::ErrorStatus kErrorStatus; 159 }; 160 161 class SampleBurst : public aidl_hal::BnBurst { 162 public: 163 // Precondition: preparedModel != nullptr 164 explicit SampleBurst(std::shared_ptr<SamplePreparedModel> preparedModel); 165 166 ndk::ScopedAStatus executeSynchronously(const aidl_hal::Request& request, 167 const std::vector<int64_t>& memoryIdentifierTokens, 168 bool measureTiming, int64_t deadlineNs, 169 int64_t loopTimeoutDurationNs, 170 aidl_hal::ExecutionResult* executionResult) override; 171 ndk::ScopedAStatus releaseMemoryResource(int64_t memoryIdentifierToken) override; 172 173 protected: 174 std::atomic_flag mExecutionInFlight = ATOMIC_FLAG_INIT; 175 const std::shared_ptr<SamplePreparedModel> kPreparedModel; 176 }; 177 178 } // namespace sample_driver 179 } // namespace nn 180 } // namespace android 181 182 #endif // ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_AIDL_SAMPLE_DRIVER_H 183