1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <HalInterfaces.h>
18 #include <SampleDriver.h>
19 #include <android-base/scopeguard.h>
20 #include <gtest/gtest.h>
21 
22 #include <cstdlib>
23 #include <filesystem>
24 #include <numeric>
25 #include <string>
26 #include <string_view>
27 #include <tuple>
28 #include <vector>
29 
30 #include "HalUtils.h"
31 #include "Manager.h"
32 #include "TestNeuralNetworksWrapper.h"
33 
34 using namespace android::nn;
35 namespace hardware = android::hardware;
36 using WrapperResult = test_wrapper::Result;
37 using Type = test_wrapper::Type;
38 const V1_2::Timing kBadTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
39 template <typename T>
40 using MQDescriptorSync = ::android::hardware::MQDescriptorSync<T>;
41 using android::sp;
42 
43 namespace android::hardware::neuralnetworks::V1_0 {
44 
operator <<(::std::ostream & os,V1_3::ErrorStatus errorStatus)45 ::std::ostream& operator<<(::std::ostream& os, V1_3::ErrorStatus errorStatus) {
46     return os << toString(errorStatus);
47 }
48 
49 }  // namespace android::hardware::neuralnetworks::V1_0
50 
51 namespace {
52 
53 enum class HasCalledPrepareModel { NO, WITHOUT_CACHING, WITH_CACHING };
54 
55 // Print HasCalledPrepareModel enum for better GTEST failure messages
operator <<(std::ostream & os,HasCalledPrepareModel hasCalledPrepareModel)56 std::ostream& operator<<(std::ostream& os, HasCalledPrepareModel hasCalledPrepareModel) {
57     switch (hasCalledPrepareModel) {
58         case HasCalledPrepareModel::NO:
59             return os << "NO";
60         case HasCalledPrepareModel::WITHOUT_CACHING:
61             return os << "WITHOUT_CACHING";
62         case HasCalledPrepareModel::WITH_CACHING:
63             return os << "WITH_CACHING";
64     }
65     CHECK(false) << "HasCalledPrepareModel print called with invalid code "
66                  << static_cast<int>(hasCalledPrepareModel);
67     return os;
68 }
69 
70 // Whether the driver is expected to be registered because it can pass initialization.
canDeviceBeRegistered(V1_3::ErrorStatus error,uint32_t numModelCache,uint32_t numDataCache)71 bool canDeviceBeRegistered(V1_3::ErrorStatus error, uint32_t numModelCache, uint32_t numDataCache) {
72     constexpr uint32_t maxNumCacheFiles =
73             static_cast<uint32_t>(V1_2::Constant::MAX_NUMBER_OF_CACHE_FILES);
74     return error == V1_3::ErrorStatus::NONE && numModelCache <= maxNumCacheFiles &&
75            numDataCache <= maxNumCacheFiles;
76 }
77 
78 // Whether the driver supports caching based on the returns from getNumberOfCacheFilesNeeded.
isCachingSupported(uint32_t numModelCache,uint32_t numDataCache)79 bool isCachingSupported(uint32_t numModelCache, uint32_t numDataCache) {
80     return numModelCache != 0 || numDataCache != 0;
81 }
82 
83 // This is an IDevice for testing purposes which overrides several methods from sample driver:
84 // - supports all the operations and is faster than cpu fallback.
85 // - overrides getNumberOfCacheFilesNeeded to report according to given parameters.
86 // - overrides prepareModelFromCache_1_3 to return error status according to
87 //   mErrorStatusPrepareFromCache.
88 // - produces CachingPreparedModel on prepareModel and prepareModelFromCache_1_3.
89 //
90 // The cache entry is written by prepareModel_1_3 and is checked later by
91 // CachingDriver::prepareModelFromCache_1_3.
92 //
93 // The CachingDriver has 2 flags mHasCalledPrepareModelFromCache and mHasCalledPrepareModel
94 // to check if the correct methods are invoked by the runtime.
95 class CachingDriver : public sample_driver::SampleDriver {
96    private:
97     static constexpr size_t kCacheSize = 256;
98 
99     class CachingPreparedModel : public V1_3::IPreparedModel {
100        public:
101         CachingPreparedModel() = default;
102 
execute(const V1_0::Request &,const sp<V1_0::IExecutionCallback> &)103         hardware::Return<V1_0::ErrorStatus> execute(const V1_0::Request&,
104                                                     const sp<V1_0::IExecutionCallback>&) override {
105             return V1_0::ErrorStatus::DEVICE_UNAVAILABLE;
106         }
execute_1_2(const V1_0::Request &,V1_2::MeasureTiming,const sp<V1_2::IExecutionCallback> &)107         hardware::Return<V1_0::ErrorStatus> execute_1_2(
108                 const V1_0::Request&, V1_2::MeasureTiming,
109                 const sp<V1_2::IExecutionCallback>&) override {
110             return V1_0::ErrorStatus::DEVICE_UNAVAILABLE;
111         }
execute_1_3(const V1_3::Request &,V1_2::MeasureTiming,const V1_3::OptionalTimePoint &,const V1_3::OptionalTimeoutDuration &,const sp<V1_3::IExecutionCallback> &)112         hardware::Return<V1_3::ErrorStatus> execute_1_3(
113                 const V1_3::Request&, V1_2::MeasureTiming, const V1_3::OptionalTimePoint&,
114                 const V1_3::OptionalTimeoutDuration&,
115                 const sp<V1_3::IExecutionCallback>&) override {
116             return V1_3::ErrorStatus::DEVICE_UNAVAILABLE;
117         }
executeSynchronously(const V1_0::Request &,V1_2::MeasureTiming,executeSynchronously_cb cb)118         hardware::Return<void> executeSynchronously(const V1_0::Request&, V1_2::MeasureTiming,
119                                                     executeSynchronously_cb cb) override {
120             cb(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, {}, kBadTiming);
121             return hardware::Void();
122         }
executeSynchronously_1_3(const V1_3::Request &,V1_2::MeasureTiming,const V1_3::OptionalTimePoint &,const V1_3::OptionalTimeoutDuration &,executeSynchronously_1_3_cb cb)123         hardware::Return<void> executeSynchronously_1_3(const V1_3::Request&, V1_2::MeasureTiming,
124                                                         const V1_3::OptionalTimePoint&,
125                                                         const V1_3::OptionalTimeoutDuration&,
126                                                         executeSynchronously_1_3_cb cb) override {
127             cb(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, {}, kBadTiming);
128             return hardware::Void();
129         }
configureExecutionBurst(const sp<V1_2::IBurstCallback> &,const MQDescriptorSync<V1_2::FmqRequestDatum> &,const MQDescriptorSync<V1_2::FmqResultDatum> &,configureExecutionBurst_cb cb)130         hardware::Return<void> configureExecutionBurst(
131                 const sp<V1_2::IBurstCallback>&, const MQDescriptorSync<V1_2::FmqRequestDatum>&,
132                 const MQDescriptorSync<V1_2::FmqResultDatum>&,
133                 configureExecutionBurst_cb cb) override {
134             cb(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, nullptr);
135             return hardware::Void();
136         }
executeFenced(const V1_3::Request &,const hardware::hidl_vec<hardware::hidl_handle> &,V1_2::MeasureTiming,const V1_3::OptionalTimePoint &,const V1_3::OptionalTimeoutDuration &,const V1_3::OptionalTimeoutDuration &,executeFenced_cb cb)137         hardware::Return<void> executeFenced(const V1_3::Request&,
138                                              const hardware::hidl_vec<hardware::hidl_handle>&,
139                                              V1_2::MeasureTiming, const V1_3::OptionalTimePoint&,
140                                              const V1_3::OptionalTimeoutDuration&,
141                                              const V1_3::OptionalTimeoutDuration&,
142                                              executeFenced_cb cb) {
143             cb(V1_3::ErrorStatus::DEVICE_UNAVAILABLE, hardware::hidl_handle(nullptr), nullptr);
144             return hardware::Void();
145         }
146     };
147 
148    public:
CachingDriver(std::string_view name,V1_3::ErrorStatus errorStatusGetNumCacheFiles,uint32_t numModelCache,uint32_t numDataCache,V1_3::ErrorStatus errorStatusPrepareFromCache)149     CachingDriver(std::string_view name, V1_3::ErrorStatus errorStatusGetNumCacheFiles,
150                   uint32_t numModelCache, uint32_t numDataCache,
151                   V1_3::ErrorStatus errorStatusPrepareFromCache)
152         : SampleDriver(name.data()),
153           mErrorStatusGetNumCacheFiles(errorStatusGetNumCacheFiles),
154           mNumModelCache(numModelCache),
155           mNumDataCache(numDataCache),
156           mErrorStatusPrepareFromCache(errorStatusPrepareFromCache) {
157         mModelCacheData.resize(kCacheSize);
158         std::iota(mModelCacheData.begin(), mModelCacheData.end(), 0);
159         mDataCacheData.resize(kCacheSize);
160         std::iota(mDataCacheData.begin(), mDataCacheData.end(), 1);
161     }
~CachingDriver()162     ~CachingDriver() override {}
163 
164     // Reports faster than cpu.
getCapabilities_1_3(getCapabilities_1_3_cb cb)165     hardware::Return<void> getCapabilities_1_3(getCapabilities_1_3_cb cb) override {
166         android::nn::initVLogMask();
167         const V1_0::PerformanceInfo kPerf = {.execTime = 0.1, .powerUsage = 0.1};
168         V1_3::Capabilities capabilities = {
169                 .relaxedFloat32toFloat16PerformanceScalar = kPerf,
170                 .relaxedFloat32toFloat16PerformanceTensor = kPerf,
171                 .operandPerformance = nonExtensionOperandPerformance<HalVersion::V1_3>(kPerf),
172                 .ifPerformance = kPerf,
173                 .whilePerformance = kPerf};
174         cb(V1_3::ErrorStatus::NONE, capabilities);
175         return hardware::Void();
176     }
177 
178     // Reports supporting all operations.
getSupportedOperations_1_3(const V1_3::Model & model,getSupportedOperations_1_3_cb cb)179     hardware::Return<void> getSupportedOperations_1_3(const V1_3::Model& model,
180                                                       getSupportedOperations_1_3_cb cb) override {
181         std::vector<bool> supported(model.main.operations.size(), true);
182         cb(V1_3::ErrorStatus::NONE, supported);
183         return hardware::Void();
184     }
185 
186     // Reports according to mGetNumCacheFiles.
getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb)187     hardware::Return<void> getNumberOfCacheFilesNeeded(getNumberOfCacheFilesNeeded_cb cb) override {
188         cb(convertToV1_0(mErrorStatusGetNumCacheFiles), mNumModelCache, mNumDataCache);
189         return hardware::Void();
190     }
191 
192     // Generates CachingPreparedModel.
193     // Writes the cache entry per mCacheXData and sets mHasCalledPrepareModel.
prepareModel_1_3(const V1_3::Model &,V1_1::ExecutionPreference,V1_3::Priority,const V1_3::OptionalTimePoint &,const hardware::hidl_vec<hardware::hidl_handle> & modelCacheHandle,const hardware::hidl_vec<hardware::hidl_handle> & dataCacheHandle,const HalCacheToken &,const sp<V1_3::IPreparedModelCallback> & cb)194     hardware::Return<V1_3::ErrorStatus> prepareModel_1_3(
195             const V1_3::Model&, V1_1::ExecutionPreference, V1_3::Priority,
196             const V1_3::OptionalTimePoint&,
197             const hardware::hidl_vec<hardware::hidl_handle>& modelCacheHandle,
198             const hardware::hidl_vec<hardware::hidl_handle>& dataCacheHandle, const HalCacheToken&,
199             const sp<V1_3::IPreparedModelCallback>& cb) override {
200         checkNumberOfCacheHandles(modelCacheHandle.size(), dataCacheHandle.size());
201         if (modelCacheHandle.size() != 0 || dataCacheHandle.size() != 0) {
202             writeToCache(modelCacheHandle, mModelCacheData);
203             writeToCache(dataCacheHandle, mDataCacheData);
204             mHasCalledPrepareModel = HasCalledPrepareModel::WITH_CACHING;
205         } else {
206             mHasCalledPrepareModel = HasCalledPrepareModel::WITHOUT_CACHING;
207         }
208         cb->notify_1_3(V1_3::ErrorStatus::NONE, new CachingPreparedModel());
209         return V1_3::ErrorStatus::NONE;
210     }
211 
212     // Checks if the cache entry is correct, notifies error status according to
213     // mErrorStatusPrepareFromCache, sets mHasCalledPrepareModelFromCache.
prepareModelFromCache_1_3(const V1_3::OptionalTimePoint &,const hardware::hidl_vec<hardware::hidl_handle> & modelCacheHandle,const hardware::hidl_vec<hardware::hidl_handle> & dataCacheHandle,const HalCacheToken &,const sp<V1_3::IPreparedModelCallback> & callback)214     hardware::Return<V1_3::ErrorStatus> prepareModelFromCache_1_3(
215             const V1_3::OptionalTimePoint&,
216             const hardware::hidl_vec<hardware::hidl_handle>& modelCacheHandle,
217             const hardware::hidl_vec<hardware::hidl_handle>& dataCacheHandle, const HalCacheToken&,
218             const sp<V1_3::IPreparedModelCallback>& callback) override {
219         readFromCache(modelCacheHandle, mModelCacheData);
220         readFromCache(dataCacheHandle, mDataCacheData);
221         mHasCalledPrepareModelFromCache = true;
222         if (mErrorStatusPrepareFromCache == V1_3::ErrorStatus::NONE) {
223             callback->notify_1_3(mErrorStatusPrepareFromCache, new CachingPreparedModel());
224         } else {
225             callback->notify_1_3(mErrorStatusPrepareFromCache, nullptr);
226         }
227         return V1_3::ErrorStatus::NONE;
228     };
229 
hasCalledPrepareModelFromCache() const230     bool hasCalledPrepareModelFromCache() const { return mHasCalledPrepareModelFromCache; }
hasCalledPrepareModel() const231     HasCalledPrepareModel hasCalledPrepareModel() const { return mHasCalledPrepareModel; }
232 
233    private:
234     // Checks the number of cache files passed to the driver from runtime.
checkNumberOfCacheHandles(size_t modelCache,size_t dataCache)235     void checkNumberOfCacheHandles(size_t modelCache, size_t dataCache) {
236         if (isCachingSupported(mNumModelCache, mNumDataCache)) {
237             if (modelCache != 0 || dataCache != 0) {
238                 ASSERT_EQ(modelCache, mNumModelCache);
239                 ASSERT_EQ(dataCache, mNumDataCache);
240             }
241         } else {
242             ASSERT_EQ(modelCache, 0ul);
243             ASSERT_EQ(dataCache, 0ul);
244         }
245     }
246 
writeToCache(const hardware::hidl_vec<hardware::hidl_handle> & handles,const std::vector<uint8_t> & cache)247     void writeToCache(const hardware::hidl_vec<hardware::hidl_handle>& handles,
248                       const std::vector<uint8_t>& cache) {
249         for (uint32_t i = 0; i < handles.size(); ++i) {
250             ASSERT_EQ(handles[i]->numFds, 1);
251             EXPECT_EQ(write(handles[i]->data[0], cache.data(), kCacheSize),
252                       static_cast<ssize_t>(kCacheSize));
253         }
254     }
255 
readFromCache(const hardware::hidl_vec<hardware::hidl_handle> & handles,const std::vector<uint8_t> & expected)256     void readFromCache(const hardware::hidl_vec<hardware::hidl_handle>& handles,
257                        const std::vector<uint8_t>& expected) {
258         for (uint32_t i = 0; i < handles.size(); ++i) {
259             ASSERT_EQ(handles[i]->numFds, 1);
260             std::vector<uint8_t> actual(kCacheSize);
261             EXPECT_EQ(read(handles[i]->data[0], actual.data(), kCacheSize),
262                       static_cast<ssize_t>(kCacheSize));
263             EXPECT_EQ(actual, expected);
264         }
265     }
266 
267     std::vector<uint8_t> mModelCacheData;
268     std::vector<uint8_t> mDataCacheData;
269 
270     const V1_3::ErrorStatus mErrorStatusGetNumCacheFiles;
271     const uint32_t mNumModelCache;
272     const uint32_t mNumDataCache;
273     const V1_3::ErrorStatus mErrorStatusPrepareFromCache;
274 
275     bool mHasCalledPrepareModelFromCache = false;
276     HasCalledPrepareModel mHasCalledPrepareModel = HasCalledPrepareModel::NO;
277 };
278 
CreateBroadcastAddModel(test_wrapper::Model * model)279 void CreateBroadcastAddModel(test_wrapper::Model* model) {
280     test_wrapper::OperandType matrixType(Type::TENSOR_FLOAT32, {2, 2});
281     test_wrapper::OperandType vectorType(Type::TENSOR_FLOAT32, {2});
282     test_wrapper::OperandType scalarType(Type::INT32, {});
283     int32_t activation(ANEURALNETWORKS_FUSED_NONE);
284     auto a = model->addOperand(&matrixType);
285     auto b = model->addOperand(&vectorType);
286     auto c = model->addOperand(&matrixType);
287     auto d = model->addOperand(&scalarType);
288     model->setOperandValue(d, &activation, sizeof(activation));
289     model->addOperation(ANEURALNETWORKS_ADD, {a, b, d}, {c});
290     model->identifyInputsAndOutputs({a, b}, {c});
291     ASSERT_TRUE(model->isValid());
292     ASSERT_EQ(model->finish(), WrapperResult::NO_ERROR);
293 }
294 
getDeviceWithName(std::string_view deviceName,const ANeuralNetworksDevice ** outputDevice)295 void getDeviceWithName(std::string_view deviceName, const ANeuralNetworksDevice** outputDevice) {
296     uint32_t numDevices = 0;
297     ASSERT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
298     EXPECT_GE(numDevices, (uint32_t)1);
299 
300     int numMatchingDevices = 0;
301     for (uint32_t i = 0; i < numDevices; i++) {
302         ANeuralNetworksDevice* device = nullptr;
303         ASSERT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
304 
305         const char* buffer = nullptr;
306         ASSERT_EQ(ANeuralNetworksDevice_getName(device, &buffer), ANEURALNETWORKS_NO_ERROR);
307         if (deviceName == buffer) {
308             *outputDevice = device;
309             numMatchingDevices++;
310         }
311     }
312 
313     EXPECT_LE(numMatchingDevices, 1);
314 }
315 
316 // Test device registration with a driver parameterized with
317 // - ErrorStatus returning from getNumberOfCacheFilesNeeded
318 // - Number of model cache files returning from getNumberOfCacheFilesNeeded
319 // - Number of data cache files returning from getNumberOfCacheFilesNeeded
320 using DeviceRegistrationTestParam = std::tuple<V1_3::ErrorStatus, uint32_t, uint32_t>;
321 
322 class DeviceRegistrationTest : public ::testing::TestWithParam<DeviceRegistrationTestParam> {
323    protected:
324     static constexpr std::string_view kDeviceName = "deviceTestCompilationCaching";
325     const V1_3::ErrorStatus kErrorStatusGetNumCacheFiles = std::get<0>(GetParam());
326     const uint32_t kNumModelCache = std::get<1>(GetParam());
327     const uint32_t kNumDataCache = std::get<2>(GetParam());
328     const sp<CachingDriver> kDriver =
329             new CachingDriver(kDeviceName, kErrorStatusGetNumCacheFiles, kNumModelCache,
330                               kNumDataCache, V1_3::ErrorStatus::NONE);
331 };
332 
TEST_P(DeviceRegistrationTest,CachingFailure)333 TEST_P(DeviceRegistrationTest, CachingFailure) {
334     if (DeviceManager::get()->getUseCpuOnly()) {
335         return;
336     }
337 
338     DeviceManager::get()->forTest_registerDevice(makeSharedDevice(kDeviceName.data(), kDriver));
339     const auto cleanup = android::base::make_scope_guard(
340             [] { DeviceManager::get()->forTest_reInitializeDeviceList(); });
341 
342     // get device
343     const ANeuralNetworksDevice* device = nullptr;
344     getDeviceWithName(kDeviceName, &device);
345 
346     // check if device registeration matches expectations
347     const bool isDeviceRegistered = (device != nullptr);
348     const bool expectDeviceToBeRegistered =
349             canDeviceBeRegistered(kErrorStatusGetNumCacheFiles, kNumModelCache, kNumDataCache);
350     ASSERT_EQ(isDeviceRegistered, expectDeviceToBeRegistered);
351 }
352 
353 // Test model compilation with a driver parameterized with
354 // - Number of model cache files returning from getNumberOfCacheFilesNeeded
355 // - Number of data cache files returning from getNumberOfCacheFilesNeeded
356 // - ErrorStatus returning from prepareModelFromCache_1_3
357 using CompilationCachingTestParam = std::tuple<uint32_t, uint32_t, V1_3::ErrorStatus>;
358 
359 class CompilationCachingTest : public ::testing::TestWithParam<CompilationCachingTestParam> {
360    protected:
SetUp()361     virtual void SetUp() override {
362         char cacheDirTemp[] =
363                 "/data/local/tmp/AVeryLongDirectoryNameForTestCompilationCachingXXXXXX";
364         char* cacheDir = mkdtemp(cacheDirTemp);
365         ASSERT_NE(cacheDir, nullptr);
366         mCacheDir = cacheDir;
367         CreateBroadcastAddModel(&mModel);
368     }
369 
TearDown()370     virtual void TearDown() override {
371         if (!::testing::Test::HasFailure()) {
372             std::filesystem::remove_all(mCacheDir);
373         }
374     }
375 
compileModel(const sp<CachingDriver> & driver,bool withToken)376     void compileModel(const sp<CachingDriver>& driver, bool withToken) {
377         DeviceManager::get()->forTest_registerDevice(makeSharedDevice(kDeviceName.data(), driver));
378         const auto cleanup = android::base::make_scope_guard(
379                 [] { DeviceManager::get()->forTest_reInitializeDeviceList(); });
380 
381         // Get a handle to the single driver device matching kDeviceName.
382         const ANeuralNetworksDevice* device = nullptr;
383         getDeviceWithName(kDeviceName, &device);
384         ASSERT_NE(device, nullptr);
385 
386         // Compile the model with the device.
387         ANeuralNetworksCompilation* compilation = nullptr;
388         ASSERT_EQ(ANeuralNetworksCompilation_createForDevices(mModel.getHandle(), &device, 1,
389                                                               &compilation),
390                   ANEURALNETWORKS_NO_ERROR);
391         if (withToken) {
392             ASSERT_EQ(ANeuralNetworksCompilation_setCaching(compilation, mCacheDir.c_str(),
393                                                             kToken.data()),
394                       ANEURALNETWORKS_NO_ERROR);
395         }
396         ASSERT_EQ(ANeuralNetworksCompilation_finish(compilation), ANEURALNETWORKS_NO_ERROR);
397 
398         // close memory
399         ANeuralNetworksCompilation_free(compilation);
400     }
401 
createCache()402     void createCache() {
403         sp<CachingDriver> driver =
404                 new CachingDriver(kDeviceName, V1_3::ErrorStatus::NONE, kNumModelCache,
405                                   kNumDataCache, V1_3::ErrorStatus::NONE);
406         compileModel(driver, /*withToken=*/true);
407     }
408 
409     static constexpr std::string_view kDeviceName = "deviceTestCompilationCaching";
410     const uint32_t kNumModelCache = std::get<0>(GetParam());
411     const uint32_t kNumDataCache = std::get<1>(GetParam());
412     const V1_3::ErrorStatus kErrorStatusPrepareFromCache = std::get<2>(GetParam());
413     const bool kIsCachingSupported = isCachingSupported(kNumModelCache, kNumDataCache);
414     test_wrapper::Model mModel;
415     std::string mCacheDir;
416     const HalCacheToken kToken{};
417 };
418 
TEST_P(CompilationCachingTest,TokenProvidedAndCacheNotExist)419 TEST_P(CompilationCachingTest, TokenProvidedAndCacheNotExist) {
420     if (DeviceManager::get()->getUseCpuOnly()) {
421         return;
422     }
423     sp<CachingDriver> driver =
424             new CachingDriver(kDeviceName, V1_3::ErrorStatus::NONE, kNumModelCache, kNumDataCache,
425                               kErrorStatusPrepareFromCache);
426     compileModel(driver, /*withToken=*/true);
427 
428     // When cache file does not exist, the runtime should never call prepareModelFromCache_1_3.
429     EXPECT_FALSE(driver->hasCalledPrepareModelFromCache());
430 
431     // The runtime should call prepareModel_1_3. It should request caching iff caching supported.
432     EXPECT_EQ(driver->hasCalledPrepareModel(), kIsCachingSupported
433                                                        ? HasCalledPrepareModel::WITH_CACHING
434                                                        : HasCalledPrepareModel::WITHOUT_CACHING);
435 }
436 
TEST_P(CompilationCachingTest,TokenProvidedAndCacheExist)437 TEST_P(CompilationCachingTest, TokenProvidedAndCacheExist) {
438     if (DeviceManager::get()->getUseCpuOnly()) {
439         return;
440     }
441     createCache();
442     sp<CachingDriver> driver =
443             new CachingDriver(kDeviceName, V1_3::ErrorStatus::NONE, kNumModelCache, kNumDataCache,
444                               kErrorStatusPrepareFromCache);
445     compileModel(driver, /*withToken=*/true);
446 
447     // When cache files exist, the runtime should call prepareModelFromCache_1_3 iff caching
448     // supported.
449     EXPECT_EQ(driver->hasCalledPrepareModelFromCache(), kIsCachingSupported);
450 
451     HasCalledPrepareModel expectHasCalledPrepareModel;
452     if (kIsCachingSupported) {
453         if (kErrorStatusPrepareFromCache == V1_3::ErrorStatus::NONE) {
454             // The runtime should not call prepareModel_1_3 iff caching supported and
455             // prepareModelFromCache_1_3 succeeds.
456             expectHasCalledPrepareModel = HasCalledPrepareModel::NO;
457         } else {
458             // The runtime should call prepareModel_1_3 and request caching iff caching supported
459             // but prepareModelFromCache_1_3 fails.
460             expectHasCalledPrepareModel = HasCalledPrepareModel::WITH_CACHING;
461         }
462     } else {
463         // The runtime should call prepareModel_1_3 without caching iff caching not supported.
464         expectHasCalledPrepareModel = HasCalledPrepareModel::WITHOUT_CACHING;
465     }
466     EXPECT_EQ(driver->hasCalledPrepareModel(), expectHasCalledPrepareModel);
467 }
468 
TEST_P(CompilationCachingTest,TokenNotProvided)469 TEST_P(CompilationCachingTest, TokenNotProvided) {
470     if (DeviceManager::get()->getUseCpuOnly()) {
471         return;
472     }
473     sp<CachingDriver> driver =
474             new CachingDriver(kDeviceName, V1_3::ErrorStatus::NONE, kNumModelCache, kNumDataCache,
475                               kErrorStatusPrepareFromCache);
476     compileModel(driver, /*withToken=*/false);
477 
478     // When no NDK token is provided by the client, the runtime should never call
479     // prepareModelFromCache_1_3 or request caching with prepareModel_1_3.
480     EXPECT_FALSE(driver->hasCalledPrepareModelFromCache());
481     EXPECT_EQ(driver->hasCalledPrepareModel(), HasCalledPrepareModel::WITHOUT_CACHING);
482 }
483 
484 static const auto kErrorStatusGetNumCacheFilesChoices =
485         testing::Values(V1_3::ErrorStatus::NONE, V1_3::ErrorStatus::DEVICE_UNAVAILABLE);
486 static const auto kNumCacheChoices =
487         testing::Values(0ul, 1ul, static_cast<uint32_t>(V1_2::Constant::MAX_NUMBER_OF_CACHE_FILES),
488                         static_cast<uint32_t>(V1_2::Constant::MAX_NUMBER_OF_CACHE_FILES) + 1);
489 static const auto kNumValidCacheChoices =
490         testing::Values(0ul, 1ul, static_cast<uint32_t>(V1_2::Constant::MAX_NUMBER_OF_CACHE_FILES));
491 static const auto kErrorStatusPrepareFromCacheChoices =
492         testing::Values(V1_3::ErrorStatus::NONE, V1_3::ErrorStatus::GENERAL_FAILURE,
493                         V1_3::ErrorStatus::DEVICE_UNAVAILABLE, V1_3::ErrorStatus::INVALID_ARGUMENT);
494 
495 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, DeviceRegistrationTest,
496                          testing::Combine(kErrorStatusGetNumCacheFilesChoices, kNumCacheChoices,
497                                           kNumCacheChoices));
498 
499 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, CompilationCachingTest,
500                          testing::Combine(kNumValidCacheChoices, kNumValidCacheChoices,
501                                           kErrorStatusPrepareFromCacheChoices));
502 
503 }  // namespace
504