1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18 
19 #include <android-base/logging.h>
20 #include <fcntl.h>
21 #include <ftw.h>
22 #include <gtest/gtest.h>
23 #include <hidlmemory/mapping.h>
24 #include <unistd.h>
25 
26 #include <cstdio>
27 #include <cstdlib>
28 #include <random>
29 #include <thread>
30 
31 #include "1.3/Callbacks.h"
32 #include "1.3/Utils.h"
33 #include "GeneratedTestHarness.h"
34 #include "MemoryUtils.h"
35 #include "TestHarness.h"
36 #include "Utils.h"
37 #include "VtsHalNeuralnetworks.h"
38 
39 // Forward declaration of the mobilenet generated test models in
40 // frameworks/ml/nn/runtime/test/generated/.
41 namespace generated_tests::mobilenet_224_gender_basic_fixed {
42 const test_helper::TestModel& get_test_model();
43 }  // namespace generated_tests::mobilenet_224_gender_basic_fixed
44 
45 namespace generated_tests::mobilenet_quantized {
46 const test_helper::TestModel& get_test_model();
47 }  // namespace generated_tests::mobilenet_quantized
48 
49 namespace android::hardware::neuralnetworks::V1_3::vts::functional {
50 
51 using namespace test_helper;
52 using implementation::PreparedModelCallback;
53 using V1_1::ExecutionPreference;
54 using V1_2::Constant;
55 using V1_2::OperationType;
56 
57 namespace float32_model {
58 
59 constexpr auto get_test_model = generated_tests::mobilenet_224_gender_basic_fixed::get_test_model;
60 
61 }  // namespace float32_model
62 
63 namespace quant8_model {
64 
65 constexpr auto get_test_model = generated_tests::mobilenet_quantized::get_test_model;
66 
67 }  // namespace quant8_model
68 
69 namespace {
70 
71 enum class AccessMode { READ_WRITE, READ_ONLY, WRITE_ONLY };
72 
73 // Creates cache handles based on provided file groups.
74 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
createCacheHandles(const std::vector<std::vector<std::string>> & fileGroups,const std::vector<AccessMode> & mode,hidl_vec<hidl_handle> * handles)75 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups,
76                         const std::vector<AccessMode>& mode, hidl_vec<hidl_handle>* handles) {
77     handles->resize(fileGroups.size());
78     for (uint32_t i = 0; i < fileGroups.size(); i++) {
79         std::vector<int> fds;
80         for (const auto& file : fileGroups[i]) {
81             int fd;
82             if (mode[i] == AccessMode::READ_ONLY) {
83                 fd = open(file.c_str(), O_RDONLY);
84             } else if (mode[i] == AccessMode::WRITE_ONLY) {
85                 fd = open(file.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
86             } else if (mode[i] == AccessMode::READ_WRITE) {
87                 fd = open(file.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
88             } else {
89                 FAIL();
90             }
91             ASSERT_GE(fd, 0);
92             fds.push_back(fd);
93         }
94         native_handle_t* cacheNativeHandle = native_handle_create(fds.size(), 0);
95         ASSERT_NE(cacheNativeHandle, nullptr);
96         std::copy(fds.begin(), fds.end(), &cacheNativeHandle->data[0]);
97         (*handles)[i].setTo(cacheNativeHandle, /*shouldOwn=*/true);
98     }
99 }
100 
createCacheHandles(const std::vector<std::vector<std::string>> & fileGroups,AccessMode mode,hidl_vec<hidl_handle> * handles)101 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups, AccessMode mode,
102                         hidl_vec<hidl_handle>* handles) {
103     createCacheHandles(fileGroups, std::vector<AccessMode>(fileGroups.size(), mode), handles);
104 }
105 
106 // Create a chain of broadcast operations. The second operand is always constant tensor [1].
107 // For simplicity, activation scalar is shared. The second operand is not shared
108 // in the model to let driver maintain a non-trivial size of constant data and the corresponding
109 // data locations in cache.
110 //
111 //                --------- activation --------
112 //                ↓      ↓      ↓             ↓
113 // E.g. input -> ADD -> ADD -> ADD -> ... -> ADD -> output
114 //                ↑      ↑      ↑             ↑
115 //               [1]    [1]    [1]           [1]
116 //
117 // This function assumes the operation is either ADD or MUL.
118 template <typename CppType, TestOperandType operandType>
createLargeTestModelImpl(TestOperationType op,uint32_t len)119 TestModel createLargeTestModelImpl(TestOperationType op, uint32_t len) {
120     EXPECT_TRUE(op == TestOperationType::ADD || op == TestOperationType::MUL);
121 
122     // Model operations and operands.
123     std::vector<TestOperation> operations(len);
124     std::vector<TestOperand> operands(len * 2 + 2);
125 
126     // The activation scalar, value = 0.
127     operands[0] = {
128             .type = TestOperandType::INT32,
129             .dimensions = {},
130             .numberOfConsumers = len,
131             .scale = 0.0f,
132             .zeroPoint = 0,
133             .lifetime = TestOperandLifeTime::CONSTANT_COPY,
134             .data = TestBuffer::createFromVector<int32_t>({0}),
135     };
136 
137     // The buffer value of the constant second operand. The logical value is always 1.0f.
138     CppType bufferValue;
139     // The scale of the first and second operand.
140     float scale1, scale2;
141     if (operandType == TestOperandType::TENSOR_FLOAT32) {
142         bufferValue = 1.0f;
143         scale1 = 0.0f;
144         scale2 = 0.0f;
145     } else if (op == TestOperationType::ADD) {
146         bufferValue = 1;
147         scale1 = 1.0f;
148         scale2 = 1.0f;
149     } else {
150         // To satisfy the constraint on quant8 MUL: input0.scale * input1.scale < output.scale,
151         // set input1 to have scale = 0.5f and bufferValue = 2, i.e. 1.0f in floating point.
152         bufferValue = 2;
153         scale1 = 1.0f;
154         scale2 = 0.5f;
155     }
156 
157     for (uint32_t i = 0; i < len; i++) {
158         const uint32_t firstInputIndex = i * 2 + 1;
159         const uint32_t secondInputIndex = firstInputIndex + 1;
160         const uint32_t outputIndex = secondInputIndex + 1;
161 
162         // The first operation input.
163         operands[firstInputIndex] = {
164                 .type = operandType,
165                 .dimensions = {1},
166                 .numberOfConsumers = 1,
167                 .scale = scale1,
168                 .zeroPoint = 0,
169                 .lifetime = (i == 0 ? TestOperandLifeTime::MODEL_INPUT
170                                     : TestOperandLifeTime::TEMPORARY_VARIABLE),
171                 .data = (i == 0 ? TestBuffer::createFromVector<CppType>({1}) : TestBuffer()),
172         };
173 
174         // The second operation input, value = 1.
175         operands[secondInputIndex] = {
176                 .type = operandType,
177                 .dimensions = {1},
178                 .numberOfConsumers = 1,
179                 .scale = scale2,
180                 .zeroPoint = 0,
181                 .lifetime = TestOperandLifeTime::CONSTANT_COPY,
182                 .data = TestBuffer::createFromVector<CppType>({bufferValue}),
183         };
184 
185         // The operation. All operations share the same activation scalar.
186         // The output operand is created as an input in the next iteration of the loop, in the case
187         // of all but the last member of the chain; and after the loop as a model output, in the
188         // case of the last member of the chain.
189         operations[i] = {
190                 .type = op,
191                 .inputs = {firstInputIndex, secondInputIndex, /*activation scalar*/ 0},
192                 .outputs = {outputIndex},
193         };
194     }
195 
196     // For TestOperationType::ADD, output = 1 + 1 * len = len + 1
197     // For TestOperationType::MUL, output = 1 * 1 ^ len = 1
198     CppType outputResult = static_cast<CppType>(op == TestOperationType::ADD ? len + 1u : 1u);
199 
200     // The model output.
201     operands.back() = {
202             .type = operandType,
203             .dimensions = {1},
204             .numberOfConsumers = 0,
205             .scale = scale1,
206             .zeroPoint = 0,
207             .lifetime = TestOperandLifeTime::MODEL_OUTPUT,
208             .data = TestBuffer::createFromVector<CppType>({outputResult}),
209     };
210 
211     return {
212             .main = {.operands = std::move(operands),
213                      .operations = std::move(operations),
214                      .inputIndexes = {1},
215                      .outputIndexes = {len * 2 + 1}},
216             .isRelaxed = false,
217     };
218 }
219 
220 }  // namespace
221 
222 // Tag for the compilation caching tests.
223 class CompilationCachingTestBase : public testing::Test {
224   protected:
CompilationCachingTestBase(sp<IDevice> device,OperandType type)225     CompilationCachingTestBase(sp<IDevice> device, OperandType type)
226         : kDevice(std::move(device)), kOperandType(type) {}
227 
SetUp()228     void SetUp() override {
229         testing::Test::SetUp();
230         ASSERT_NE(kDevice.get(), nullptr);
231         const bool deviceIsResponsive = kDevice->ping().isOk();
232         ASSERT_TRUE(deviceIsResponsive);
233 
234         // Create cache directory. The cache directory and a temporary cache file is always created
235         // to test the behavior of prepareModelFromCache_1_3, even when caching is not supported.
236         char cacheDirTemp[] = "/data/local/tmp/TestCompilationCachingXXXXXX";
237         char* cacheDir = mkdtemp(cacheDirTemp);
238         ASSERT_NE(cacheDir, nullptr);
239         mCacheDir = cacheDir;
240         mCacheDir.push_back('/');
241 
242         Return<void> ret = kDevice->getNumberOfCacheFilesNeeded(
243                 [this](V1_0::ErrorStatus status, uint32_t numModelCache, uint32_t numDataCache) {
244                     EXPECT_EQ(V1_0::ErrorStatus::NONE, status);
245                     mNumModelCache = numModelCache;
246                     mNumDataCache = numDataCache;
247                 });
248         EXPECT_TRUE(ret.isOk());
249         mIsCachingSupported = mNumModelCache > 0 || mNumDataCache > 0;
250 
251         // Create empty cache files.
252         mTmpCache = mCacheDir + "tmp";
253         for (uint32_t i = 0; i < mNumModelCache; i++) {
254             mModelCache.push_back({mCacheDir + "model" + std::to_string(i)});
255         }
256         for (uint32_t i = 0; i < mNumDataCache; i++) {
257             mDataCache.push_back({mCacheDir + "data" + std::to_string(i)});
258         }
259         // Dummy handles, use AccessMode::WRITE_ONLY for createCacheHandles to create files.
260         hidl_vec<hidl_handle> modelHandle, dataHandle, tmpHandle;
261         createCacheHandles(mModelCache, AccessMode::WRITE_ONLY, &modelHandle);
262         createCacheHandles(mDataCache, AccessMode::WRITE_ONLY, &dataHandle);
263         createCacheHandles({{mTmpCache}}, AccessMode::WRITE_ONLY, &tmpHandle);
264 
265         if (!mIsCachingSupported) {
266             LOG(INFO) << "NN VTS: Early termination of test because vendor service does not "
267                          "support compilation caching.";
268             std::cout << "[          ]   Early termination of test because vendor service does not "
269                          "support compilation caching."
270                       << std::endl;
271         }
272     }
273 
TearDown()274     void TearDown() override {
275         // If the test passes, remove the tmp directory.  Otherwise, keep it for debugging purposes.
276         if (!testing::Test::HasFailure()) {
277             // Recursively remove the cache directory specified by mCacheDir.
278             auto callback = [](const char* entry, const struct stat*, int, struct FTW*) {
279                 return remove(entry);
280             };
281             nftw(mCacheDir.c_str(), callback, 128, FTW_DEPTH | FTW_MOUNT | FTW_PHYS);
282         }
283         testing::Test::TearDown();
284     }
285 
286     // Model and examples creators. According to kOperandType, the following methods will return
287     // either float32 model/examples or the quant8 variant.
createTestModel()288     TestModel createTestModel() {
289         if (kOperandType == OperandType::TENSOR_FLOAT32) {
290             return float32_model::get_test_model();
291         } else {
292             return quant8_model::get_test_model();
293         }
294     }
295 
createLargeTestModel(OperationType op,uint32_t len)296     TestModel createLargeTestModel(OperationType op, uint32_t len) {
297         if (kOperandType == OperandType::TENSOR_FLOAT32) {
298             return createLargeTestModelImpl<float, TestOperandType::TENSOR_FLOAT32>(
299                     static_cast<TestOperationType>(op), len);
300         } else {
301             return createLargeTestModelImpl<uint8_t, TestOperandType::TENSOR_QUANT8_ASYMM>(
302                     static_cast<TestOperationType>(op), len);
303         }
304     }
305 
306     // See if the service can handle the model.
isModelFullySupported(const Model & model)307     bool isModelFullySupported(const Model& model) {
308         bool fullySupportsModel = false;
309         Return<void> supportedCall = kDevice->getSupportedOperations_1_3(
310                 model,
311                 [&fullySupportsModel, &model](ErrorStatus status, const hidl_vec<bool>& supported) {
312                     ASSERT_EQ(ErrorStatus::NONE, status);
313                     ASSERT_EQ(supported.size(), model.main.operations.size());
314                     fullySupportsModel = std::all_of(supported.begin(), supported.end(),
315                                                      [](bool valid) { return valid; });
316                 });
317         EXPECT_TRUE(supportedCall.isOk());
318         return fullySupportsModel;
319     }
320 
saveModelToCache(const Model & model,const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,sp<IPreparedModel> * preparedModel=nullptr)321     void saveModelToCache(const Model& model, const hidl_vec<hidl_handle>& modelCache,
322                           const hidl_vec<hidl_handle>& dataCache,
323                           sp<IPreparedModel>* preparedModel = nullptr) {
324         if (preparedModel != nullptr) *preparedModel = nullptr;
325 
326         // Launch prepare model.
327         sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
328         hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
329         Return<ErrorStatus> prepareLaunchStatus = kDevice->prepareModel_1_3(
330                 model, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority, {}, modelCache,
331                 dataCache, cacheToken, preparedModelCallback);
332         ASSERT_TRUE(prepareLaunchStatus.isOk());
333         ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus), ErrorStatus::NONE);
334 
335         // Retrieve prepared model.
336         preparedModelCallback->wait();
337         ASSERT_EQ(preparedModelCallback->getStatus(), ErrorStatus::NONE);
338         if (preparedModel != nullptr) {
339             *preparedModel = IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
340                                      .withDefault(nullptr);
341         }
342     }
343 
checkEarlyTermination(ErrorStatus status)344     bool checkEarlyTermination(ErrorStatus status) {
345         if (status == ErrorStatus::GENERAL_FAILURE) {
346             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
347                          "save the prepared model that it does not support.";
348             std::cout << "[          ]   Early termination of test because vendor service cannot "
349                          "save the prepared model that it does not support."
350                       << std::endl;
351             return true;
352         }
353         return false;
354     }
355 
checkEarlyTermination(const Model & model)356     bool checkEarlyTermination(const Model& model) {
357         if (!isModelFullySupported(model)) {
358             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
359                          "prepare model that it does not support.";
360             std::cout << "[          ]   Early termination of test because vendor service cannot "
361                          "prepare model that it does not support."
362                       << std::endl;
363             return true;
364         }
365         return false;
366     }
367 
prepareModelFromCache(const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,sp<IPreparedModel> * preparedModel,ErrorStatus * status)368     void prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
369                                const hidl_vec<hidl_handle>& dataCache,
370                                sp<IPreparedModel>* preparedModel, ErrorStatus* status) {
371         // Launch prepare model from cache.
372         sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
373         hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
374         Return<ErrorStatus> prepareLaunchStatus = kDevice->prepareModelFromCache_1_3(
375                 {}, modelCache, dataCache, cacheToken, preparedModelCallback);
376         ASSERT_TRUE(prepareLaunchStatus.isOk());
377         if (static_cast<ErrorStatus>(prepareLaunchStatus) != ErrorStatus::NONE) {
378             *preparedModel = nullptr;
379             *status = static_cast<ErrorStatus>(prepareLaunchStatus);
380             return;
381         }
382 
383         // Retrieve prepared model.
384         preparedModelCallback->wait();
385         *status = preparedModelCallback->getStatus();
386         *preparedModel = IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
387                                  .withDefault(nullptr);
388     }
389 
390     // Absolute path to the temporary cache directory.
391     std::string mCacheDir;
392 
393     // Groups of file paths for model and data cache in the tmp cache directory, initialized with
394     // outer_size = mNum{Model|Data}Cache, inner_size = 1. The outer vector corresponds to handles
395     // and the inner vector is for fds held by each handle.
396     std::vector<std::vector<std::string>> mModelCache;
397     std::vector<std::vector<std::string>> mDataCache;
398 
399     // A separate temporary file path in the tmp cache directory.
400     std::string mTmpCache;
401 
402     uint8_t mToken[static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)] = {};
403     uint32_t mNumModelCache;
404     uint32_t mNumDataCache;
405     uint32_t mIsCachingSupported;
406 
407     const sp<IDevice> kDevice;
408     // The primary data type of the testModel.
409     const OperandType kOperandType;
410 };
411 
412 using CompilationCachingTestParam = std::tuple<NamedDevice, OperandType>;
413 
414 // A parameterized fixture of CompilationCachingTestBase. Every test will run twice, with the first
415 // pass running with float32 models and the second pass running with quant8 models.
416 class CompilationCachingTest : public CompilationCachingTestBase,
417                                public testing::WithParamInterface<CompilationCachingTestParam> {
418   protected:
CompilationCachingTest()419     CompilationCachingTest()
420         : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
421                                      std::get<OperandType>(GetParam())) {}
422 };
423 
TEST_P(CompilationCachingTest,CacheSavingAndRetrieval)424 TEST_P(CompilationCachingTest, CacheSavingAndRetrieval) {
425     // Create test HIDL model and compile.
426     const TestModel& testModel = createTestModel();
427     const Model model = createModel(testModel);
428     if (checkEarlyTermination(model)) return;
429     sp<IPreparedModel> preparedModel = nullptr;
430 
431     // Save the compilation to cache.
432     {
433         hidl_vec<hidl_handle> modelCache, dataCache;
434         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
435         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
436         saveModelToCache(model, modelCache, dataCache);
437     }
438 
439     // Retrieve preparedModel from cache.
440     {
441         preparedModel = nullptr;
442         ErrorStatus status;
443         hidl_vec<hidl_handle> modelCache, dataCache;
444         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
445         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
446         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
447         if (!mIsCachingSupported) {
448             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
449             ASSERT_EQ(preparedModel, nullptr);
450             return;
451         } else if (checkEarlyTermination(status)) {
452             ASSERT_EQ(preparedModel, nullptr);
453             return;
454         } else {
455             ASSERT_EQ(status, ErrorStatus::NONE);
456             ASSERT_NE(preparedModel, nullptr);
457         }
458     }
459 
460     // Execute and verify results.
461     EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
462 }
463 
TEST_P(CompilationCachingTest,CacheSavingAndRetrievalNonZeroOffset)464 TEST_P(CompilationCachingTest, CacheSavingAndRetrievalNonZeroOffset) {
465     // Create test HIDL model and compile.
466     const TestModel& testModel = createTestModel();
467     const Model model = createModel(testModel);
468     if (checkEarlyTermination(model)) return;
469     sp<IPreparedModel> preparedModel = nullptr;
470 
471     // Save the compilation to cache.
472     {
473         hidl_vec<hidl_handle> modelCache, dataCache;
474         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
475         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
476         uint8_t dummyBytes[] = {0, 0};
477         // Write a dummy integer to the cache.
478         // The driver should be able to handle non-empty cache and non-zero fd offset.
479         for (uint32_t i = 0; i < modelCache.size(); i++) {
480             ASSERT_EQ(write(modelCache[i].getNativeHandle()->data[0], &dummyBytes,
481                             sizeof(dummyBytes)),
482                       sizeof(dummyBytes));
483         }
484         for (uint32_t i = 0; i < dataCache.size(); i++) {
485             ASSERT_EQ(
486                     write(dataCache[i].getNativeHandle()->data[0], &dummyBytes, sizeof(dummyBytes)),
487                     sizeof(dummyBytes));
488         }
489         saveModelToCache(model, modelCache, dataCache);
490     }
491 
492     // Retrieve preparedModel from cache.
493     {
494         preparedModel = nullptr;
495         ErrorStatus status;
496         hidl_vec<hidl_handle> modelCache, dataCache;
497         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
498         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
499         uint8_t dummyByte = 0;
500         // Advance the offset of each handle by one byte.
501         // The driver should be able to handle non-zero fd offset.
502         for (uint32_t i = 0; i < modelCache.size(); i++) {
503             ASSERT_GE(read(modelCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
504         }
505         for (uint32_t i = 0; i < dataCache.size(); i++) {
506             ASSERT_GE(read(dataCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
507         }
508         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
509         if (!mIsCachingSupported) {
510             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
511             ASSERT_EQ(preparedModel, nullptr);
512             return;
513         } else if (checkEarlyTermination(status)) {
514             ASSERT_EQ(preparedModel, nullptr);
515             return;
516         } else {
517             ASSERT_EQ(status, ErrorStatus::NONE);
518             ASSERT_NE(preparedModel, nullptr);
519         }
520     }
521 
522     // Execute and verify results.
523     EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
524 }
525 
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumCache)526 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumCache) {
527     // Create test HIDL model and compile.
528     const TestModel& testModel = createTestModel();
529     const Model model = createModel(testModel);
530     if (checkEarlyTermination(model)) return;
531 
532     // Test with number of model cache files greater than mNumModelCache.
533     {
534         hidl_vec<hidl_handle> modelCache, dataCache;
535         // Pass an additional cache file for model cache.
536         mModelCache.push_back({mTmpCache});
537         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
538         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
539         mModelCache.pop_back();
540         sp<IPreparedModel> preparedModel = nullptr;
541         saveModelToCache(model, modelCache, dataCache, &preparedModel);
542         ASSERT_NE(preparedModel, nullptr);
543         // Execute and verify results.
544         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
545         // Check if prepareModelFromCache fails.
546         preparedModel = nullptr;
547         ErrorStatus status;
548         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
549         if (status != ErrorStatus::INVALID_ARGUMENT) {
550             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
551         }
552         ASSERT_EQ(preparedModel, nullptr);
553     }
554 
555     // Test with number of model cache files smaller than mNumModelCache.
556     if (mModelCache.size() > 0) {
557         hidl_vec<hidl_handle> modelCache, dataCache;
558         // Pop out the last cache file.
559         auto tmp = mModelCache.back();
560         mModelCache.pop_back();
561         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
562         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
563         mModelCache.push_back(tmp);
564         sp<IPreparedModel> preparedModel = nullptr;
565         saveModelToCache(model, modelCache, dataCache, &preparedModel);
566         ASSERT_NE(preparedModel, nullptr);
567         // Execute and verify results.
568         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
569         // Check if prepareModelFromCache fails.
570         preparedModel = nullptr;
571         ErrorStatus status;
572         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
573         if (status != ErrorStatus::INVALID_ARGUMENT) {
574             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
575         }
576         ASSERT_EQ(preparedModel, nullptr);
577     }
578 
579     // Test with number of data cache files greater than mNumDataCache.
580     {
581         hidl_vec<hidl_handle> modelCache, dataCache;
582         // Pass an additional cache file for data cache.
583         mDataCache.push_back({mTmpCache});
584         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
585         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
586         mDataCache.pop_back();
587         sp<IPreparedModel> preparedModel = nullptr;
588         saveModelToCache(model, modelCache, dataCache, &preparedModel);
589         ASSERT_NE(preparedModel, nullptr);
590         // Execute and verify results.
591         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
592         // Check if prepareModelFromCache fails.
593         preparedModel = nullptr;
594         ErrorStatus status;
595         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
596         if (status != ErrorStatus::INVALID_ARGUMENT) {
597             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
598         }
599         ASSERT_EQ(preparedModel, nullptr);
600     }
601 
602     // Test with number of data cache files smaller than mNumDataCache.
603     if (mDataCache.size() > 0) {
604         hidl_vec<hidl_handle> modelCache, dataCache;
605         // Pop out the last cache file.
606         auto tmp = mDataCache.back();
607         mDataCache.pop_back();
608         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
609         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
610         mDataCache.push_back(tmp);
611         sp<IPreparedModel> preparedModel = nullptr;
612         saveModelToCache(model, modelCache, dataCache, &preparedModel);
613         ASSERT_NE(preparedModel, nullptr);
614         // Execute and verify results.
615         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
616         // Check if prepareModelFromCache fails.
617         preparedModel = nullptr;
618         ErrorStatus status;
619         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
620         if (status != ErrorStatus::INVALID_ARGUMENT) {
621             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
622         }
623         ASSERT_EQ(preparedModel, nullptr);
624     }
625 }
626 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumCache)627 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumCache) {
628     // Create test HIDL model and compile.
629     const TestModel& testModel = createTestModel();
630     const Model model = createModel(testModel);
631     if (checkEarlyTermination(model)) return;
632 
633     // Save the compilation to cache.
634     {
635         hidl_vec<hidl_handle> modelCache, dataCache;
636         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
637         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
638         saveModelToCache(model, modelCache, dataCache);
639     }
640 
641     // Test with number of model cache files greater than mNumModelCache.
642     {
643         sp<IPreparedModel> preparedModel = nullptr;
644         ErrorStatus status;
645         hidl_vec<hidl_handle> modelCache, dataCache;
646         mModelCache.push_back({mTmpCache});
647         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
648         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
649         mModelCache.pop_back();
650         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
651         if (status != ErrorStatus::GENERAL_FAILURE) {
652             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
653         }
654         ASSERT_EQ(preparedModel, nullptr);
655     }
656 
657     // Test with number of model cache files smaller than mNumModelCache.
658     if (mModelCache.size() > 0) {
659         sp<IPreparedModel> preparedModel = nullptr;
660         ErrorStatus status;
661         hidl_vec<hidl_handle> modelCache, dataCache;
662         auto tmp = mModelCache.back();
663         mModelCache.pop_back();
664         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
665         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
666         mModelCache.push_back(tmp);
667         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
668         if (status != ErrorStatus::GENERAL_FAILURE) {
669             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
670         }
671         ASSERT_EQ(preparedModel, nullptr);
672     }
673 
674     // Test with number of data cache files greater than mNumDataCache.
675     {
676         sp<IPreparedModel> preparedModel = nullptr;
677         ErrorStatus status;
678         hidl_vec<hidl_handle> modelCache, dataCache;
679         mDataCache.push_back({mTmpCache});
680         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
681         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
682         mDataCache.pop_back();
683         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
684         if (status != ErrorStatus::GENERAL_FAILURE) {
685             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
686         }
687         ASSERT_EQ(preparedModel, nullptr);
688     }
689 
690     // Test with number of data cache files smaller than mNumDataCache.
691     if (mDataCache.size() > 0) {
692         sp<IPreparedModel> preparedModel = nullptr;
693         ErrorStatus status;
694         hidl_vec<hidl_handle> modelCache, dataCache;
695         auto tmp = mDataCache.back();
696         mDataCache.pop_back();
697         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
698         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
699         mDataCache.push_back(tmp);
700         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
701         if (status != ErrorStatus::GENERAL_FAILURE) {
702             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
703         }
704         ASSERT_EQ(preparedModel, nullptr);
705     }
706 }
707 
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumFd)708 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumFd) {
709     // Create test HIDL model and compile.
710     const TestModel& testModel = createTestModel();
711     const Model model = createModel(testModel);
712     if (checkEarlyTermination(model)) return;
713 
714     // Go through each handle in model cache, test with NumFd greater than 1.
715     for (uint32_t i = 0; i < mNumModelCache; i++) {
716         hidl_vec<hidl_handle> modelCache, dataCache;
717         // Pass an invalid number of fds for handle i.
718         mModelCache[i].push_back(mTmpCache);
719         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
720         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
721         mModelCache[i].pop_back();
722         sp<IPreparedModel> preparedModel = nullptr;
723         saveModelToCache(model, modelCache, dataCache, &preparedModel);
724         ASSERT_NE(preparedModel, nullptr);
725         // Execute and verify results.
726         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
727         // Check if prepareModelFromCache fails.
728         preparedModel = nullptr;
729         ErrorStatus status;
730         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
731         if (status != ErrorStatus::INVALID_ARGUMENT) {
732             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
733         }
734         ASSERT_EQ(preparedModel, nullptr);
735     }
736 
737     // Go through each handle in model cache, test with NumFd equal to 0.
738     for (uint32_t i = 0; i < mNumModelCache; i++) {
739         hidl_vec<hidl_handle> modelCache, dataCache;
740         // Pass an invalid number of fds for handle i.
741         auto tmp = mModelCache[i].back();
742         mModelCache[i].pop_back();
743         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
744         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
745         mModelCache[i].push_back(tmp);
746         sp<IPreparedModel> preparedModel = nullptr;
747         saveModelToCache(model, modelCache, dataCache, &preparedModel);
748         ASSERT_NE(preparedModel, nullptr);
749         // Execute and verify results.
750         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
751         // Check if prepareModelFromCache fails.
752         preparedModel = nullptr;
753         ErrorStatus status;
754         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
755         if (status != ErrorStatus::INVALID_ARGUMENT) {
756             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
757         }
758         ASSERT_EQ(preparedModel, nullptr);
759     }
760 
761     // Go through each handle in data cache, test with NumFd greater than 1.
762     for (uint32_t i = 0; i < mNumDataCache; i++) {
763         hidl_vec<hidl_handle> modelCache, dataCache;
764         // Pass an invalid number of fds for handle i.
765         mDataCache[i].push_back(mTmpCache);
766         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
767         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
768         mDataCache[i].pop_back();
769         sp<IPreparedModel> preparedModel = nullptr;
770         saveModelToCache(model, modelCache, dataCache, &preparedModel);
771         ASSERT_NE(preparedModel, nullptr);
772         // Execute and verify results.
773         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
774         // Check if prepareModelFromCache fails.
775         preparedModel = nullptr;
776         ErrorStatus status;
777         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
778         if (status != ErrorStatus::INVALID_ARGUMENT) {
779             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
780         }
781         ASSERT_EQ(preparedModel, nullptr);
782     }
783 
784     // Go through each handle in data cache, test with NumFd equal to 0.
785     for (uint32_t i = 0; i < mNumDataCache; i++) {
786         hidl_vec<hidl_handle> modelCache, dataCache;
787         // Pass an invalid number of fds for handle i.
788         auto tmp = mDataCache[i].back();
789         mDataCache[i].pop_back();
790         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
791         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
792         mDataCache[i].push_back(tmp);
793         sp<IPreparedModel> preparedModel = nullptr;
794         saveModelToCache(model, modelCache, dataCache, &preparedModel);
795         ASSERT_NE(preparedModel, nullptr);
796         // Execute and verify results.
797         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
798         // Check if prepareModelFromCache fails.
799         preparedModel = nullptr;
800         ErrorStatus status;
801         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
802         if (status != ErrorStatus::INVALID_ARGUMENT) {
803             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
804         }
805         ASSERT_EQ(preparedModel, nullptr);
806     }
807 }
808 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumFd)809 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumFd) {
810     // Create test HIDL model and compile.
811     const TestModel& testModel = createTestModel();
812     const Model model = createModel(testModel);
813     if (checkEarlyTermination(model)) return;
814 
815     // Save the compilation to cache.
816     {
817         hidl_vec<hidl_handle> modelCache, dataCache;
818         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
819         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
820         saveModelToCache(model, modelCache, dataCache);
821     }
822 
823     // Go through each handle in model cache, test with NumFd greater than 1.
824     for (uint32_t i = 0; i < mNumModelCache; i++) {
825         sp<IPreparedModel> preparedModel = nullptr;
826         ErrorStatus status;
827         hidl_vec<hidl_handle> modelCache, dataCache;
828         mModelCache[i].push_back(mTmpCache);
829         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
830         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
831         mModelCache[i].pop_back();
832         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
833         if (status != ErrorStatus::GENERAL_FAILURE) {
834             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
835         }
836         ASSERT_EQ(preparedModel, nullptr);
837     }
838 
839     // Go through each handle in model cache, test with NumFd equal to 0.
840     for (uint32_t i = 0; i < mNumModelCache; i++) {
841         sp<IPreparedModel> preparedModel = nullptr;
842         ErrorStatus status;
843         hidl_vec<hidl_handle> modelCache, dataCache;
844         auto tmp = mModelCache[i].back();
845         mModelCache[i].pop_back();
846         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
847         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
848         mModelCache[i].push_back(tmp);
849         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
850         if (status != ErrorStatus::GENERAL_FAILURE) {
851             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
852         }
853         ASSERT_EQ(preparedModel, nullptr);
854     }
855 
856     // Go through each handle in data cache, test with NumFd greater than 1.
857     for (uint32_t i = 0; i < mNumDataCache; i++) {
858         sp<IPreparedModel> preparedModel = nullptr;
859         ErrorStatus status;
860         hidl_vec<hidl_handle> modelCache, dataCache;
861         mDataCache[i].push_back(mTmpCache);
862         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
863         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
864         mDataCache[i].pop_back();
865         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
866         if (status != ErrorStatus::GENERAL_FAILURE) {
867             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
868         }
869         ASSERT_EQ(preparedModel, nullptr);
870     }
871 
872     // Go through each handle in data cache, test with NumFd equal to 0.
873     for (uint32_t i = 0; i < mNumDataCache; i++) {
874         sp<IPreparedModel> preparedModel = nullptr;
875         ErrorStatus status;
876         hidl_vec<hidl_handle> modelCache, dataCache;
877         auto tmp = mDataCache[i].back();
878         mDataCache[i].pop_back();
879         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
880         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
881         mDataCache[i].push_back(tmp);
882         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
883         if (status != ErrorStatus::GENERAL_FAILURE) {
884             ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
885         }
886         ASSERT_EQ(preparedModel, nullptr);
887     }
888 }
889 
TEST_P(CompilationCachingTest,SaveToCacheInvalidAccessMode)890 TEST_P(CompilationCachingTest, SaveToCacheInvalidAccessMode) {
891     // Create test HIDL model and compile.
892     const TestModel& testModel = createTestModel();
893     const Model model = createModel(testModel);
894     if (checkEarlyTermination(model)) return;
895     std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
896     std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
897 
898     // Go through each handle in model cache, test with invalid access mode.
899     for (uint32_t i = 0; i < mNumModelCache; i++) {
900         hidl_vec<hidl_handle> modelCache, dataCache;
901         modelCacheMode[i] = AccessMode::READ_ONLY;
902         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
903         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
904         modelCacheMode[i] = AccessMode::READ_WRITE;
905         sp<IPreparedModel> preparedModel = nullptr;
906         saveModelToCache(model, modelCache, dataCache, &preparedModel);
907         ASSERT_NE(preparedModel, nullptr);
908         // Execute and verify results.
909         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
910         // Check if prepareModelFromCache fails.
911         preparedModel = nullptr;
912         ErrorStatus status;
913         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
914         if (status != ErrorStatus::INVALID_ARGUMENT) {
915             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
916         }
917         ASSERT_EQ(preparedModel, nullptr);
918     }
919 
920     // Go through each handle in data cache, test with invalid access mode.
921     for (uint32_t i = 0; i < mNumDataCache; i++) {
922         hidl_vec<hidl_handle> modelCache, dataCache;
923         dataCacheMode[i] = AccessMode::READ_ONLY;
924         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
925         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
926         dataCacheMode[i] = AccessMode::READ_WRITE;
927         sp<IPreparedModel> preparedModel = nullptr;
928         saveModelToCache(model, modelCache, dataCache, &preparedModel);
929         ASSERT_NE(preparedModel, nullptr);
930         // Execute and verify results.
931         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
932         // Check if prepareModelFromCache fails.
933         preparedModel = nullptr;
934         ErrorStatus status;
935         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
936         if (status != ErrorStatus::INVALID_ARGUMENT) {
937             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
938         }
939         ASSERT_EQ(preparedModel, nullptr);
940     }
941 }
942 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidAccessMode)943 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidAccessMode) {
944     // Create test HIDL model and compile.
945     const TestModel& testModel = createTestModel();
946     const Model model = createModel(testModel);
947     if (checkEarlyTermination(model)) return;
948     std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
949     std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
950 
951     // Save the compilation to cache.
952     {
953         hidl_vec<hidl_handle> modelCache, dataCache;
954         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
955         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
956         saveModelToCache(model, modelCache, dataCache);
957     }
958 
959     // Go through each handle in model cache, test with invalid access mode.
960     for (uint32_t i = 0; i < mNumModelCache; i++) {
961         sp<IPreparedModel> preparedModel = nullptr;
962         ErrorStatus status;
963         hidl_vec<hidl_handle> modelCache, dataCache;
964         modelCacheMode[i] = AccessMode::WRITE_ONLY;
965         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
966         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
967         modelCacheMode[i] = AccessMode::READ_WRITE;
968         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
969         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
970         ASSERT_EQ(preparedModel, nullptr);
971     }
972 
973     // Go through each handle in data cache, test with invalid access mode.
974     for (uint32_t i = 0; i < mNumDataCache; i++) {
975         sp<IPreparedModel> preparedModel = nullptr;
976         ErrorStatus status;
977         hidl_vec<hidl_handle> modelCache, dataCache;
978         dataCacheMode[i] = AccessMode::WRITE_ONLY;
979         createCacheHandles(mModelCache, modelCacheMode, &modelCache);
980         createCacheHandles(mDataCache, dataCacheMode, &dataCache);
981         dataCacheMode[i] = AccessMode::READ_WRITE;
982         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
983         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
984         ASSERT_EQ(preparedModel, nullptr);
985     }
986 }
987 
988 // Copy file contents between file groups.
989 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
990 // The outer vector sizes must match and the inner vectors must have size = 1.
copyCacheFiles(const std::vector<std::vector<std::string>> & from,const std::vector<std::vector<std::string>> & to)991 static void copyCacheFiles(const std::vector<std::vector<std::string>>& from,
992                            const std::vector<std::vector<std::string>>& to) {
993     constexpr size_t kBufferSize = 1000000;
994     uint8_t buffer[kBufferSize];
995 
996     ASSERT_EQ(from.size(), to.size());
997     for (uint32_t i = 0; i < from.size(); i++) {
998         ASSERT_EQ(from[i].size(), 1u);
999         ASSERT_EQ(to[i].size(), 1u);
1000         int fromFd = open(from[i][0].c_str(), O_RDONLY);
1001         int toFd = open(to[i][0].c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
1002         ASSERT_GE(fromFd, 0);
1003         ASSERT_GE(toFd, 0);
1004 
1005         ssize_t readBytes;
1006         while ((readBytes = read(fromFd, &buffer, kBufferSize)) > 0) {
1007             ASSERT_EQ(write(toFd, &buffer, readBytes), readBytes);
1008         }
1009         ASSERT_GE(readBytes, 0);
1010 
1011         close(fromFd);
1012         close(toFd);
1013     }
1014 }
1015 
1016 // Number of operations in the large test model.
1017 constexpr uint32_t kLargeModelSize = 100;
1018 constexpr uint32_t kNumIterationsTOCTOU = 100;
1019 
TEST_P(CompilationCachingTest,SaveToCache_TOCTOU)1020 TEST_P(CompilationCachingTest, SaveToCache_TOCTOU) {
1021     if (!mIsCachingSupported) return;
1022 
1023     // Create test models and check if fully supported by the service.
1024     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1025     const Model modelMul = createModel(testModelMul);
1026     if (checkEarlyTermination(modelMul)) return;
1027     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1028     const Model modelAdd = createModel(testModelAdd);
1029     if (checkEarlyTermination(modelAdd)) return;
1030 
1031     // Save the modelMul compilation to cache.
1032     auto modelCacheMul = mModelCache;
1033     for (auto& cache : modelCacheMul) {
1034         cache[0].append("_mul");
1035     }
1036     {
1037         hidl_vec<hidl_handle> modelCache, dataCache;
1038         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1039         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1040         saveModelToCache(modelMul, modelCache, dataCache);
1041     }
1042 
1043     // Use a different token for modelAdd.
1044     mToken[0]++;
1045 
1046     // This test is probabilistic, so we run it multiple times.
1047     for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1048         // Save the modelAdd compilation to cache.
1049         {
1050             hidl_vec<hidl_handle> modelCache, dataCache;
1051             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1052             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1053 
1054             // Spawn a thread to copy the cache content concurrently while saving to cache.
1055             std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1056             saveModelToCache(modelAdd, modelCache, dataCache);
1057             thread.join();
1058         }
1059 
1060         // Retrieve preparedModel from cache.
1061         {
1062             sp<IPreparedModel> preparedModel = nullptr;
1063             ErrorStatus status;
1064             hidl_vec<hidl_handle> modelCache, dataCache;
1065             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1066             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1067             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1068 
1069             // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1070             // the prepared model must be executed with the correct result and not crash.
1071             if (status != ErrorStatus::NONE) {
1072                 ASSERT_EQ(preparedModel, nullptr);
1073             } else {
1074                 ASSERT_NE(preparedModel, nullptr);
1075                 EvaluatePreparedModel(kDevice, preparedModel, testModelAdd,
1076                                       /*testKind=*/TestKind::GENERAL);
1077             }
1078         }
1079     }
1080 }
1081 
TEST_P(CompilationCachingTest,PrepareFromCache_TOCTOU)1082 TEST_P(CompilationCachingTest, PrepareFromCache_TOCTOU) {
1083     if (!mIsCachingSupported) return;
1084 
1085     // Create test models and check if fully supported by the service.
1086     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1087     const Model modelMul = createModel(testModelMul);
1088     if (checkEarlyTermination(modelMul)) return;
1089     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1090     const Model modelAdd = createModel(testModelAdd);
1091     if (checkEarlyTermination(modelAdd)) return;
1092 
1093     // Save the modelMul compilation to cache.
1094     auto modelCacheMul = mModelCache;
1095     for (auto& cache : modelCacheMul) {
1096         cache[0].append("_mul");
1097     }
1098     {
1099         hidl_vec<hidl_handle> modelCache, dataCache;
1100         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1101         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1102         saveModelToCache(modelMul, modelCache, dataCache);
1103     }
1104 
1105     // Use a different token for modelAdd.
1106     mToken[0]++;
1107 
1108     // This test is probabilistic, so we run it multiple times.
1109     for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1110         // Save the modelAdd compilation to cache.
1111         {
1112             hidl_vec<hidl_handle> modelCache, dataCache;
1113             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1114             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1115             saveModelToCache(modelAdd, modelCache, dataCache);
1116         }
1117 
1118         // Retrieve preparedModel from cache.
1119         {
1120             sp<IPreparedModel> preparedModel = nullptr;
1121             ErrorStatus status;
1122             hidl_vec<hidl_handle> modelCache, dataCache;
1123             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1124             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1125 
1126             // Spawn a thread to copy the cache content concurrently while preparing from cache.
1127             std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1128             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1129             thread.join();
1130 
1131             // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1132             // the prepared model must be executed with the correct result and not crash.
1133             if (status != ErrorStatus::NONE) {
1134                 ASSERT_EQ(preparedModel, nullptr);
1135             } else {
1136                 ASSERT_NE(preparedModel, nullptr);
1137                 EvaluatePreparedModel(kDevice, preparedModel, testModelAdd,
1138                                       /*testKind=*/TestKind::GENERAL);
1139             }
1140         }
1141     }
1142 }
1143 
TEST_P(CompilationCachingTest,ReplaceSecuritySensitiveCache)1144 TEST_P(CompilationCachingTest, ReplaceSecuritySensitiveCache) {
1145     if (!mIsCachingSupported) return;
1146 
1147     // Create test models and check if fully supported by the service.
1148     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1149     const Model modelMul = createModel(testModelMul);
1150     if (checkEarlyTermination(modelMul)) return;
1151     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1152     const Model modelAdd = createModel(testModelAdd);
1153     if (checkEarlyTermination(modelAdd)) return;
1154 
1155     // Save the modelMul compilation to cache.
1156     auto modelCacheMul = mModelCache;
1157     for (auto& cache : modelCacheMul) {
1158         cache[0].append("_mul");
1159     }
1160     {
1161         hidl_vec<hidl_handle> modelCache, dataCache;
1162         createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1163         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1164         saveModelToCache(modelMul, modelCache, dataCache);
1165     }
1166 
1167     // Use a different token for modelAdd.
1168     mToken[0]++;
1169 
1170     // Save the modelAdd compilation to cache.
1171     {
1172         hidl_vec<hidl_handle> modelCache, dataCache;
1173         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1174         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1175         saveModelToCache(modelAdd, modelCache, dataCache);
1176     }
1177 
1178     // Replace the model cache of modelAdd with modelMul.
1179     copyCacheFiles(modelCacheMul, mModelCache);
1180 
1181     // Retrieve the preparedModel from cache, expect failure.
1182     {
1183         sp<IPreparedModel> preparedModel = nullptr;
1184         ErrorStatus status;
1185         hidl_vec<hidl_handle> modelCache, dataCache;
1186         createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1187         createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1188         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1189         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1190         ASSERT_EQ(preparedModel, nullptr);
1191     }
1192 }
1193 
1194 static const auto kNamedDeviceChoices = testing::ValuesIn(getNamedDevices());
1195 static const auto kOperandTypeChoices =
1196         testing::Values(OperandType::TENSOR_FLOAT32, OperandType::TENSOR_QUANT8_ASYMM);
1197 
printCompilationCachingTest(const testing::TestParamInfo<CompilationCachingTestParam> & info)1198 std::string printCompilationCachingTest(
1199         const testing::TestParamInfo<CompilationCachingTestParam>& info) {
1200     const auto& [namedDevice, operandType] = info.param;
1201     const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1202     return gtestCompliantName(getName(namedDevice) + "_" + type);
1203 }
1204 
1205 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CompilationCachingTest);
1206 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, CompilationCachingTest,
1207                          testing::Combine(kNamedDeviceChoices, kOperandTypeChoices),
1208                          printCompilationCachingTest);
1209 
1210 using CompilationCachingSecurityTestParam = std::tuple<NamedDevice, OperandType, uint32_t>;
1211 
1212 class CompilationCachingSecurityTest
1213     : public CompilationCachingTestBase,
1214       public testing::WithParamInterface<CompilationCachingSecurityTestParam> {
1215   protected:
CompilationCachingSecurityTest()1216     CompilationCachingSecurityTest()
1217         : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
1218                                      std::get<OperandType>(GetParam())) {}
1219 
SetUp()1220     void SetUp() {
1221         CompilationCachingTestBase::SetUp();
1222         generator.seed(kSeed);
1223     }
1224 
1225     // Get a random integer within a closed range [lower, upper].
1226     template <typename T>
getRandomInt(T lower,T upper)1227     T getRandomInt(T lower, T upper) {
1228         std::uniform_int_distribution<T> dis(lower, upper);
1229         return dis(generator);
1230     }
1231 
1232     // Randomly flip one single bit of the cache entry.
flipOneBitOfCache(const std::string & filename,bool * skip)1233     void flipOneBitOfCache(const std::string& filename, bool* skip) {
1234         FILE* pFile = fopen(filename.c_str(), "r+");
1235         ASSERT_EQ(fseek(pFile, 0, SEEK_END), 0);
1236         long int fileSize = ftell(pFile);
1237         if (fileSize == 0) {
1238             fclose(pFile);
1239             *skip = true;
1240             return;
1241         }
1242         ASSERT_EQ(fseek(pFile, getRandomInt(0l, fileSize - 1), SEEK_SET), 0);
1243         int readByte = fgetc(pFile);
1244         ASSERT_NE(readByte, EOF);
1245         ASSERT_EQ(fseek(pFile, -1, SEEK_CUR), 0);
1246         ASSERT_NE(fputc(static_cast<uint8_t>(readByte) ^ (1U << getRandomInt(0, 7)), pFile), EOF);
1247         fclose(pFile);
1248         *skip = false;
1249     }
1250 
1251     // Randomly append bytes to the cache entry.
appendBytesToCache(const std::string & filename,bool * skip)1252     void appendBytesToCache(const std::string& filename, bool* skip) {
1253         FILE* pFile = fopen(filename.c_str(), "a");
1254         uint32_t appendLength = getRandomInt(1, 256);
1255         for (uint32_t i = 0; i < appendLength; i++) {
1256             ASSERT_NE(fputc(getRandomInt<uint8_t>(0, 255), pFile), EOF);
1257         }
1258         fclose(pFile);
1259         *skip = false;
1260     }
1261 
1262     enum class ExpectedResult { GENERAL_FAILURE, NOT_CRASH };
1263 
1264     // Test if the driver behaves as expected when given corrupted cache or token.
1265     // The modifier will be invoked after save to cache but before prepare from cache.
1266     // The modifier accepts one pointer argument "skip" as the returning value, indicating
1267     // whether the test should be skipped or not.
testCorruptedCache(ExpectedResult expected,std::function<void (bool *)> modifier)1268     void testCorruptedCache(ExpectedResult expected, std::function<void(bool*)> modifier) {
1269         const TestModel& testModel = createTestModel();
1270         const Model model = createModel(testModel);
1271         if (checkEarlyTermination(model)) return;
1272 
1273         // Save the compilation to cache.
1274         {
1275             hidl_vec<hidl_handle> modelCache, dataCache;
1276             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1277             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1278             saveModelToCache(model, modelCache, dataCache);
1279         }
1280 
1281         bool skip = false;
1282         modifier(&skip);
1283         if (skip) return;
1284 
1285         // Retrieve preparedModel from cache.
1286         {
1287             sp<IPreparedModel> preparedModel = nullptr;
1288             ErrorStatus status;
1289             hidl_vec<hidl_handle> modelCache, dataCache;
1290             createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1291             createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1292             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1293 
1294             switch (expected) {
1295                 case ExpectedResult::GENERAL_FAILURE:
1296                     ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1297                     ASSERT_EQ(preparedModel, nullptr);
1298                     break;
1299                 case ExpectedResult::NOT_CRASH:
1300                     ASSERT_EQ(preparedModel == nullptr, status != ErrorStatus::NONE);
1301                     break;
1302                 default:
1303                     FAIL();
1304             }
1305         }
1306     }
1307 
1308     const uint32_t kSeed = std::get<uint32_t>(GetParam());
1309     std::mt19937 generator;
1310 };
1311 
TEST_P(CompilationCachingSecurityTest,CorruptedModelCache)1312 TEST_P(CompilationCachingSecurityTest, CorruptedModelCache) {
1313     if (!mIsCachingSupported) return;
1314     for (uint32_t i = 0; i < mNumModelCache; i++) {
1315         testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1316                            [this, i](bool* skip) { flipOneBitOfCache(mModelCache[i][0], skip); });
1317     }
1318 }
1319 
TEST_P(CompilationCachingSecurityTest,WrongLengthModelCache)1320 TEST_P(CompilationCachingSecurityTest, WrongLengthModelCache) {
1321     if (!mIsCachingSupported) return;
1322     for (uint32_t i = 0; i < mNumModelCache; i++) {
1323         testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1324                            [this, i](bool* skip) { appendBytesToCache(mModelCache[i][0], skip); });
1325     }
1326 }
1327 
TEST_P(CompilationCachingSecurityTest,CorruptedDataCache)1328 TEST_P(CompilationCachingSecurityTest, CorruptedDataCache) {
1329     if (!mIsCachingSupported) return;
1330     for (uint32_t i = 0; i < mNumDataCache; i++) {
1331         testCorruptedCache(ExpectedResult::NOT_CRASH,
1332                            [this, i](bool* skip) { flipOneBitOfCache(mDataCache[i][0], skip); });
1333     }
1334 }
1335 
TEST_P(CompilationCachingSecurityTest,WrongLengthDataCache)1336 TEST_P(CompilationCachingSecurityTest, WrongLengthDataCache) {
1337     if (!mIsCachingSupported) return;
1338     for (uint32_t i = 0; i < mNumDataCache; i++) {
1339         testCorruptedCache(ExpectedResult::NOT_CRASH,
1340                            [this, i](bool* skip) { appendBytesToCache(mDataCache[i][0], skip); });
1341     }
1342 }
1343 
TEST_P(CompilationCachingSecurityTest,WrongToken)1344 TEST_P(CompilationCachingSecurityTest, WrongToken) {
1345     if (!mIsCachingSupported) return;
1346     testCorruptedCache(ExpectedResult::GENERAL_FAILURE, [this](bool* skip) {
1347         // Randomly flip one single bit in mToken.
1348         uint32_t ind =
1349                 getRandomInt(0u, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN) - 1);
1350         mToken[ind] ^= (1U << getRandomInt(0, 7));
1351         *skip = false;
1352     });
1353 }
1354 
printCompilationCachingSecurityTest(const testing::TestParamInfo<CompilationCachingSecurityTestParam> & info)1355 std::string printCompilationCachingSecurityTest(
1356         const testing::TestParamInfo<CompilationCachingSecurityTestParam>& info) {
1357     const auto& [namedDevice, operandType, seed] = info.param;
1358     const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1359     return gtestCompliantName(getName(namedDevice) + "_" + type + "_" + std::to_string(seed));
1360 }
1361 
1362 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CompilationCachingSecurityTest);
1363 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, CompilationCachingSecurityTest,
1364                          testing::Combine(kNamedDeviceChoices, kOperandTypeChoices,
1365                                           testing::Range(0U, 10U)),
1366                          printCompilationCachingSecurityTest);
1367 
1368 }  // namespace android::hardware::neuralnetworks::V1_3::vts::functional
1369