1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_aidl_hal_test"
18 
19 #include <android-base/logging.h>
20 #include <android/binder_auto_utils.h>
21 #include <android/binder_interface_utils.h>
22 #include <android/binder_status.h>
23 #include <fcntl.h>
24 #include <ftw.h>
25 #include <gtest/gtest.h>
26 #include <hidlmemory/mapping.h>
27 #include <unistd.h>
28 
29 #include <cstdio>
30 #include <cstdlib>
31 #include <iterator>
32 #include <random>
33 #include <thread>
34 
35 #include "Callbacks.h"
36 #include "GeneratedTestHarness.h"
37 #include "MemoryUtils.h"
38 #include "TestHarness.h"
39 #include "Utils.h"
40 #include "VtsHalNeuralnetworks.h"
41 
42 // Forward declaration of the mobilenet generated test models in
43 // frameworks/ml/nn/runtime/test/generated/.
44 namespace generated_tests::mobilenet_224_gender_basic_fixed {
45 const test_helper::TestModel& get_test_model();
46 }  // namespace generated_tests::mobilenet_224_gender_basic_fixed
47 
48 namespace generated_tests::mobilenet_quantized {
49 const test_helper::TestModel& get_test_model();
50 }  // namespace generated_tests::mobilenet_quantized
51 
52 namespace aidl::android::hardware::neuralnetworks::vts::functional {
53 
54 using namespace test_helper;
55 using implementation::PreparedModelCallback;
56 
57 namespace float32_model {
58 
59 constexpr auto get_test_model = generated_tests::mobilenet_224_gender_basic_fixed::get_test_model;
60 
61 }  // namespace float32_model
62 
63 namespace quant8_model {
64 
65 constexpr auto get_test_model = generated_tests::mobilenet_quantized::get_test_model;
66 
67 }  // namespace quant8_model
68 
69 namespace {
70 
71 enum class AccessMode { READ_WRITE, READ_ONLY, WRITE_ONLY };
72 
73 // Creates cache handles based on provided file groups.
74 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
createCacheFds(const std::vector<std::string> & files,const std::vector<AccessMode> & mode,std::vector<ndk::ScopedFileDescriptor> * fds)75 void createCacheFds(const std::vector<std::string>& files, const std::vector<AccessMode>& mode,
76                     std::vector<ndk::ScopedFileDescriptor>* fds) {
77     fds->clear();
78     fds->reserve(files.size());
79     for (uint32_t i = 0; i < files.size(); i++) {
80         const auto& file = files[i];
81         int fd;
82         if (mode[i] == AccessMode::READ_ONLY) {
83             fd = open(file.c_str(), O_RDONLY);
84         } else if (mode[i] == AccessMode::WRITE_ONLY) {
85             fd = open(file.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
86         } else if (mode[i] == AccessMode::READ_WRITE) {
87             fd = open(file.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
88         } else {
89             FAIL();
90         }
91         ASSERT_GE(fd, 0);
92         fds->emplace_back(fd);
93     }
94 }
95 
createCacheFds(const std::vector<std::string> & files,AccessMode mode,std::vector<ndk::ScopedFileDescriptor> * fds)96 void createCacheFds(const std::vector<std::string>& files, AccessMode mode,
97                     std::vector<ndk::ScopedFileDescriptor>* fds) {
98     createCacheFds(files, std::vector<AccessMode>(files.size(), mode), fds);
99 }
100 
101 // Create a chain of broadcast operations. The second operand is always constant tensor [1].
102 // For simplicity, activation scalar is shared. The second operand is not shared
103 // in the model to let driver maintain a non-trivial size of constant data and the corresponding
104 // data locations in cache.
105 //
106 //                --------- activation --------
107 //                ↓      ↓      ↓             ↓
108 // E.g. input -> ADD -> ADD -> ADD -> ... -> ADD -> output
109 //                ↑      ↑      ↑             ↑
110 //               [1]    [1]    [1]           [1]
111 //
112 // This function assumes the operation is either ADD or MUL.
113 template <typename CppType, TestOperandType operandType>
createLargeTestModelImpl(TestOperationType op,uint32_t len)114 TestModel createLargeTestModelImpl(TestOperationType op, uint32_t len) {
115     EXPECT_TRUE(op == TestOperationType::ADD || op == TestOperationType::MUL);
116 
117     // Model operations and operands.
118     std::vector<TestOperation> operations(len);
119     std::vector<TestOperand> operands(len * 2 + 2);
120 
121     // The activation scalar, value = 0.
122     operands[0] = {
123             .type = TestOperandType::INT32,
124             .dimensions = {},
125             .numberOfConsumers = len,
126             .scale = 0.0f,
127             .zeroPoint = 0,
128             .lifetime = TestOperandLifeTime::CONSTANT_COPY,
129             .data = TestBuffer::createFromVector<int32_t>({0}),
130     };
131 
132     // The buffer value of the constant second operand. The logical value is always 1.0f.
133     CppType bufferValue;
134     // The scale of the first and second operand.
135     float scale1, scale2;
136     if (operandType == TestOperandType::TENSOR_FLOAT32) {
137         bufferValue = 1.0f;
138         scale1 = 0.0f;
139         scale2 = 0.0f;
140     } else if (op == TestOperationType::ADD) {
141         bufferValue = 1;
142         scale1 = 1.0f;
143         scale2 = 1.0f;
144     } else {
145         // To satisfy the constraint on quant8 MUL: input0.scale * input1.scale < output.scale,
146         // set input1 to have scale = 0.5f and bufferValue = 2, i.e. 1.0f in floating point.
147         bufferValue = 2;
148         scale1 = 1.0f;
149         scale2 = 0.5f;
150     }
151 
152     for (uint32_t i = 0; i < len; i++) {
153         const uint32_t firstInputIndex = i * 2 + 1;
154         const uint32_t secondInputIndex = firstInputIndex + 1;
155         const uint32_t outputIndex = secondInputIndex + 1;
156 
157         // The first operation input.
158         operands[firstInputIndex] = {
159                 .type = operandType,
160                 .dimensions = {1},
161                 .numberOfConsumers = 1,
162                 .scale = scale1,
163                 .zeroPoint = 0,
164                 .lifetime = (i == 0 ? TestOperandLifeTime::MODEL_INPUT
165                                     : TestOperandLifeTime::TEMPORARY_VARIABLE),
166                 .data = (i == 0 ? TestBuffer::createFromVector<CppType>({1}) : TestBuffer()),
167         };
168 
169         // The second operation input, value = 1.
170         operands[secondInputIndex] = {
171                 .type = operandType,
172                 .dimensions = {1},
173                 .numberOfConsumers = 1,
174                 .scale = scale2,
175                 .zeroPoint = 0,
176                 .lifetime = TestOperandLifeTime::CONSTANT_COPY,
177                 .data = TestBuffer::createFromVector<CppType>({bufferValue}),
178         };
179 
180         // The operation. All operations share the same activation scalar.
181         // The output operand is created as an input in the next iteration of the loop, in the case
182         // of all but the last member of the chain; and after the loop as a model output, in the
183         // case of the last member of the chain.
184         operations[i] = {
185                 .type = op,
186                 .inputs = {firstInputIndex, secondInputIndex, /*activation scalar*/ 0},
187                 .outputs = {outputIndex},
188         };
189     }
190 
191     // For TestOperationType::ADD, output = 1 + 1 * len = len + 1
192     // For TestOperationType::MUL, output = 1 * 1 ^ len = 1
193     CppType outputResult = static_cast<CppType>(op == TestOperationType::ADD ? len + 1u : 1u);
194 
195     // The model output.
196     operands.back() = {
197             .type = operandType,
198             .dimensions = {1},
199             .numberOfConsumers = 0,
200             .scale = scale1,
201             .zeroPoint = 0,
202             .lifetime = TestOperandLifeTime::MODEL_OUTPUT,
203             .data = TestBuffer::createFromVector<CppType>({outputResult}),
204     };
205 
206     return {
207             .main = {.operands = std::move(operands),
208                      .operations = std::move(operations),
209                      .inputIndexes = {1},
210                      .outputIndexes = {len * 2 + 1}},
211             .isRelaxed = false,
212     };
213 }
214 
215 }  // namespace
216 
217 // Tag for the compilation caching tests.
218 class CompilationCachingTestBase : public testing::Test {
219   protected:
CompilationCachingTestBase(std::shared_ptr<IDevice> device,OperandType type)220     CompilationCachingTestBase(std::shared_ptr<IDevice> device, OperandType type)
221         : kDevice(std::move(device)), kOperandType(type) {}
222 
SetUp()223     void SetUp() override {
224         testing::Test::SetUp();
225         ASSERT_NE(kDevice.get(), nullptr);
226         const bool deviceIsResponsive =
227                 ndk::ScopedAStatus::fromStatus(AIBinder_ping(kDevice->asBinder().get())).isOk();
228         ASSERT_TRUE(deviceIsResponsive);
229 
230         // Create cache directory. The cache directory and a temporary cache file is always created
231         // to test the behavior of prepareModelFromCache, even when caching is not supported.
232         char cacheDirTemp[] = "/data/local/tmp/TestCompilationCachingXXXXXX";
233         char* cacheDir = mkdtemp(cacheDirTemp);
234         ASSERT_NE(cacheDir, nullptr);
235         mCacheDir = cacheDir;
236         mCacheDir.push_back('/');
237 
238         NumberOfCacheFiles numCacheFiles;
239         const auto ret = kDevice->getNumberOfCacheFilesNeeded(&numCacheFiles);
240         ASSERT_TRUE(ret.isOk());
241 
242         mNumModelCache = numCacheFiles.numModelCache;
243         mNumDataCache = numCacheFiles.numDataCache;
244         ASSERT_GE(mNumModelCache, 0) << "Invalid numModelCache: " << mNumModelCache;
245         ASSERT_GE(mNumDataCache, 0) << "Invalid numDataCache: " << mNumDataCache;
246         mIsCachingSupported = mNumModelCache > 0 || mNumDataCache > 0;
247 
248         // Create empty cache files.
249         mTmpCache = mCacheDir + "tmp";
250         for (uint32_t i = 0; i < mNumModelCache; i++) {
251             mModelCache.push_back({mCacheDir + "model" + std::to_string(i)});
252         }
253         for (uint32_t i = 0; i < mNumDataCache; i++) {
254             mDataCache.push_back({mCacheDir + "data" + std::to_string(i)});
255         }
256         // Placeholder handles, use AccessMode::WRITE_ONLY for createCacheFds to create files.
257         std::vector<ndk::ScopedFileDescriptor> modelHandle, dataHandle, tmpHandle;
258         createCacheFds(mModelCache, AccessMode::WRITE_ONLY, &modelHandle);
259         createCacheFds(mDataCache, AccessMode::WRITE_ONLY, &dataHandle);
260         createCacheFds({mTmpCache}, AccessMode::WRITE_ONLY, &tmpHandle);
261 
262         if (!mIsCachingSupported) {
263             LOG(INFO) << "NN VTS: Early termination of test because vendor service does not "
264                          "support compilation caching.";
265             std::cout << "[          ]   Early termination of test because vendor service does not "
266                          "support compilation caching."
267                       << std::endl;
268         }
269     }
270 
TearDown()271     void TearDown() override {
272         // If the test passes, remove the tmp directory.  Otherwise, keep it for debugging purposes.
273         if (!testing::Test::HasFailure()) {
274             // Recursively remove the cache directory specified by mCacheDir.
275             auto callback = [](const char* entry, const struct stat*, int, struct FTW*) {
276                 return remove(entry);
277             };
278             nftw(mCacheDir.c_str(), callback, 128, FTW_DEPTH | FTW_MOUNT | FTW_PHYS);
279         }
280         testing::Test::TearDown();
281     }
282 
283     // Model and examples creators. According to kOperandType, the following methods will return
284     // either float32 model/examples or the quant8 variant.
createTestModel()285     TestModel createTestModel() {
286         if (kOperandType == OperandType::TENSOR_FLOAT32) {
287             return float32_model::get_test_model();
288         } else {
289             return quant8_model::get_test_model();
290         }
291     }
292 
createLargeTestModel(OperationType op,uint32_t len)293     TestModel createLargeTestModel(OperationType op, uint32_t len) {
294         if (kOperandType == OperandType::TENSOR_FLOAT32) {
295             return createLargeTestModelImpl<float, TestOperandType::TENSOR_FLOAT32>(
296                     static_cast<TestOperationType>(op), len);
297         } else {
298             return createLargeTestModelImpl<uint8_t, TestOperandType::TENSOR_QUANT8_ASYMM>(
299                     static_cast<TestOperationType>(op), len);
300         }
301     }
302 
303     // See if the service can handle the model.
isModelFullySupported(const Model & model)304     bool isModelFullySupported(const Model& model) {
305         std::vector<bool> supportedOps;
306         const auto supportedCall = kDevice->getSupportedOperations(model, &supportedOps);
307         EXPECT_TRUE(supportedCall.isOk());
308         EXPECT_EQ(supportedOps.size(), model.main.operations.size());
309         if (!supportedCall.isOk() || supportedOps.size() != model.main.operations.size()) {
310             return false;
311         }
312         return std::all_of(supportedOps.begin(), supportedOps.end(),
313                            [](bool valid) { return valid; });
314     }
315 
saveModelToCache(const Model & model,const std::vector<ndk::ScopedFileDescriptor> & modelCache,const std::vector<ndk::ScopedFileDescriptor> & dataCache,std::shared_ptr<IPreparedModel> * preparedModel=nullptr)316     void saveModelToCache(const Model& model,
317                           const std::vector<ndk::ScopedFileDescriptor>& modelCache,
318                           const std::vector<ndk::ScopedFileDescriptor>& dataCache,
319                           std::shared_ptr<IPreparedModel>* preparedModel = nullptr) {
320         if (preparedModel != nullptr) *preparedModel = nullptr;
321 
322         // Launch prepare model.
323         std::shared_ptr<PreparedModelCallback> preparedModelCallback =
324                 ndk::SharedRefBase::make<PreparedModelCallback>();
325         std::vector<uint8_t> cacheToken(std::begin(mToken), std::end(mToken));
326         const auto prepareLaunchStatus = kDevice->prepareModel(
327                 model, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority, kNoDeadline,
328                 modelCache, dataCache, cacheToken, preparedModelCallback);
329         ASSERT_TRUE(prepareLaunchStatus.isOk());
330 
331         // Retrieve prepared model.
332         preparedModelCallback->wait();
333         ASSERT_EQ(preparedModelCallback->getStatus(), ErrorStatus::NONE);
334         if (preparedModel != nullptr) {
335             *preparedModel = preparedModelCallback->getPreparedModel();
336         }
337     }
338 
checkEarlyTermination(ErrorStatus status)339     bool checkEarlyTermination(ErrorStatus status) {
340         if (status == ErrorStatus::GENERAL_FAILURE) {
341             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
342                          "save the prepared model that it does not support.";
343             std::cout << "[          ]   Early termination of test because vendor service cannot "
344                          "save the prepared model that it does not support."
345                       << std::endl;
346             return true;
347         }
348         return false;
349     }
350 
checkEarlyTermination(const Model & model)351     bool checkEarlyTermination(const Model& model) {
352         if (!isModelFullySupported(model)) {
353             LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
354                          "prepare model that it does not support.";
355             std::cout << "[          ]   Early termination of test because vendor service cannot "
356                          "prepare model that it does not support."
357                       << std::endl;
358             return true;
359         }
360         return false;
361     }
362 
363     // If fallbackModel is not provided, call prepareModelFromCache.
364     // If fallbackModel is provided, and prepareModelFromCache returns GENERAL_FAILURE,
365     // then prepareModel(fallbackModel) will be called.
366     // This replicates the behaviour of the runtime when loading a model from cache.
367     // NNAPI Shim depends on this behaviour and may try to load the model from cache in
368     // prepareModel (shim needs model information when loading from cache).
prepareModelFromCache(const std::vector<ndk::ScopedFileDescriptor> & modelCache,const std::vector<ndk::ScopedFileDescriptor> & dataCache,std::shared_ptr<IPreparedModel> * preparedModel,ErrorStatus * status,const Model * fallbackModel=nullptr)369     void prepareModelFromCache(const std::vector<ndk::ScopedFileDescriptor>& modelCache,
370                                const std::vector<ndk::ScopedFileDescriptor>& dataCache,
371                                std::shared_ptr<IPreparedModel>* preparedModel, ErrorStatus* status,
372                                const Model* fallbackModel = nullptr) {
373         // Launch prepare model from cache.
374         std::shared_ptr<PreparedModelCallback> preparedModelCallback =
375                 ndk::SharedRefBase::make<PreparedModelCallback>();
376         std::vector<uint8_t> cacheToken(std::begin(mToken), std::end(mToken));
377         auto prepareLaunchStatus = kDevice->prepareModelFromCache(
378                 kNoDeadline, modelCache, dataCache, cacheToken, preparedModelCallback);
379 
380         // The shim does not support prepareModelFromCache() properly, but it
381         // will still attempt to create a model from cache when modelCache or
382         // dataCache is provided in prepareModel(). Instead of failing straight
383         // away, we try to utilize that other code path when fallbackModel is
384         // set. Note that we cannot verify whether the returned model was
385         // actually prepared from cache in that case.
386         if (!prepareLaunchStatus.isOk() &&
387             prepareLaunchStatus.getExceptionCode() == EX_SERVICE_SPECIFIC &&
388             static_cast<ErrorStatus>(prepareLaunchStatus.getServiceSpecificError()) ==
389                     ErrorStatus::GENERAL_FAILURE &&
390             mIsCachingSupported && fallbackModel != nullptr) {
391             preparedModelCallback = ndk::SharedRefBase::make<PreparedModelCallback>();
392             prepareLaunchStatus = kDevice->prepareModel(
393                     *fallbackModel, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority,
394                     kNoDeadline, modelCache, dataCache, cacheToken, preparedModelCallback);
395         }
396 
397         ASSERT_TRUE(prepareLaunchStatus.isOk() ||
398                     prepareLaunchStatus.getExceptionCode() == EX_SERVICE_SPECIFIC)
399                 << "prepareLaunchStatus: " << prepareLaunchStatus.getDescription();
400         if (!prepareLaunchStatus.isOk()) {
401             *preparedModel = nullptr;
402             *status = static_cast<ErrorStatus>(prepareLaunchStatus.getServiceSpecificError());
403             return;
404         }
405 
406         // Retrieve prepared model.
407         preparedModelCallback->wait();
408         *status = preparedModelCallback->getStatus();
409         *preparedModel = preparedModelCallback->getPreparedModel();
410     }
411 
412     // Replicate behaviour of runtime when loading model from cache.
413     // Test if prepareModelFromCache behaves correctly when faced with bad
414     // arguments. If prepareModelFromCache is not supported (GENERAL_FAILURE),
415     // it attempts to call prepareModel with same arguments, which is expected either
416     // to not support the model (GENERAL_FAILURE) or return a valid model.
verifyModelPreparationBehaviour(const std::vector<ndk::ScopedFileDescriptor> & modelCache,const std::vector<ndk::ScopedFileDescriptor> & dataCache,const Model * model,const TestModel & testModel)417     void verifyModelPreparationBehaviour(const std::vector<ndk::ScopedFileDescriptor>& modelCache,
418                                          const std::vector<ndk::ScopedFileDescriptor>& dataCache,
419                                          const Model* model, const TestModel& testModel) {
420         std::shared_ptr<IPreparedModel> preparedModel;
421         ErrorStatus status;
422 
423         // Verify that prepareModelFromCache fails either due to bad
424         // arguments (INVALID_ARGUMENT) or GENERAL_FAILURE if not supported.
425         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status,
426                               /*fallbackModel=*/nullptr);
427         if (status != ErrorStatus::INVALID_ARGUMENT) {
428             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
429         }
430         ASSERT_EQ(preparedModel, nullptr);
431 
432         // If caching is not supported, attempt calling prepareModel.
433         if (status == ErrorStatus::GENERAL_FAILURE) {
434             // Fallback with prepareModel should succeed regardless of cache files
435             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status,
436                                   /*fallbackModel=*/model);
437             // Unless caching is not supported?
438             if (status != ErrorStatus::GENERAL_FAILURE) {
439                 // But if it is, we should see a valid model.
440                 ASSERT_EQ(status, ErrorStatus::NONE);
441                 ASSERT_NE(preparedModel, nullptr);
442                 EvaluatePreparedModel(kDevice, preparedModel, testModel,
443                                       /*testKind=*/TestKind::GENERAL);
444             }
445         }
446     }
447 
448     // Absolute path to the temporary cache directory.
449     std::string mCacheDir;
450 
451     // Groups of file paths for model and data cache in the tmp cache directory, initialized with
452     // size = mNum{Model|Data}Cache. The outer vector corresponds to handles and the inner vector is
453     // for fds held by each handle.
454     std::vector<std::string> mModelCache;
455     std::vector<std::string> mDataCache;
456 
457     // A separate temporary file path in the tmp cache directory.
458     std::string mTmpCache;
459 
460     uint8_t mToken[static_cast<uint32_t>(IDevice::BYTE_SIZE_OF_CACHE_TOKEN)] = {};
461     uint32_t mNumModelCache;
462     uint32_t mNumDataCache;
463     bool mIsCachingSupported;
464 
465     const std::shared_ptr<IDevice> kDevice;
466     // The primary data type of the testModel.
467     const OperandType kOperandType;
468 };
469 
470 using CompilationCachingTestParam = std::tuple<NamedDevice, OperandType>;
471 
472 // A parameterized fixture of CompilationCachingTestBase. Every test will run twice, with the first
473 // pass running with float32 models and the second pass running with quant8 models.
474 class CompilationCachingTest : public CompilationCachingTestBase,
475                                public testing::WithParamInterface<CompilationCachingTestParam> {
476   protected:
CompilationCachingTest()477     CompilationCachingTest()
478         : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
479                                      std::get<OperandType>(GetParam())) {}
480 };
481 
TEST_P(CompilationCachingTest,CacheSavingAndRetrieval)482 TEST_P(CompilationCachingTest, CacheSavingAndRetrieval) {
483     // Create test HIDL model and compile.
484     const TestModel& testModel = createTestModel();
485     const Model model = createModel(testModel);
486     if (checkEarlyTermination(model)) return;
487     std::shared_ptr<IPreparedModel> preparedModel = nullptr;
488 
489     // Save the compilation to cache.
490     {
491         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
492         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
493         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
494         saveModelToCache(model, modelCache, dataCache);
495     }
496 
497     // Retrieve preparedModel from cache.
498     {
499         preparedModel = nullptr;
500         ErrorStatus status;
501         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
502         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
503         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
504         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status,
505                               /*fallbackModel=*/&model);
506         if (!mIsCachingSupported) {
507             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
508             ASSERT_EQ(preparedModel, nullptr);
509             return;
510         } else if (checkEarlyTermination(status)) {
511             ASSERT_EQ(preparedModel, nullptr);
512             return;
513         } else {
514             ASSERT_EQ(status, ErrorStatus::NONE);
515             ASSERT_NE(preparedModel, nullptr);
516         }
517     }
518 
519     // Execute and verify results.
520     EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
521 }
522 
TEST_P(CompilationCachingTest,CacheSavingAndRetrievalNonZeroOffset)523 TEST_P(CompilationCachingTest, CacheSavingAndRetrievalNonZeroOffset) {
524     // Create test HIDL model and compile.
525     const TestModel& testModel = createTestModel();
526     const Model model = createModel(testModel);
527     if (checkEarlyTermination(model)) return;
528     std::shared_ptr<IPreparedModel> preparedModel = nullptr;
529 
530     // Save the compilation to cache.
531     {
532         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
533         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
534         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
535         uint8_t placeholderBytes[] = {0, 0};
536         // Write a placeholder integer to the cache.
537         // The driver should be able to handle non-empty cache and non-zero fd offset.
538         for (uint32_t i = 0; i < modelCache.size(); i++) {
539             ASSERT_EQ(write(modelCache[i].get(), &placeholderBytes, sizeof(placeholderBytes)),
540                       sizeof(placeholderBytes));
541         }
542         for (uint32_t i = 0; i < dataCache.size(); i++) {
543             ASSERT_EQ(write(dataCache[i].get(), &placeholderBytes, sizeof(placeholderBytes)),
544                       sizeof(placeholderBytes));
545         }
546         saveModelToCache(model, modelCache, dataCache);
547     }
548 
549     // Retrieve preparedModel from cache.
550     {
551         preparedModel = nullptr;
552         ErrorStatus status;
553         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
554         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
555         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
556         uint8_t placeholderByte = 0;
557         // Advance the offset of each handle by one byte.
558         // The driver should be able to handle non-zero fd offset.
559         for (uint32_t i = 0; i < modelCache.size(); i++) {
560             ASSERT_GE(read(modelCache[i].get(), &placeholderByte, 1), 0);
561         }
562         for (uint32_t i = 0; i < dataCache.size(); i++) {
563             ASSERT_GE(read(dataCache[i].get(), &placeholderByte, 1), 0);
564         }
565         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status,
566                               /*fallbackModel=*/&model);
567         if (!mIsCachingSupported) {
568             ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
569             ASSERT_EQ(preparedModel, nullptr);
570             return;
571         } else if (checkEarlyTermination(status)) {
572             ASSERT_EQ(preparedModel, nullptr);
573             return;
574         } else {
575             ASSERT_EQ(status, ErrorStatus::NONE);
576             ASSERT_NE(preparedModel, nullptr);
577         }
578     }
579 
580     // Execute and verify results.
581     EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
582 }
583 
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumCache)584 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumCache) {
585     // Create test HIDL model and compile.
586     const TestModel& testModel = createTestModel();
587     const Model model = createModel(testModel);
588     if (checkEarlyTermination(model)) return;
589 
590     // Test with number of model cache files greater than mNumModelCache.
591     {
592         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
593         // Pass an additional cache file for model cache.
594         mModelCache.push_back({mTmpCache});
595         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
596         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
597         mModelCache.pop_back();
598         std::shared_ptr<IPreparedModel> preparedModel = nullptr;
599         saveModelToCache(model, modelCache, dataCache, &preparedModel);
600         ASSERT_NE(preparedModel, nullptr);
601         // Execute and verify results.
602         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
603         // Check if prepareModelFromCache fails.
604         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
605     }
606 
607     // Test with number of model cache files smaller than mNumModelCache.
608     if (mModelCache.size() > 0) {
609         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
610         // Pop out the last cache file.
611         auto tmp = mModelCache.back();
612         mModelCache.pop_back();
613         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
614         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
615         mModelCache.push_back(tmp);
616         std::shared_ptr<IPreparedModel> preparedModel = nullptr;
617         saveModelToCache(model, modelCache, dataCache, &preparedModel);
618         ASSERT_NE(preparedModel, nullptr);
619         // Execute and verify results.
620         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
621         // Check if prepareModelFromCache fails.
622         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
623     }
624 
625     // Test with number of data cache files greater than mNumDataCache.
626     {
627         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
628         // Pass an additional cache file for data cache.
629         mDataCache.push_back({mTmpCache});
630         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
631         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
632         mDataCache.pop_back();
633         std::shared_ptr<IPreparedModel> preparedModel = nullptr;
634         saveModelToCache(model, modelCache, dataCache, &preparedModel);
635         ASSERT_NE(preparedModel, nullptr);
636         // Execute and verify results.
637         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
638         // Check if prepareModelFromCache fails.
639         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
640     }
641 
642     // Test with number of data cache files smaller than mNumDataCache.
643     if (mDataCache.size() > 0) {
644         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
645         // Pop out the last cache file.
646         auto tmp = mDataCache.back();
647         mDataCache.pop_back();
648         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
649         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
650         mDataCache.push_back(tmp);
651         std::shared_ptr<IPreparedModel> preparedModel = nullptr;
652         saveModelToCache(model, modelCache, dataCache, &preparedModel);
653         ASSERT_NE(preparedModel, nullptr);
654         // Execute and verify results.
655         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
656         // Check if prepareModelFromCache fails.
657         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
658     }
659 }
660 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumCache)661 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumCache) {
662     // Create test HIDL model and compile.
663     const TestModel& testModel = createTestModel();
664     const Model model = createModel(testModel);
665     if (checkEarlyTermination(model)) return;
666 
667     // Save the compilation to cache.
668     {
669         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
670         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
671         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
672         saveModelToCache(model, modelCache, dataCache);
673     }
674 
675     // Test with number of model cache files greater than mNumModelCache.
676     {
677         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
678         mModelCache.push_back({mTmpCache});
679         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
680         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
681         mModelCache.pop_back();
682 
683         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
684     }
685 
686     // Test with number of model cache files smaller than mNumModelCache.
687     if (mModelCache.size() > 0) {
688         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
689         auto tmp = mModelCache.back();
690         mModelCache.pop_back();
691         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
692         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
693         mModelCache.push_back(tmp);
694 
695         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
696     }
697 
698     // Test with number of data cache files greater than mNumDataCache.
699     {
700         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
701         mDataCache.push_back({mTmpCache});
702         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
703         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
704         mDataCache.pop_back();
705 
706         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
707     }
708 
709     // Test with number of data cache files smaller than mNumDataCache.
710     if (mDataCache.size() > 0) {
711         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
712         auto tmp = mDataCache.back();
713         mDataCache.pop_back();
714         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
715         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
716         mDataCache.push_back(tmp);
717 
718         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
719     }
720 }
721 
TEST_P(CompilationCachingTest,SaveToCacheInvalidAccessMode)722 TEST_P(CompilationCachingTest, SaveToCacheInvalidAccessMode) {
723     // Create test HIDL model and compile.
724     const TestModel& testModel = createTestModel();
725     const Model model = createModel(testModel);
726     if (checkEarlyTermination(model)) return;
727     std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
728     std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
729 
730     // Go through each handle in model cache, test with invalid access mode.
731     for (uint32_t i = 0; i < mNumModelCache; i++) {
732         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
733         modelCacheMode[i] = AccessMode::READ_ONLY;
734         createCacheFds(mModelCache, modelCacheMode, &modelCache);
735         createCacheFds(mDataCache, dataCacheMode, &dataCache);
736         modelCacheMode[i] = AccessMode::READ_WRITE;
737         std::shared_ptr<IPreparedModel> preparedModel = nullptr;
738         saveModelToCache(model, modelCache, dataCache, &preparedModel);
739         ASSERT_NE(preparedModel, nullptr);
740         // Execute and verify results.
741         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
742         // Check if prepareModelFromCache fails.
743         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
744     }
745 
746     // Go through each handle in data cache, test with invalid access mode.
747     for (uint32_t i = 0; i < mNumDataCache; i++) {
748         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
749         dataCacheMode[i] = AccessMode::READ_ONLY;
750         createCacheFds(mModelCache, modelCacheMode, &modelCache);
751         createCacheFds(mDataCache, dataCacheMode, &dataCache);
752         dataCacheMode[i] = AccessMode::READ_WRITE;
753         std::shared_ptr<IPreparedModel> preparedModel = nullptr;
754         saveModelToCache(model, modelCache, dataCache, &preparedModel);
755         ASSERT_NE(preparedModel, nullptr);
756         // Execute and verify results.
757         EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
758         // Check if prepareModelFromCache fails.
759         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
760     }
761 }
762 
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidAccessMode)763 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidAccessMode) {
764     // Create test HIDL model and compile.
765     const TestModel& testModel = createTestModel();
766     const Model model = createModel(testModel);
767     if (checkEarlyTermination(model)) return;
768     std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
769     std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
770 
771     // Save the compilation to cache.
772     {
773         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
774         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
775         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
776         saveModelToCache(model, modelCache, dataCache);
777     }
778 
779     // Go through each handle in model cache, test with invalid access mode.
780     for (uint32_t i = 0; i < mNumModelCache; i++) {
781         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
782         modelCacheMode[i] = AccessMode::WRITE_ONLY;
783         createCacheFds(mModelCache, modelCacheMode, &modelCache);
784         createCacheFds(mDataCache, dataCacheMode, &dataCache);
785         modelCacheMode[i] = AccessMode::READ_WRITE;
786 
787         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
788     }
789 
790     // Go through each handle in data cache, test with invalid access mode.
791     for (uint32_t i = 0; i < mNumDataCache; i++) {
792         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
793         dataCacheMode[i] = AccessMode::WRITE_ONLY;
794         createCacheFds(mModelCache, modelCacheMode, &modelCache);
795         createCacheFds(mDataCache, dataCacheMode, &dataCache);
796         dataCacheMode[i] = AccessMode::READ_WRITE;
797         verifyModelPreparationBehaviour(modelCache, dataCache, &model, testModel);
798     }
799 }
800 
801 // Copy file contents between files.
802 // The vector sizes must match.
copyCacheFiles(const std::vector<std::string> & from,const std::vector<std::string> & to)803 static void copyCacheFiles(const std::vector<std::string>& from,
804                            const std::vector<std::string>& to) {
805     constexpr size_t kBufferSize = 1000000;
806     uint8_t buffer[kBufferSize];
807 
808     ASSERT_EQ(from.size(), to.size());
809     for (uint32_t i = 0; i < from.size(); i++) {
810         int fromFd = open(from[i].c_str(), O_RDONLY);
811         int toFd = open(to[i].c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
812         ASSERT_GE(fromFd, 0);
813         ASSERT_GE(toFd, 0);
814 
815         ssize_t readBytes;
816         while ((readBytes = read(fromFd, &buffer, kBufferSize)) > 0) {
817             ASSERT_EQ(write(toFd, &buffer, readBytes), readBytes);
818         }
819         ASSERT_GE(readBytes, 0);
820 
821         close(fromFd);
822         close(toFd);
823     }
824 }
825 
826 // Number of operations in the large test model.
827 constexpr uint32_t kLargeModelSize = 100;
828 constexpr uint32_t kNumIterationsTOCTOU = 100;
829 
TEST_P(CompilationCachingTest,SaveToCache_TOCTOU)830 TEST_P(CompilationCachingTest, SaveToCache_TOCTOU) {
831     if (!mIsCachingSupported) return;
832 
833     // Create test models and check if fully supported by the service.
834     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
835     const Model modelMul = createModel(testModelMul);
836     if (checkEarlyTermination(modelMul)) return;
837     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
838     const Model modelAdd = createModel(testModelAdd);
839     if (checkEarlyTermination(modelAdd)) return;
840 
841     // Save the modelMul compilation to cache.
842     auto modelCacheMul = mModelCache;
843     for (auto& cache : modelCacheMul) {
844         cache.append("_mul");
845     }
846     {
847         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
848         createCacheFds(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
849         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
850         saveModelToCache(modelMul, modelCache, dataCache);
851     }
852 
853     // Use a different token for modelAdd.
854     mToken[0]++;
855 
856     // This test is probabilistic, so we run it multiple times.
857     for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
858         // Save the modelAdd compilation to cache.
859         {
860             std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
861             createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
862             createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
863 
864             // Spawn a thread to copy the cache content concurrently while saving to cache.
865             std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
866             saveModelToCache(modelAdd, modelCache, dataCache);
867             thread.join();
868         }
869 
870         // Retrieve preparedModel from cache.
871         {
872             std::shared_ptr<IPreparedModel> preparedModel = nullptr;
873             ErrorStatus status;
874             std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
875             createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
876             createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
877             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status,
878                                   /*fallbackModel=*/nullptr);
879 
880             // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
881             // the prepared model must be executed with the correct result and not crash.
882             if (status != ErrorStatus::NONE) {
883                 ASSERT_EQ(preparedModel, nullptr);
884             } else {
885                 ASSERT_NE(preparedModel, nullptr);
886                 EvaluatePreparedModel(kDevice, preparedModel, testModelAdd,
887                                       /*testKind=*/TestKind::GENERAL);
888             }
889         }
890     }
891 }
892 
TEST_P(CompilationCachingTest,PrepareFromCache_TOCTOU)893 TEST_P(CompilationCachingTest, PrepareFromCache_TOCTOU) {
894     if (!mIsCachingSupported) return;
895 
896     // Create test models and check if fully supported by the service.
897     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
898     const Model modelMul = createModel(testModelMul);
899     if (checkEarlyTermination(modelMul)) return;
900     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
901     const Model modelAdd = createModel(testModelAdd);
902     if (checkEarlyTermination(modelAdd)) return;
903 
904     // Save the modelMul compilation to cache.
905     auto modelCacheMul = mModelCache;
906     for (auto& cache : modelCacheMul) {
907         cache.append("_mul");
908     }
909     {
910         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
911         createCacheFds(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
912         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
913         saveModelToCache(modelMul, modelCache, dataCache);
914     }
915 
916     // Use a different token for modelAdd.
917     mToken[0]++;
918 
919     // This test is probabilistic, so we run it multiple times.
920     for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
921         // Save the modelAdd compilation to cache.
922         {
923             std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
924             createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
925             createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
926             saveModelToCache(modelAdd, modelCache, dataCache);
927         }
928 
929         // Retrieve preparedModel from cache.
930         {
931             std::shared_ptr<IPreparedModel> preparedModel = nullptr;
932             ErrorStatus status;
933             std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
934             createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
935             createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
936 
937             // Spawn a thread to copy the cache content concurrently while preparing from cache.
938             std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
939             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status,
940                                   /*fallbackModel=*/nullptr);
941             thread.join();
942 
943             // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
944             // the prepared model must be executed with the correct result and not crash.
945             if (status != ErrorStatus::NONE) {
946                 ASSERT_EQ(preparedModel, nullptr);
947             } else {
948                 ASSERT_NE(preparedModel, nullptr);
949                 EvaluatePreparedModel(kDevice, preparedModel, testModelAdd,
950                                       /*testKind=*/TestKind::GENERAL);
951             }
952         }
953     }
954 }
955 
TEST_P(CompilationCachingTest,ReplaceSecuritySensitiveCache)956 TEST_P(CompilationCachingTest, ReplaceSecuritySensitiveCache) {
957     if (!mIsCachingSupported) return;
958 
959     // Create test models and check if fully supported by the service.
960     const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
961     const Model modelMul = createModel(testModelMul);
962     if (checkEarlyTermination(modelMul)) return;
963     const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
964     const Model modelAdd = createModel(testModelAdd);
965     if (checkEarlyTermination(modelAdd)) return;
966 
967     // Save the modelMul compilation to cache.
968     auto modelCacheMul = mModelCache;
969     for (auto& cache : modelCacheMul) {
970         cache.append("_mul");
971     }
972     {
973         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
974         createCacheFds(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
975         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
976         saveModelToCache(modelMul, modelCache, dataCache);
977     }
978 
979     // Use a different token for modelAdd.
980     mToken[0]++;
981 
982     // Save the modelAdd compilation to cache.
983     {
984         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
985         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
986         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
987         saveModelToCache(modelAdd, modelCache, dataCache);
988     }
989 
990     // Replace the model cache of modelAdd with modelMul.
991     copyCacheFiles(modelCacheMul, mModelCache);
992 
993     // Retrieve the preparedModel from cache, expect failure.
994     {
995         std::shared_ptr<IPreparedModel> preparedModel = nullptr;
996         ErrorStatus status;
997         std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
998         createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
999         createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
1000         prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1001         ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1002         ASSERT_EQ(preparedModel, nullptr);
1003     }
1004 }
1005 
1006 // TODO(b/179270601): restore kNamedDeviceChoices.
1007 static const auto kOperandTypeChoices =
1008         testing::Values(OperandType::TENSOR_FLOAT32, OperandType::TENSOR_QUANT8_ASYMM);
1009 
printCompilationCachingTest(const testing::TestParamInfo<CompilationCachingTestParam> & info)1010 std::string printCompilationCachingTest(
1011         const testing::TestParamInfo<CompilationCachingTestParam>& info) {
1012     const auto& [namedDevice, operandType] = info.param;
1013     const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1014     return gtestCompliantName(getName(namedDevice) + "_" + type);
1015 }
1016 
1017 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CompilationCachingTest);
1018 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, CompilationCachingTest,
1019                          testing::Combine(testing::ValuesIn(getNamedDevices()),
1020                                           kOperandTypeChoices),
1021                          printCompilationCachingTest);
1022 
1023 using CompilationCachingSecurityTestParam = std::tuple<NamedDevice, OperandType, uint32_t>;
1024 
1025 class CompilationCachingSecurityTest
1026     : public CompilationCachingTestBase,
1027       public testing::WithParamInterface<CompilationCachingSecurityTestParam> {
1028   protected:
CompilationCachingSecurityTest()1029     CompilationCachingSecurityTest()
1030         : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
1031                                      std::get<OperandType>(GetParam())) {}
1032 
SetUp()1033     void SetUp() {
1034         CompilationCachingTestBase::SetUp();
1035         generator.seed(kSeed);
1036     }
1037 
1038     // Get a random integer within a closed range [lower, upper].
1039     template <typename T>
getRandomInt(T lower,T upper)1040     T getRandomInt(T lower, T upper) {
1041         std::uniform_int_distribution<T> dis(lower, upper);
1042         return dis(generator);
1043     }
1044 
1045     // Randomly flip one single bit of the cache entry.
flipOneBitOfCache(const std::string & filename,bool * skip)1046     void flipOneBitOfCache(const std::string& filename, bool* skip) {
1047         FILE* pFile = fopen(filename.c_str(), "r+");
1048         ASSERT_EQ(fseek(pFile, 0, SEEK_END), 0);
1049         long int fileSize = ftell(pFile);
1050         if (fileSize == 0) {
1051             fclose(pFile);
1052             *skip = true;
1053             return;
1054         }
1055         ASSERT_EQ(fseek(pFile, getRandomInt(0l, fileSize - 1), SEEK_SET), 0);
1056         int readByte = fgetc(pFile);
1057         ASSERT_NE(readByte, EOF);
1058         ASSERT_EQ(fseek(pFile, -1, SEEK_CUR), 0);
1059         ASSERT_NE(fputc(static_cast<uint8_t>(readByte) ^ (1U << getRandomInt(0, 7)), pFile), EOF);
1060         fclose(pFile);
1061         *skip = false;
1062     }
1063 
1064     // Randomly append bytes to the cache entry.
appendBytesToCache(const std::string & filename,bool * skip)1065     void appendBytesToCache(const std::string& filename, bool* skip) {
1066         FILE* pFile = fopen(filename.c_str(), "a");
1067         uint32_t appendLength = getRandomInt(1, 256);
1068         for (uint32_t i = 0; i < appendLength; i++) {
1069             ASSERT_NE(fputc(getRandomInt<uint8_t>(0, 255), pFile), EOF);
1070         }
1071         fclose(pFile);
1072         *skip = false;
1073     }
1074 
1075     enum class ExpectedResult { GENERAL_FAILURE, NOT_CRASH };
1076 
1077     // Test if the driver behaves as expected when given corrupted cache or token.
1078     // The modifier will be invoked after save to cache but before prepare from cache.
1079     // The modifier accepts one pointer argument "skip" as the returning value, indicating
1080     // whether the test should be skipped or not.
testCorruptedCache(ExpectedResult expected,std::function<void (bool *)> modifier)1081     void testCorruptedCache(ExpectedResult expected, std::function<void(bool*)> modifier) {
1082         const TestModel& testModel = createTestModel();
1083         const Model model = createModel(testModel);
1084         if (checkEarlyTermination(model)) return;
1085 
1086         // Save the compilation to cache.
1087         {
1088             std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
1089             createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
1090             createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
1091             saveModelToCache(model, modelCache, dataCache);
1092         }
1093 
1094         bool skip = false;
1095         modifier(&skip);
1096         if (skip) return;
1097 
1098         // Retrieve preparedModel from cache.
1099         {
1100             std::shared_ptr<IPreparedModel> preparedModel = nullptr;
1101             ErrorStatus status;
1102             std::vector<ndk::ScopedFileDescriptor> modelCache, dataCache;
1103             createCacheFds(mModelCache, AccessMode::READ_WRITE, &modelCache);
1104             createCacheFds(mDataCache, AccessMode::READ_WRITE, &dataCache);
1105             prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1106 
1107             switch (expected) {
1108                 case ExpectedResult::GENERAL_FAILURE:
1109                     ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1110                     ASSERT_EQ(preparedModel, nullptr);
1111                     break;
1112                 case ExpectedResult::NOT_CRASH:
1113                     ASSERT_EQ(preparedModel == nullptr, status != ErrorStatus::NONE);
1114                     break;
1115                 default:
1116                     FAIL();
1117             }
1118         }
1119     }
1120 
1121     const uint32_t kSeed = std::get<uint32_t>(GetParam());
1122     std::mt19937 generator;
1123 };
1124 
TEST_P(CompilationCachingSecurityTest,CorruptedModelCache)1125 TEST_P(CompilationCachingSecurityTest, CorruptedModelCache) {
1126     if (!mIsCachingSupported) return;
1127     for (uint32_t i = 0; i < mNumModelCache; i++) {
1128         testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1129                            [this, i](bool* skip) { flipOneBitOfCache(mModelCache[i], skip); });
1130     }
1131 }
1132 
TEST_P(CompilationCachingSecurityTest,WrongLengthModelCache)1133 TEST_P(CompilationCachingSecurityTest, WrongLengthModelCache) {
1134     if (!mIsCachingSupported) return;
1135     for (uint32_t i = 0; i < mNumModelCache; i++) {
1136         testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1137                            [this, i](bool* skip) { appendBytesToCache(mModelCache[i], skip); });
1138     }
1139 }
1140 
TEST_P(CompilationCachingSecurityTest,CorruptedDataCache)1141 TEST_P(CompilationCachingSecurityTest, CorruptedDataCache) {
1142     if (!mIsCachingSupported) return;
1143     for (uint32_t i = 0; i < mNumDataCache; i++) {
1144         testCorruptedCache(ExpectedResult::NOT_CRASH,
1145                            [this, i](bool* skip) { flipOneBitOfCache(mDataCache[i], skip); });
1146     }
1147 }
1148 
TEST_P(CompilationCachingSecurityTest,WrongLengthDataCache)1149 TEST_P(CompilationCachingSecurityTest, WrongLengthDataCache) {
1150     if (!mIsCachingSupported) return;
1151     for (uint32_t i = 0; i < mNumDataCache; i++) {
1152         testCorruptedCache(ExpectedResult::NOT_CRASH,
1153                            [this, i](bool* skip) { appendBytesToCache(mDataCache[i], skip); });
1154     }
1155 }
1156 
TEST_P(CompilationCachingSecurityTest,WrongToken)1157 TEST_P(CompilationCachingSecurityTest, WrongToken) {
1158     if (!mIsCachingSupported) return;
1159     testCorruptedCache(ExpectedResult::GENERAL_FAILURE, [this](bool* skip) {
1160         // Randomly flip one single bit in mToken.
1161         uint32_t ind =
1162                 getRandomInt(0u, static_cast<uint32_t>(IDevice::BYTE_SIZE_OF_CACHE_TOKEN) - 1);
1163         mToken[ind] ^= (1U << getRandomInt(0, 7));
1164         *skip = false;
1165     });
1166 }
1167 
printCompilationCachingSecurityTest(const testing::TestParamInfo<CompilationCachingSecurityTestParam> & info)1168 std::string printCompilationCachingSecurityTest(
1169         const testing::TestParamInfo<CompilationCachingSecurityTestParam>& info) {
1170     const auto& [namedDevice, operandType, seed] = info.param;
1171     const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1172     return gtestCompliantName(getName(namedDevice) + "_" + type + "_" + std::to_string(seed));
1173 }
1174 
1175 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CompilationCachingSecurityTest);
1176 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, CompilationCachingSecurityTest,
1177                          testing::Combine(testing::ValuesIn(getNamedDevices()), kOperandTypeChoices,
1178                                           testing::Range(0U, 10U)),
1179                          printCompilationCachingSecurityTest);
1180 
1181 }  // namespace aidl::android::hardware::neuralnetworks::vts::functional
1182