1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H 18 #define ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H 19 20 #include <CpuExecutor.h> 21 #include <HalBufferTracker.h> 22 #include <HalInterfaces.h> 23 #include <hwbinder/IPCThreadState.h> 24 25 #include <memory> 26 #include <string> 27 #include <utility> 28 #include <vector> 29 30 #include "NeuralNetworks.h" 31 32 namespace android { 33 namespace nn { 34 namespace sample_driver { 35 36 using hardware::MQDescriptorSync; 37 38 // Manages the data buffer for an operand. 39 class SampleBuffer : public V1_3::IBuffer { 40 public: SampleBuffer(std::shared_ptr<HalManagedBuffer> buffer,std::unique_ptr<HalBufferTracker::Token> token)41 SampleBuffer(std::shared_ptr<HalManagedBuffer> buffer, 42 std::unique_ptr<HalBufferTracker::Token> token) 43 : kBuffer(std::move(buffer)), kToken(std::move(token)) { 44 CHECK(kBuffer != nullptr); 45 CHECK(kToken != nullptr); 46 } 47 hardware::Return<V1_3::ErrorStatus> copyTo(const hardware::hidl_memory& dst) override; 48 hardware::Return<V1_3::ErrorStatus> copyFrom( 49 const hardware::hidl_memory& src, 50 const hardware::hidl_vec<uint32_t>& dimensions) override; 51 52 private: 53 const std::shared_ptr<HalManagedBuffer> kBuffer; 54 const std::unique_ptr<HalBufferTracker::Token> kToken; 55 }; 56 57 // Base class used to create sample drivers for the NN HAL. This class 58 // provides some implementation of the more common functions. 59 // 60 // Since these drivers simulate hardware, they must run the computations 61 // on the CPU. An actual driver would not do that. 62 class SampleDriver : public V1_3::IDevice { 63 public: 64 SampleDriver(const char* name, 65 const IOperationResolver* operationResolver = BuiltinOperationResolver::get()) 66 : mName(name), 67 mOperationResolver(operationResolver), 68 mHalBufferTracker(HalBufferTracker::create()) { 69 android::nn::initVLogMask(); 70 } 71 hardware::Return<void> getCapabilities(getCapabilities_cb cb) override; 72 hardware::Return<void> getCapabilities_1_1(getCapabilities_1_1_cb cb) override; 73 hardware::Return<void> getCapabilities_1_2(getCapabilities_1_2_cb cb) override; 74 hardware::Return<void> getVersionString(getVersionString_cb cb) override; 75 hardware::Return<void> getType(getType_cb cb) override; 76 hardware::Return<void> getSupportedExtensions(getSupportedExtensions_cb) override; 77 hardware::Return<void> getSupportedOperations(const V1_0::Model& model, 78 getSupportedOperations_cb cb) override; 79 hardware::Return<void> getSupportedOperations_1_1(const V1_1::Model& model, 80 getSupportedOperations_1_1_cb cb) override; 81 hardware::Return<void> getSupportedOperations_1_2(const V1_2::Model& model, 82 getSupportedOperations_1_2_cb cb) override; 83 hardware::Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) override; 84 hardware::Return<V1_0::ErrorStatus> prepareModel( 85 const V1_0::Model& model, const sp<V1_0::IPreparedModelCallback>& callback) override; 86 hardware::Return<V1_0::ErrorStatus> prepareModel_1_1( 87 const V1_1::Model& model, V1_1::ExecutionPreference preference, 88 const sp<V1_0::IPreparedModelCallback>& callback) override; 89 hardware::Return<V1_0::ErrorStatus> prepareModel_1_2( 90 const V1_2::Model& model, V1_1::ExecutionPreference preference, 91 const hardware::hidl_vec<hardware::hidl_handle>& modelCache, 92 const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token, 93 const sp<V1_2::IPreparedModelCallback>& callback) override; 94 hardware::Return<V1_3::ErrorStatus> prepareModel_1_3( 95 const V1_3::Model& model, V1_1::ExecutionPreference preference, V1_3::Priority priority, 96 const V1_3::OptionalTimePoint& deadline, 97 const hardware::hidl_vec<hardware::hidl_handle>& modelCache, 98 const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token, 99 const sp<V1_3::IPreparedModelCallback>& callback) override; 100 hardware::Return<V1_0::ErrorStatus> prepareModelFromCache( 101 const hardware::hidl_vec<hardware::hidl_handle>& modelCache, 102 const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token, 103 const sp<V1_2::IPreparedModelCallback>& callback) override; 104 hardware::Return<V1_3::ErrorStatus> prepareModelFromCache_1_3( 105 const V1_3::OptionalTimePoint& deadline, 106 const hardware::hidl_vec<hardware::hidl_handle>& modelCache, 107 const hardware::hidl_vec<hardware::hidl_handle>& dataCache, const HalCacheToken& token, 108 const sp<V1_3::IPreparedModelCallback>& callback) override; 109 hardware::Return<V1_0::DeviceStatus> getStatus() override; 110 hardware::Return<void> allocate( 111 const V1_3::BufferDesc& desc, 112 const hardware::hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels, 113 const hardware::hidl_vec<V1_3::BufferRole>& inputRoles, 114 const hardware::hidl_vec<V1_3::BufferRole>& outputRoles, allocate_cb cb) override; 115 116 // Starts and runs the driver service. Typically called from main(). 117 // This will return only once the service shuts down. 118 int run(); 119 getExecutor()120 CpuExecutor getExecutor() const { return CpuExecutor(mOperationResolver); } getHalBufferTracker()121 const std::shared_ptr<HalBufferTracker>& getHalBufferTracker() const { 122 return mHalBufferTracker; 123 } 124 125 protected: 126 std::string mName; 127 const IOperationResolver* mOperationResolver; 128 const std::shared_ptr<HalBufferTracker> mHalBufferTracker; 129 }; 130 131 class SamplePreparedModel : public V1_3::IPreparedModel { 132 public: SamplePreparedModel(const V1_3::Model & model,const SampleDriver * driver,V1_1::ExecutionPreference preference,uid_t userId,V1_3::Priority priority)133 SamplePreparedModel(const V1_3::Model& model, const SampleDriver* driver, 134 V1_1::ExecutionPreference preference, uid_t userId, V1_3::Priority priority) 135 : mModel(model), 136 mDriver(driver), 137 kPreference(preference), 138 kUserId(userId), 139 kPriority(priority) { 140 (void)kUserId; 141 (void)kPriority; 142 } 143 bool initialize(); 144 hardware::Return<V1_0::ErrorStatus> execute( 145 const V1_0::Request& request, const sp<V1_0::IExecutionCallback>& callback) override; 146 hardware::Return<V1_0::ErrorStatus> execute_1_2( 147 const V1_0::Request& request, V1_2::MeasureTiming measure, 148 const sp<V1_2::IExecutionCallback>& callback) override; 149 hardware::Return<V1_3::ErrorStatus> execute_1_3( 150 const V1_3::Request& request, V1_2::MeasureTiming measure, 151 const V1_3::OptionalTimePoint& deadline, 152 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration, 153 const sp<V1_3::IExecutionCallback>& callback) override; 154 hardware::Return<void> executeSynchronously(const V1_0::Request& request, 155 V1_2::MeasureTiming measure, 156 executeSynchronously_cb cb) override; 157 hardware::Return<void> executeSynchronously_1_3( 158 const V1_3::Request& request, V1_2::MeasureTiming measure, 159 const V1_3::OptionalTimePoint& deadline, 160 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration, 161 executeSynchronously_1_3_cb cb) override; 162 hardware::Return<void> configureExecutionBurst( 163 const sp<V1_2::IBurstCallback>& callback, 164 const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel, 165 const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel, 166 configureExecutionBurst_cb cb) override; 167 hardware::Return<void> executeFenced(const V1_3::Request& request, 168 const hardware::hidl_vec<hardware::hidl_handle>& wait_for, 169 V1_2::MeasureTiming measure, 170 const V1_3::OptionalTimePoint& deadline, 171 const V1_3::OptionalTimeoutDuration& loopTimeoutDuration, 172 const V1_3::OptionalTimeoutDuration& duration, 173 executeFenced_cb callback) override; getModel()174 const V1_3::Model* getModel() const { return &mModel; } 175 176 protected: 177 V1_3::Model mModel; 178 const SampleDriver* mDriver; 179 std::vector<RunTimePoolInfo> mPoolInfos; 180 const V1_1::ExecutionPreference kPreference; 181 const uid_t kUserId; 182 const V1_3::Priority kPriority; 183 }; 184 185 class SampleFencedExecutionCallback : public V1_3::IFencedExecutionCallback { 186 public: SampleFencedExecutionCallback(V1_2::Timing timingSinceLaunch,V1_2::Timing timingAfterFence,V1_3::ErrorStatus error)187 SampleFencedExecutionCallback(V1_2::Timing timingSinceLaunch, V1_2::Timing timingAfterFence, 188 V1_3::ErrorStatus error) 189 : kTimingSinceLaunch(timingSinceLaunch), 190 kTimingAfterFence(timingAfterFence), 191 kErrorStatus(error) {} getExecutionInfo(getExecutionInfo_cb callback)192 hardware::Return<void> getExecutionInfo(getExecutionInfo_cb callback) override { 193 callback(kErrorStatus, kTimingSinceLaunch, kTimingAfterFence); 194 return hardware::Void(); 195 } 196 197 private: 198 const V1_2::Timing kTimingSinceLaunch; 199 const V1_2::Timing kTimingAfterFence; 200 const V1_3::ErrorStatus kErrorStatus; 201 }; 202 203 } // namespace sample_driver 204 } // namespace nn 205 } // namespace android 206 207 #endif // ANDROID_FRAMEWORKS_ML_NN_DRIVER_SAMPLE_SAMPLE_DRIVER_H 208