1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <HalInterfaces.h>
18 #include <SampleDriver.h>
19 #include <android-base/scopeguard.h>
20 #include <gtest/gtest.h>
21
22 #include <cstdlib>
23 #include <filesystem>
24 #include <numeric>
25 #include <string>
26 #include <string_view>
27 #include <tuple>
28 #include <vector>
29
30 #include "HalUtils.h"
31 #include "Manager.h"
32 #include "TestNeuralNetworksWrapper.h"
33
34 using namespace android::nn;
35 namespace hardware = android::hardware;
36 using WrapperResult = test_wrapper::Result;
37 using Type = test_wrapper::Type;
38 const V1_2::Timing kBadTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
39 template <typename T>
40 using MQDescriptorSync = ::android::hardware::MQDescriptorSync<T>;
41 using android::sp;
42
43 namespace android::hardware::neuralnetworks::V1_0 {
44
operator <<(::std::ostream & os,V1_3::ErrorStatus errorStatus)45 ::std::ostream& operator<<(::std::ostream& os, V1_3::ErrorStatus errorStatus) {
46 return os << toString(errorStatus);
47 }
48
49 } // namespace android::hardware::neuralnetworks::V1_0
50
51 namespace {
52
53 enum class HasCalledPrepareModel { NO, WITHOUT_CACHING, WITH_CACHING };
54
55 // Print HasCalledPrepareModel enum for better GTEST failure messages
operator <<(std::ostream & os,HasCalledPrepareModel hasCalledPrepareModel)56 std::ostream& operator<<(std::ostream& os, HasCalledPrepareModel hasCalledPrepareModel) {
57 switch (hasCalledPrepareModel) {
58 case HasCalledPrepareModel::NO:
59 return os << "NO";
60 case HasCalledPrepareModel::WITHOUT_CACHING:
61 return os << "WITHOUT_CACHING";
62 case HasCalledPrepareModel::WITH_CACHING:
63 return os << "WITH_CACHING";
64 }
65 CHECK(false) << "HasCalledPrepareModel print called with invalid code "
66 << static_cast<int>(hasCalledPrepareModel);
67 return os;
68 }
69
70 // Whether the driver is expected to be registered because it can pass initialization.
canDeviceBeRegistered(V1_3::ErrorStatus error,uint32_t numModelCache,uint32_t numDataCache)71 bool canDeviceBeRegistered(V1_3::ErrorStatus error, uint32_t numModelCache, uint32_t numDataCache) {
72 constexpr uint32_t maxNumCacheFiles =
73 static_cast<uint32_t>(V1_2::Constant::MAX_NUMBER_OF_CACHE_FILES);
74 return error == V1_3::ErrorStatus::NONE && numModelCache <= maxNumCacheFiles &&
75 numDataCache <= maxNumCacheFiles;
76 }
77
78 // Whether the driver supports caching based on the returns from getNumberOfCacheFilesNeeded.
isCachingSupported(uint32_t numModelCache,uint32_t numDataCache)79 bool isCachingSupported(uint32_t numModelCache, uint32_t numDataCache) {
80 return numModelCache != 0 || numDataCache != 0;
81 }
82
83 // This is an IDevice for testing purposes which overrides several methods from sample driver:
84 // - supports all the operations and is faster than cpu fallback.
85 // - overrides getNumberOfCacheFilesNeeded to report according to given parameters.
86 // - overrides prepareModelFromCache_1_3 to return error status according to
87 // mErrorStatusPrepareFromCache.
88 // - produces CachingPreparedModel on prepareModel and prepareModelFromCache_1_3.
89 //
90 // The cache entry is written by prepareModel_1_3 and is checked later by
91 // CachingDriver::prepareModelFromCache_1_3.
92 //
93 // The CachingDriver has 2 flags mHasCalledPrepareModelFromCache and mHasCalledPrepareModel
94 // to check if the correct methods are invoked by the runtime.
95 class CachingDriver : public sample_driver::SampleDriver {
96 private:
97 static constexpr size_t kCacheSize = 256;
98
99 class CachingPreparedModel : public V1_3::IPreparedModel {
100 public:
101 CachingPreparedModel() = default;
102
execute(const V1_0::Request &,const sp<V1_0::IExecutionCallback> &)103 hardware::Return<V1_0::ErrorStatus> execute(const V1_0::Request&,
104 const sp<V1_0::IExecutionCallback>&) override {
105 return V1_0::ErrorStatus::DEVICE_UNAVAILABLE;
106 }
execute_1_2(const V1_0::Request &,V1_2::MeasureTiming,const sp<V1_2::IExecutionCallback> &)107 hardware::Return<V1_0::ErrorStatus> execute_1_2(
108 const V1_0::Request&, V1_2::MeasureTiming,
109 const sp<V1_2::IExecutionCallback>&) override {
110 return V1_0::ErrorStatus::DEVICE_UNAVAILABLE;
111 }
execute_1_3(const V1_3::Request &,V1_2::MeasureTiming,const V1_3::OptionalTimePoint &,const V1_3::OptionalTimeoutDuration &,const sp<V1_3::IExecutionCallback> &)112 hardware::Return<V1_3::ErrorStatus> execute_1_3(
113 const V1_3::Request&, V1_2::MeasureTiming, const V1_3::OptionalTimePoint&,
114 const V1_3::OptionalTimeoutDuration&,
115 const sp<V1_3::IExecutionCallback>&) override {
116 return V1_3::ErrorStatus::DEVICE_UNAVAILABLE;
117 }
executeSynchronously(const V1_0::Request &,V1_2::MeasureTiming,executeSynchronously_cb cb)118 hardware::Return<void> executeSynchronously(const V1_0::Request&, V1_2::MeasureTiming,
119 executeSynchronously_cb cb) override {
120 cb(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, {}, kBadTiming);
121 return hardware::Void();
122 }
executeSynchronously_1_3(const V1_3::Request &,V1_2::MeasureTiming,const V1_3::OptionalTimePoint &,const V1_3::OptionalTimeoutDuration &,executeSynchronously_1_3_cb cb)123 hardware::Return<void> executeSynchronously_1_3(const V1_3::Request&, V1_2::MeasureTiming,
124 const V1_3::OptionalTimePoint&,
125 const V1_3::OptionalTimeoutDuration&,
126 executeSynchronously_1_3_cb cb) override {
127 cb(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, {}, kBadTiming);
128 return hardware::Void();
129 }
configureExecutionBurst(const sp<V1_2::IBurstCallback> &,const MQDescriptorSync<V1_2::FmqRequestDatum> &,const MQDescriptorSync<V1_2::FmqResultDatum> &,configureExecutionBurst_cb cb)130 hardware::Return<void> configureExecutionBurst(
131 const sp<V1_2::IBurstCallback>&, const MQDescriptorSync<V1_2::FmqRequestDatum>&,
132 const MQDescriptorSync<V1_2::FmqResultDatum>&,
133 configureExecutionBurst_cb cb) override {
134 cb(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, nullptr);
135 return hardware::Void();
136 }
executeFenced(const V1_3::Request &,const hardware::hidl_vec<hardware::hidl_handle> &,V1_2::MeasureTiming,const V1_3::OptionalTimePoint &,const V1_3::OptionalTimeoutDuration &,const V1_3::OptionalTimeoutDuration &,executeFenced_cb cb)137 hardware::Return<void> executeFenced(const V1_3::Request&,
138 const hardware::hidl_vec<hardware::hidl_handle>&,
139 V1_2::MeasureTiming, const V1_3::OptionalTimePoint&,
140 const V1_3::OptionalTimeoutDuration&,
141 const V1_3::OptionalTimeoutDuration&,
142 executeFenced_cb cb) {
143 cb(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, hardware::hidl_handle(nullptr), nullptr);
144 return hardware::Void();
145 }
146 };
147
148 public:
CachingDriver(std::string_view name,V1_3::ErrorStatus errorStatusGetNumCacheFiles,uint32_t numModelCache,uint32_t numDataCache,V1_3::ErrorStatus errorStatusPrepareFromCache)149 CachingDriver(std::string_view name, V1_3::ErrorStatus errorStatusGetNumCacheFiles,
150 uint32_t numModelCache, uint32_t numDataCache,
151 V1_3::ErrorStatus errorStatusPrepareFromCache)
152 : SampleDriver(name.data()),
153 mErrorStatusGetNumCacheFiles(errorStatusGetNumCacheFiles),
154 mNumModelCache(numModelCache),
155 mNumDataCache(numDataCache),
156 mErrorStatusPrepareFromCache(errorStatusPrepareFromCache) {
157 mModelCacheData.resize(kCacheSize);
158 std::iota(mModelCacheData.begin(), mModelCacheData.end(), 0);
159 mDataCacheData.resize(kCacheSize);
160 std::iota(mDataCacheData.begin(), mDataCacheData.end(), 1);
161 }
~CachingDriver()162 ~CachingDriver() override {}
163
164 // Reports faster than cpu.
getCapabilities_1_3(getCapabilities_1_3_cb cb)165 hardware::Return<void> getCapabilities_1_3(getCapabilities_1_3_cb cb) override {
166 android::nn::initVLogMask();
167 const V1_0::PerformanceInfo kPerf = {.execTime = 0.1, .powerUsage = 0.1};
168 V1_3::Capabilities capabilities = {
169 .relaxedFloat32toFloat16PerformanceScalar = kPerf,
170 .relaxedFloat32toFloat16PerformanceTensor = kPerf,
171 .operandPerformance = nonExtensionOperandPerformance<HalVersion::V1_3>(kPerf),
172 .ifPerformance = kPerf,
173 .whilePerformance = kPerf};
174 cb(V1_3::ErrorStatus::NONE, capabilities);
175 return hardware::Void();
176 }
177
178 // Reports supporting all operations.
getSupportedOperations_1_3(const V1_3::Model & model,getSupportedOperations_1_3_cb cb)179 hardware::Return<void> getSupportedOperations_1_3(const V1_3::Model& model,
180 getSupportedOperations_1_3_cb cb) override {
181 std::vector<bool> supported(model.main.operations.size(), true);
182 cb(V1_3::ErrorStatus::NONE, supported);
183 return hardware::Void();
184 }
185
186 // Reports according to mGetNumCacheFiles.
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)187 hardware::Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) override {
188 cb(convertToV1_0(mErrorStatusGetNumCacheFiles), mNumModelCache, mNumDataCache);
189 return hardware::Void();
190 }
191
192 // Generates CachingPreparedModel.
193 // Writes the cache entry per mCacheXData and sets mHasCalledPrepareModel.
prepareModel_1_3(const V1_3::Model &,V1_1::ExecutionPreference,V1_3::Priority,const V1_3::OptionalTimePoint &,const hardware::hidl_vec<hardware::hidl_handle> & modelCacheHandle,const hardware::hidl_vec<hardware::hidl_handle> & dataCacheHandle,const HalCacheToken &,const sp<V1_3::IPreparedModelCallback> & cb)194 hardware::Return<V1_3::ErrorStatus> prepareModel_1_3(
195 const V1_3::Model&, V1_1::ExecutionPreference, V1_3::Priority,
196 const V1_3::OptionalTimePoint&,
197 const hardware::hidl_vec<hardware::hidl_handle>& modelCacheHandle,
198 const hardware::hidl_vec<hardware::hidl_handle>& dataCacheHandle, const HalCacheToken&,
199 const sp<V1_3::IPreparedModelCallback>& cb) override {
200 checkNumberOfCacheHandles(modelCacheHandle.size(), dataCacheHandle.size());
201 if (modelCacheHandle.size() != 0 || dataCacheHandle.size() != 0) {
202 writeToCache(modelCacheHandle, mModelCacheData);
203 writeToCache(dataCacheHandle, mDataCacheData);
204 mHasCalledPrepareModel = HasCalledPrepareModel::WITH_CACHING;
205 } else {
206 mHasCalledPrepareModel = HasCalledPrepareModel::WITHOUT_CACHING;
207 }
208 cb->notify_1_3(V1_3::ErrorStatus::NONE, new CachingPreparedModel());
209 return V1_3::ErrorStatus::NONE;
210 }
211
212 // Checks if the cache entry is correct, notifies error status according to
213 // mErrorStatusPrepareFromCache, sets mHasCalledPrepareModelFromCache.
prepareModelFromCache_1_3(const V1_3::OptionalTimePoint &,const hardware::hidl_vec<hardware::hidl_handle> & modelCacheHandle,const hardware::hidl_vec<hardware::hidl_handle> & dataCacheHandle,const HalCacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)214 hardware::Return<V1_3::ErrorStatus> prepareModelFromCache_1_3(
215 const V1_3::OptionalTimePoint&,
216 const hardware::hidl_vec<hardware::hidl_handle>& modelCacheHandle,
217 const hardware::hidl_vec<hardware::hidl_handle>& dataCacheHandle, const HalCacheToken&,
218 const sp<V1_3::IPreparedModelCallback>& callback) override {
219 readFromCache(modelCacheHandle, mModelCacheData);
220 readFromCache(dataCacheHandle, mDataCacheData);
221 mHasCalledPrepareModelFromCache = true;
222 if (mErrorStatusPrepareFromCache == V1_3::ErrorStatus::NONE) {
223 callback->notify_1_3(mErrorStatusPrepareFromCache, new CachingPreparedModel());
224 } else {
225 callback->notify_1_3(mErrorStatusPrepareFromCache, nullptr);
226 }
227 return V1_3::ErrorStatus::NONE;
228 };
229
hasCalledPrepareModelFromCache() const230 bool hasCalledPrepareModelFromCache() const { return mHasCalledPrepareModelFromCache; }
hasCalledPrepareModel() const231 HasCalledPrepareModel hasCalledPrepareModel() const { return mHasCalledPrepareModel; }
232
233 private:
234 // Checks the number of cache files passed to the driver from runtime.
checkNumberOfCacheHandles(size_t modelCache,size_t dataCache)235 void checkNumberOfCacheHandles(size_t modelCache, size_t dataCache) {
236 if (isCachingSupported(mNumModelCache, mNumDataCache)) {
237 if (modelCache != 0 || dataCache != 0) {
238 ASSERT_EQ(modelCache, mNumModelCache);
239 ASSERT_EQ(dataCache, mNumDataCache);
240 }
241 } else {
242 ASSERT_EQ(modelCache, 0ul);
243 ASSERT_EQ(dataCache, 0ul);
244 }
245 }
246
writeToCache(const hardware::hidl_vec<hardware::hidl_handle> & handles,const std::vector<uint8_t> & cache)247 void writeToCache(const hardware::hidl_vec<hardware::hidl_handle>& handles,
248 const std::vector<uint8_t>& cache) {
249 for (uint32_t i = 0; i < handles.size(); ++i) {
250 ASSERT_EQ(handles[i]->numFds, 1);
251 EXPECT_EQ(write(handles[i]->data[0], cache.data(), kCacheSize),
252 static_cast<ssize_t>(kCacheSize));
253 }
254 }
255
readFromCache(const hardware::hidl_vec<hardware::hidl_handle> & handles,const std::vector<uint8_t> & expected)256 void readFromCache(const hardware::hidl_vec<hardware::hidl_handle>& handles,
257 const std::vector<uint8_t>& expected) {
258 for (uint32_t i = 0; i < handles.size(); ++i) {
259 ASSERT_EQ(handles[i]->numFds, 1);
260 std::vector<uint8_t> actual(kCacheSize);
261 EXPECT_EQ(read(handles[i]->data[0], actual.data(), kCacheSize),
262 static_cast<ssize_t>(kCacheSize));
263 EXPECT_EQ(actual, expected);
264 }
265 }
266
267 std::vector<uint8_t> mModelCacheData;
268 std::vector<uint8_t> mDataCacheData;
269
270 const V1_3::ErrorStatus mErrorStatusGetNumCacheFiles;
271 const uint32_t mNumModelCache;
272 const uint32_t mNumDataCache;
273 const V1_3::ErrorStatus mErrorStatusPrepareFromCache;
274
275 bool mHasCalledPrepareModelFromCache = false;
276 HasCalledPrepareModel mHasCalledPrepareModel = HasCalledPrepareModel::NO;
277 };
278
CreateBroadcastAddModel(test_wrapper::Model * model)279 void CreateBroadcastAddModel(test_wrapper::Model* model) {
280 test_wrapper::OperandType matrixType(Type::TENSOR_FLOAT32, {2, 2});
281 test_wrapper::OperandType vectorType(Type::TENSOR_FLOAT32, {2});
282 test_wrapper::OperandType scalarType(Type::INT32, {});
283 int32_t activation(ANEURALNETWORKS_FUSED_NONE);
284 auto a = model->addOperand(&matrixType);
285 auto b = model->addOperand(&vectorType);
286 auto c = model->addOperand(&matrixType);
287 auto d = model->addOperand(&scalarType);
288 model->setOperandValue(d, &activation, sizeof(activation));
289 model->addOperation(ANEURALNETWORKS_ADD, {a, b, d}, {c});
290 model->identifyInputsAndOutputs({a, b}, {c});
291 ASSERT_TRUE(model->isValid());
292 ASSERT_EQ(model->finish(), WrapperResult::NO_ERROR);
293 }
294
getDeviceWithName(std::string_view deviceName,const ANeuralNetworksDevice ** outputDevice)295 void getDeviceWithName(std::string_view deviceName, const ANeuralNetworksDevice** outputDevice) {
296 uint32_t numDevices = 0;
297 ASSERT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
298 EXPECT_GE(numDevices, (uint32_t)1);
299
300 int numMatchingDevices = 0;
301 for (uint32_t i = 0; i < numDevices; i++) {
302 ANeuralNetworksDevice* device = nullptr;
303 ASSERT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
304
305 const char* buffer = nullptr;
306 ASSERT_EQ(ANeuralNetworksDevice_getName(device, &buffer), ANEURALNETWORKS_NO_ERROR);
307 if (deviceName == buffer) {
308 *outputDevice = device;
309 numMatchingDevices++;
310 }
311 }
312
313 EXPECT_LE(numMatchingDevices, 1);
314 }
315
316 // Test device registration with a driver parameterized with
317 // - ErrorStatus returning from getNumberOfCacheFilesNeeded
318 // - Number of model cache files returning from getNumberOfCacheFilesNeeded
319 // - Number of data cache files returning from getNumberOfCacheFilesNeeded
320 using DeviceRegistrationTestParam = std::tuple<V1_3::ErrorStatus, uint32_t, uint32_t>;
321
322 class DeviceRegistrationTest : public ::testing::TestWithParam<DeviceRegistrationTestParam> {
323 protected:
324 static constexpr std::string_view kDeviceName = "deviceTestCompilationCaching";
325 const V1_3::ErrorStatus kErrorStatusGetNumCacheFiles = std::get<0>(GetParam());
326 const uint32_t kNumModelCache = std::get<1>(GetParam());
327 const uint32_t kNumDataCache = std::get<2>(GetParam());
328 const sp<CachingDriver> kDriver =
329 new CachingDriver(kDeviceName, kErrorStatusGetNumCacheFiles, kNumModelCache,
330 kNumDataCache, V1_3::ErrorStatus::NONE);
331 };
332
TEST_P(DeviceRegistrationTest,CachingFailure)333 TEST_P(DeviceRegistrationTest, CachingFailure) {
334 if (DeviceManager::get()->getUseCpuOnly()) {
335 return;
336 }
337
338 DeviceManager::get()->forTest_registerDevice(makeSharedDevice(kDeviceName.data(), kDriver));
339 const auto cleanup = android::base::make_scope_guard(
340 [] { DeviceManager::get()->forTest_reInitializeDeviceList(); });
341
342 // get device
343 const ANeuralNetworksDevice* device = nullptr;
344 getDeviceWithName(kDeviceName, &device);
345
346 // check if device registeration matches expectations
347 const bool isDeviceRegistered = (device != nullptr);
348 const bool expectDeviceToBeRegistered =
349 canDeviceBeRegistered(kErrorStatusGetNumCacheFiles, kNumModelCache, kNumDataCache);
350 ASSERT_EQ(isDeviceRegistered, expectDeviceToBeRegistered);
351 }
352
353 // Test model compilation with a driver parameterized with
354 // - Number of model cache files returning from getNumberOfCacheFilesNeeded
355 // - Number of data cache files returning from getNumberOfCacheFilesNeeded
356 // - ErrorStatus returning from prepareModelFromCache_1_3
357 using CompilationCachingTestParam = std::tuple<uint32_t, uint32_t, V1_3::ErrorStatus>;
358
359 class CompilationCachingTest : public ::testing::TestWithParam<CompilationCachingTestParam> {
360 protected:
SetUp()361 virtual void SetUp() override {
362 char cacheDirTemp[] =
363 "/data/local/tmp/AVeryLongDirectoryNameForTestCompilationCachingXXXXXX";
364 char* cacheDir = mkdtemp(cacheDirTemp);
365 ASSERT_NE(cacheDir, nullptr);
366 mCacheDir = cacheDir;
367 CreateBroadcastAddModel(&mModel);
368 }
369
TearDown()370 virtual void TearDown() override {
371 if (!::testing::Test::HasFailure()) {
372 std::filesystem::remove_all(mCacheDir);
373 }
374 }
375
compileModel(const sp<CachingDriver> & driver,bool withToken)376 void compileModel(const sp<CachingDriver>& driver, bool withToken) {
377 DeviceManager::get()->forTest_registerDevice(makeSharedDevice(kDeviceName.data(), driver));
378 const auto cleanup = android::base::make_scope_guard(
379 [] { DeviceManager::get()->forTest_reInitializeDeviceList(); });
380
381 // Get a handle to the single driver device matching kDeviceName.
382 const ANeuralNetworksDevice* device = nullptr;
383 getDeviceWithName(kDeviceName, &device);
384 ASSERT_NE(device, nullptr);
385
386 // Compile the model with the device.
387 ANeuralNetworksCompilation* compilation = nullptr;
388 ASSERT_EQ(ANeuralNetworksCompilation_createForDevices(mModel.getHandle(), &device, 1,
389 &compilation),
390 ANEURALNETWORKS_NO_ERROR);
391 if (withToken) {
392 ASSERT_EQ(ANeuralNetworksCompilation_setCaching(compilation, mCacheDir.c_str(),
393 kToken.data()),
394 ANEURALNETWORKS_NO_ERROR);
395 }
396 ASSERT_EQ(ANeuralNetworksCompilation_finish(compilation), ANEURALNETWORKS_NO_ERROR);
397
398 // close memory
399 ANeuralNetworksCompilation_free(compilation);
400 }
401
createCache()402 void createCache() {
403 sp<CachingDriver> driver =
404 new CachingDriver(kDeviceName, V1_3::ErrorStatus::NONE, kNumModelCache,
405 kNumDataCache, V1_3::ErrorStatus::NONE);
406 compileModel(driver, /*withToken=*/true);
407 }
408
409 static constexpr std::string_view kDeviceName = "deviceTestCompilationCaching";
410 const uint32_t kNumModelCache = std::get<0>(GetParam());
411 const uint32_t kNumDataCache = std::get<1>(GetParam());
412 const V1_3::ErrorStatus kErrorStatusPrepareFromCache = std::get<2>(GetParam());
413 const bool kIsCachingSupported = isCachingSupported(kNumModelCache, kNumDataCache);
414 test_wrapper::Model mModel;
415 std::string mCacheDir;
416 const HalCacheToken kToken{};
417 };
418
TEST_P(CompilationCachingTest,TokenProvidedAndCacheNotExist)419 TEST_P(CompilationCachingTest, TokenProvidedAndCacheNotExist) {
420 if (DeviceManager::get()->getUseCpuOnly()) {
421 return;
422 }
423 sp<CachingDriver> driver =
424 new CachingDriver(kDeviceName, V1_3::ErrorStatus::NONE, kNumModelCache, kNumDataCache,
425 kErrorStatusPrepareFromCache);
426 compileModel(driver, /*withToken=*/true);
427
428 // When cache file does not exist, the runtime should never call prepareModelFromCache_1_3.
429 EXPECT_FALSE(driver->hasCalledPrepareModelFromCache());
430
431 // The runtime should call prepareModel_1_3. It should request caching iff caching supported.
432 EXPECT_EQ(driver->hasCalledPrepareModel(), kIsCachingSupported
433 ? HasCalledPrepareModel::WITH_CACHING
434 : HasCalledPrepareModel::WITHOUT_CACHING);
435 }
436
TEST_P(CompilationCachingTest,TokenProvidedAndCacheExist)437 TEST_P(CompilationCachingTest, TokenProvidedAndCacheExist) {
438 if (DeviceManager::get()->getUseCpuOnly()) {
439 return;
440 }
441 createCache();
442 sp<CachingDriver> driver =
443 new CachingDriver(kDeviceName, V1_3::ErrorStatus::NONE, kNumModelCache, kNumDataCache,
444 kErrorStatusPrepareFromCache);
445 compileModel(driver, /*withToken=*/true);
446
447 // When cache files exist, the runtime should call prepareModelFromCache_1_3 iff caching
448 // supported.
449 EXPECT_EQ(driver->hasCalledPrepareModelFromCache(), kIsCachingSupported);
450
451 HasCalledPrepareModel expectHasCalledPrepareModel;
452 if (kIsCachingSupported) {
453 if (kErrorStatusPrepareFromCache == V1_3::ErrorStatus::NONE) {
454 // The runtime should not call prepareModel_1_3 iff caching supported and
455 // prepareModelFromCache_1_3 succeeds.
456 expectHasCalledPrepareModel = HasCalledPrepareModel::NO;
457 } else {
458 // The runtime should call prepareModel_1_3 and request caching iff caching supported
459 // but prepareModelFromCache_1_3 fails.
460 expectHasCalledPrepareModel = HasCalledPrepareModel::WITH_CACHING;
461 }
462 } else {
463 // The runtime should call prepareModel_1_3 without caching iff caching not supported.
464 expectHasCalledPrepareModel = HasCalledPrepareModel::WITHOUT_CACHING;
465 }
466 EXPECT_EQ(driver->hasCalledPrepareModel(), expectHasCalledPrepareModel);
467 }
468
TEST_P(CompilationCachingTest,TokenNotProvided)469 TEST_P(CompilationCachingTest, TokenNotProvided) {
470 if (DeviceManager::get()->getUseCpuOnly()) {
471 return;
472 }
473 sp<CachingDriver> driver =
474 new CachingDriver(kDeviceName, V1_3::ErrorStatus::NONE, kNumModelCache, kNumDataCache,
475 kErrorStatusPrepareFromCache);
476 compileModel(driver, /*withToken=*/false);
477
478 // When no NDK token is provided by the client, the runtime should never call
479 // prepareModelFromCache_1_3 or request caching with prepareModel_1_3.
480 EXPECT_FALSE(driver->hasCalledPrepareModelFromCache());
481 EXPECT_EQ(driver->hasCalledPrepareModel(), HasCalledPrepareModel::WITHOUT_CACHING);
482 }
483
484 static const auto kErrorStatusGetNumCacheFilesChoices =
485 testing::Values(V1_3::ErrorStatus::NONE, V1_3::ErrorStatus::DEVICE_UNAVAILABLE);
486 static const auto kNumCacheChoices =
487 testing::Values(0ul, 1ul, static_cast<uint32_t>(V1_2::Constant::MAX_NUMBER_OF_CACHE_FILES),
488 static_cast<uint32_t>(V1_2::Constant::MAX_NUMBER_OF_CACHE_FILES) + 1);
489 static const auto kNumValidCacheChoices =
490 testing::Values(0ul, 1ul, static_cast<uint32_t>(V1_2::Constant::MAX_NUMBER_OF_CACHE_FILES));
491 static const auto kErrorStatusPrepareFromCacheChoices =
492 testing::Values(V1_3::ErrorStatus::NONE, V1_3::ErrorStatus::GENERAL_FAILURE,
493 V1_3::ErrorStatus::DEVICE_UNAVAILABLE, V1_3::ErrorStatus::INVALID_ARGUMENT);
494
495 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, DeviceRegistrationTest,
496 testing::Combine(kErrorStatusGetNumCacheFilesChoices, kNumCacheChoices,
497 kNumCacheChoices));
498
499 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, CompilationCachingTest,
500 testing::Combine(kNumValidCacheChoices, kNumValidCacheChoices,
501 kErrorStatusPrepareFromCacheChoices));
502
503 } // namespace
504