1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18
19 #include <android-base/logging.h>
20 #include <fcntl.h>
21 #include <ftw.h>
22 #include <gtest/gtest.h>
23 #include <hidlmemory/mapping.h>
24 #include <unistd.h>
25
26 #include <cstdio>
27 #include <cstdlib>
28 #include <random>
29 #include <thread>
30
31 #include "1.3/Callbacks.h"
32 #include "1.3/Utils.h"
33 #include "GeneratedTestHarness.h"
34 #include "MemoryUtils.h"
35 #include "TestHarness.h"
36 #include "Utils.h"
37 #include "VtsHalNeuralnetworks.h"
38
39 // Forward declaration of the mobilenet generated test models in
40 // frameworks/ml/nn/runtime/test/generated/.
41 namespace generated_tests::mobilenet_224_gender_basic_fixed {
42 const test_helper::TestModel& get_test_model();
43 } // namespace generated_tests::mobilenet_224_gender_basic_fixed
44
45 namespace generated_tests::mobilenet_quantized {
46 const test_helper::TestModel& get_test_model();
47 } // namespace generated_tests::mobilenet_quantized
48
49 namespace android::hardware::neuralnetworks::V1_3::vts::functional {
50
51 using namespace test_helper;
52 using implementation::PreparedModelCallback;
53 using V1_1::ExecutionPreference;
54 using V1_2::Constant;
55 using V1_2::OperationType;
56
57 namespace float32_model {
58
59 constexpr auto get_test_model = generated_tests::mobilenet_224_gender_basic_fixed::get_test_model;
60
61 } // namespace float32_model
62
63 namespace quant8_model {
64
65 constexpr auto get_test_model = generated_tests::mobilenet_quantized::get_test_model;
66
67 } // namespace quant8_model
68
69 namespace {
70
71 enum class AccessMode { READ_WRITE, READ_ONLY, WRITE_ONLY };
72
73 // Creates cache handles based on provided file groups.
74 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
createCacheHandles(const std::vector<std::vector<std::string>> & fileGroups,const std::vector<AccessMode> & mode,hidl_vec<hidl_handle> * handles)75 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups,
76 const std::vector<AccessMode>& mode, hidl_vec<hidl_handle>* handles) {
77 handles->resize(fileGroups.size());
78 for (uint32_t i = 0; i < fileGroups.size(); i++) {
79 std::vector<int> fds;
80 for (const auto& file : fileGroups[i]) {
81 int fd;
82 if (mode[i] == AccessMode::READ_ONLY) {
83 fd = open(file.c_str(), O_RDONLY);
84 } else if (mode[i] == AccessMode::WRITE_ONLY) {
85 fd = open(file.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
86 } else if (mode[i] == AccessMode::READ_WRITE) {
87 fd = open(file.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
88 } else {
89 FAIL();
90 }
91 ASSERT_GE(fd, 0);
92 fds.push_back(fd);
93 }
94 native_handle_t* cacheNativeHandle = native_handle_create(fds.size(), 0);
95 ASSERT_NE(cacheNativeHandle, nullptr);
96 std::copy(fds.begin(), fds.end(), &cacheNativeHandle->data[0]);
97 (*handles)[i].setTo(cacheNativeHandle, /*shouldOwn=*/true);
98 }
99 }
100
createCacheHandles(const std::vector<std::vector<std::string>> & fileGroups,AccessMode mode,hidl_vec<hidl_handle> * handles)101 void createCacheHandles(const std::vector<std::vector<std::string>>& fileGroups, AccessMode mode,
102 hidl_vec<hidl_handle>* handles) {
103 createCacheHandles(fileGroups, std::vector<AccessMode>(fileGroups.size(), mode), handles);
104 }
105
106 // Create a chain of broadcast operations. The second operand is always constant tensor [1].
107 // For simplicity, activation scalar is shared. The second operand is not shared
108 // in the model to let driver maintain a non-trivial size of constant data and the corresponding
109 // data locations in cache.
110 //
111 // --------- activation --------
112 // ↓ ↓ ↓ ↓
113 // E.g. input -> ADD -> ADD -> ADD -> ... -> ADD -> output
114 // ↑ ↑ ↑ ↑
115 // [1] [1] [1] [1]
116 //
117 // This function assumes the operation is either ADD or MUL.
118 template <typename CppType, TestOperandType operandType>
createLargeTestModelImpl(TestOperationType op,uint32_t len)119 TestModel createLargeTestModelImpl(TestOperationType op, uint32_t len) {
120 EXPECT_TRUE(op == TestOperationType::ADD || op == TestOperationType::MUL);
121
122 // Model operations and operands.
123 std::vector<TestOperation> operations(len);
124 std::vector<TestOperand> operands(len * 2 + 2);
125
126 // The activation scalar, value = 0.
127 operands[0] = {
128 .type = TestOperandType::INT32,
129 .dimensions = {},
130 .numberOfConsumers = len,
131 .scale = 0.0f,
132 .zeroPoint = 0,
133 .lifetime = TestOperandLifeTime::CONSTANT_COPY,
134 .data = TestBuffer::createFromVector<int32_t>({0}),
135 };
136
137 // The buffer value of the constant second operand. The logical value is always 1.0f.
138 CppType bufferValue;
139 // The scale of the first and second operand.
140 float scale1, scale2;
141 if (operandType == TestOperandType::TENSOR_FLOAT32) {
142 bufferValue = 1.0f;
143 scale1 = 0.0f;
144 scale2 = 0.0f;
145 } else if (op == TestOperationType::ADD) {
146 bufferValue = 1;
147 scale1 = 1.0f;
148 scale2 = 1.0f;
149 } else {
150 // To satisfy the constraint on quant8 MUL: input0.scale * input1.scale < output.scale,
151 // set input1 to have scale = 0.5f and bufferValue = 2, i.e. 1.0f in floating point.
152 bufferValue = 2;
153 scale1 = 1.0f;
154 scale2 = 0.5f;
155 }
156
157 for (uint32_t i = 0; i < len; i++) {
158 const uint32_t firstInputIndex = i * 2 + 1;
159 const uint32_t secondInputIndex = firstInputIndex + 1;
160 const uint32_t outputIndex = secondInputIndex + 1;
161
162 // The first operation input.
163 operands[firstInputIndex] = {
164 .type = operandType,
165 .dimensions = {1},
166 .numberOfConsumers = 1,
167 .scale = scale1,
168 .zeroPoint = 0,
169 .lifetime = (i == 0 ? TestOperandLifeTime::MODEL_INPUT
170 : TestOperandLifeTime::TEMPORARY_VARIABLE),
171 .data = (i == 0 ? TestBuffer::createFromVector<CppType>({1}) : TestBuffer()),
172 };
173
174 // The second operation input, value = 1.
175 operands[secondInputIndex] = {
176 .type = operandType,
177 .dimensions = {1},
178 .numberOfConsumers = 1,
179 .scale = scale2,
180 .zeroPoint = 0,
181 .lifetime = TestOperandLifeTime::CONSTANT_COPY,
182 .data = TestBuffer::createFromVector<CppType>({bufferValue}),
183 };
184
185 // The operation. All operations share the same activation scalar.
186 // The output operand is created as an input in the next iteration of the loop, in the case
187 // of all but the last member of the chain; and after the loop as a model output, in the
188 // case of the last member of the chain.
189 operations[i] = {
190 .type = op,
191 .inputs = {firstInputIndex, secondInputIndex, /*activation scalar*/ 0},
192 .outputs = {outputIndex},
193 };
194 }
195
196 // For TestOperationType::ADD, output = 1 + 1 * len = len + 1
197 // For TestOperationType::MUL, output = 1 * 1 ^ len = 1
198 CppType outputResult = static_cast<CppType>(op == TestOperationType::ADD ? len + 1u : 1u);
199
200 // The model output.
201 operands.back() = {
202 .type = operandType,
203 .dimensions = {1},
204 .numberOfConsumers = 0,
205 .scale = scale1,
206 .zeroPoint = 0,
207 .lifetime = TestOperandLifeTime::MODEL_OUTPUT,
208 .data = TestBuffer::createFromVector<CppType>({outputResult}),
209 };
210
211 return {
212 .main = {.operands = std::move(operands),
213 .operations = std::move(operations),
214 .inputIndexes = {1},
215 .outputIndexes = {len * 2 + 1}},
216 .isRelaxed = false,
217 };
218 }
219
220 } // namespace
221
222 // Tag for the compilation caching tests.
223 class CompilationCachingTestBase : public testing::Test {
224 protected:
CompilationCachingTestBase(sp<IDevice> device,OperandType type)225 CompilationCachingTestBase(sp<IDevice> device, OperandType type)
226 : kDevice(std::move(device)), kOperandType(type) {}
227
SetUp()228 void SetUp() override {
229 testing::Test::SetUp();
230 ASSERT_NE(kDevice.get(), nullptr);
231 const bool deviceIsResponsive = kDevice->ping().isOk();
232 ASSERT_TRUE(deviceIsResponsive);
233
234 // Create cache directory. The cache directory and a temporary cache file is always created
235 // to test the behavior of prepareModelFromCache_1_3, even when caching is not supported.
236 char cacheDirTemp[] = "/data/local/tmp/TestCompilationCachingXXXXXX";
237 char* cacheDir = mkdtemp(cacheDirTemp);
238 ASSERT_NE(cacheDir, nullptr);
239 mCacheDir = cacheDir;
240 mCacheDir.push_back('/');
241
242 Return<void> ret = kDevice->getNumberOfCacheFilesNeeded(
243 [this](V1_0::ErrorStatus status, uint32_t numModelCache, uint32_t numDataCache) {
244 EXPECT_EQ(V1_0::ErrorStatus::NONE, status);
245 mNumModelCache = numModelCache;
246 mNumDataCache = numDataCache;
247 });
248 EXPECT_TRUE(ret.isOk());
249 mIsCachingSupported = mNumModelCache > 0 || mNumDataCache > 0;
250
251 // Create empty cache files.
252 mTmpCache = mCacheDir + "tmp";
253 for (uint32_t i = 0; i < mNumModelCache; i++) {
254 mModelCache.push_back({mCacheDir + "model" + std::to_string(i)});
255 }
256 for (uint32_t i = 0; i < mNumDataCache; i++) {
257 mDataCache.push_back({mCacheDir + "data" + std::to_string(i)});
258 }
259 // Dummy handles, use AccessMode::WRITE_ONLY for createCacheHandles to create files.
260 hidl_vec<hidl_handle> modelHandle, dataHandle, tmpHandle;
261 createCacheHandles(mModelCache, AccessMode::WRITE_ONLY, &modelHandle);
262 createCacheHandles(mDataCache, AccessMode::WRITE_ONLY, &dataHandle);
263 createCacheHandles({{mTmpCache}}, AccessMode::WRITE_ONLY, &tmpHandle);
264
265 if (!mIsCachingSupported) {
266 LOG(INFO) << "NN VTS: Early termination of test because vendor service does not "
267 "support compilation caching.";
268 std::cout << "[ ] Early termination of test because vendor service does not "
269 "support compilation caching."
270 << std::endl;
271 }
272 }
273
TearDown()274 void TearDown() override {
275 // If the test passes, remove the tmp directory. Otherwise, keep it for debugging purposes.
276 if (!testing::Test::HasFailure()) {
277 // Recursively remove the cache directory specified by mCacheDir.
278 auto callback = [](const char* entry, const struct stat*, int, struct FTW*) {
279 return remove(entry);
280 };
281 nftw(mCacheDir.c_str(), callback, 128, FTW_DEPTH | FTW_MOUNT | FTW_PHYS);
282 }
283 testing::Test::TearDown();
284 }
285
286 // Model and examples creators. According to kOperandType, the following methods will return
287 // either float32 model/examples or the quant8 variant.
createTestModel()288 TestModel createTestModel() {
289 if (kOperandType == OperandType::TENSOR_FLOAT32) {
290 return float32_model::get_test_model();
291 } else {
292 return quant8_model::get_test_model();
293 }
294 }
295
createLargeTestModel(OperationType op,uint32_t len)296 TestModel createLargeTestModel(OperationType op, uint32_t len) {
297 if (kOperandType == OperandType::TENSOR_FLOAT32) {
298 return createLargeTestModelImpl<float, TestOperandType::TENSOR_FLOAT32>(
299 static_cast<TestOperationType>(op), len);
300 } else {
301 return createLargeTestModelImpl<uint8_t, TestOperandType::TENSOR_QUANT8_ASYMM>(
302 static_cast<TestOperationType>(op), len);
303 }
304 }
305
306 // See if the service can handle the model.
isModelFullySupported(const Model & model)307 bool isModelFullySupported(const Model& model) {
308 bool fullySupportsModel = false;
309 Return<void> supportedCall = kDevice->getSupportedOperations_1_3(
310 model,
311 [&fullySupportsModel, &model](ErrorStatus status, const hidl_vec<bool>& supported) {
312 ASSERT_EQ(ErrorStatus::NONE, status);
313 ASSERT_EQ(supported.size(), model.main.operations.size());
314 fullySupportsModel = std::all_of(supported.begin(), supported.end(),
315 [](bool valid) { return valid; });
316 });
317 EXPECT_TRUE(supportedCall.isOk());
318 return fullySupportsModel;
319 }
320
saveModelToCache(const Model & model,const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,sp<IPreparedModel> * preparedModel=nullptr)321 void saveModelToCache(const Model& model, const hidl_vec<hidl_handle>& modelCache,
322 const hidl_vec<hidl_handle>& dataCache,
323 sp<IPreparedModel>* preparedModel = nullptr) {
324 if (preparedModel != nullptr) *preparedModel = nullptr;
325
326 // Launch prepare model.
327 sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
328 hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
329 Return<ErrorStatus> prepareLaunchStatus = kDevice->prepareModel_1_3(
330 model, ExecutionPreference::FAST_SINGLE_ANSWER, kDefaultPriority, {}, modelCache,
331 dataCache, cacheToken, preparedModelCallback);
332 ASSERT_TRUE(prepareLaunchStatus.isOk());
333 ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus), ErrorStatus::NONE);
334
335 // Retrieve prepared model.
336 preparedModelCallback->wait();
337 ASSERT_EQ(preparedModelCallback->getStatus(), ErrorStatus::NONE);
338 if (preparedModel != nullptr) {
339 *preparedModel = IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
340 .withDefault(nullptr);
341 }
342 }
343
checkEarlyTermination(ErrorStatus status)344 bool checkEarlyTermination(ErrorStatus status) {
345 if (status == ErrorStatus::GENERAL_FAILURE) {
346 LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
347 "save the prepared model that it does not support.";
348 std::cout << "[ ] Early termination of test because vendor service cannot "
349 "save the prepared model that it does not support."
350 << std::endl;
351 return true;
352 }
353 return false;
354 }
355
checkEarlyTermination(const Model & model)356 bool checkEarlyTermination(const Model& model) {
357 if (!isModelFullySupported(model)) {
358 LOG(INFO) << "NN VTS: Early termination of test because vendor service cannot "
359 "prepare model that it does not support.";
360 std::cout << "[ ] Early termination of test because vendor service cannot "
361 "prepare model that it does not support."
362 << std::endl;
363 return true;
364 }
365 return false;
366 }
367
prepareModelFromCache(const hidl_vec<hidl_handle> & modelCache,const hidl_vec<hidl_handle> & dataCache,sp<IPreparedModel> * preparedModel,ErrorStatus * status)368 void prepareModelFromCache(const hidl_vec<hidl_handle>& modelCache,
369 const hidl_vec<hidl_handle>& dataCache,
370 sp<IPreparedModel>* preparedModel, ErrorStatus* status) {
371 // Launch prepare model from cache.
372 sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
373 hidl_array<uint8_t, sizeof(mToken)> cacheToken(mToken);
374 Return<ErrorStatus> prepareLaunchStatus = kDevice->prepareModelFromCache_1_3(
375 {}, modelCache, dataCache, cacheToken, preparedModelCallback);
376 ASSERT_TRUE(prepareLaunchStatus.isOk());
377 if (static_cast<ErrorStatus>(prepareLaunchStatus) != ErrorStatus::NONE) {
378 *preparedModel = nullptr;
379 *status = static_cast<ErrorStatus>(prepareLaunchStatus);
380 return;
381 }
382
383 // Retrieve prepared model.
384 preparedModelCallback->wait();
385 *status = preparedModelCallback->getStatus();
386 *preparedModel = IPreparedModel::castFrom(preparedModelCallback->getPreparedModel())
387 .withDefault(nullptr);
388 }
389
390 // Absolute path to the temporary cache directory.
391 std::string mCacheDir;
392
393 // Groups of file paths for model and data cache in the tmp cache directory, initialized with
394 // outer_size = mNum{Model|Data}Cache, inner_size = 1. The outer vector corresponds to handles
395 // and the inner vector is for fds held by each handle.
396 std::vector<std::vector<std::string>> mModelCache;
397 std::vector<std::vector<std::string>> mDataCache;
398
399 // A separate temporary file path in the tmp cache directory.
400 std::string mTmpCache;
401
402 uint8_t mToken[static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN)] = {};
403 uint32_t mNumModelCache;
404 uint32_t mNumDataCache;
405 uint32_t mIsCachingSupported;
406
407 const sp<IDevice> kDevice;
408 // The primary data type of the testModel.
409 const OperandType kOperandType;
410 };
411
412 using CompilationCachingTestParam = std::tuple<NamedDevice, OperandType>;
413
414 // A parameterized fixture of CompilationCachingTestBase. Every test will run twice, with the first
415 // pass running with float32 models and the second pass running with quant8 models.
416 class CompilationCachingTest : public CompilationCachingTestBase,
417 public testing::WithParamInterface<CompilationCachingTestParam> {
418 protected:
CompilationCachingTest()419 CompilationCachingTest()
420 : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
421 std::get<OperandType>(GetParam())) {}
422 };
423
TEST_P(CompilationCachingTest,CacheSavingAndRetrieval)424 TEST_P(CompilationCachingTest, CacheSavingAndRetrieval) {
425 // Create test HIDL model and compile.
426 const TestModel& testModel = createTestModel();
427 const Model model = createModel(testModel);
428 if (checkEarlyTermination(model)) return;
429 sp<IPreparedModel> preparedModel = nullptr;
430
431 // Save the compilation to cache.
432 {
433 hidl_vec<hidl_handle> modelCache, dataCache;
434 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
435 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
436 saveModelToCache(model, modelCache, dataCache);
437 }
438
439 // Retrieve preparedModel from cache.
440 {
441 preparedModel = nullptr;
442 ErrorStatus status;
443 hidl_vec<hidl_handle> modelCache, dataCache;
444 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
445 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
446 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
447 if (!mIsCachingSupported) {
448 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
449 ASSERT_EQ(preparedModel, nullptr);
450 return;
451 } else if (checkEarlyTermination(status)) {
452 ASSERT_EQ(preparedModel, nullptr);
453 return;
454 } else {
455 ASSERT_EQ(status, ErrorStatus::NONE);
456 ASSERT_NE(preparedModel, nullptr);
457 }
458 }
459
460 // Execute and verify results.
461 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
462 }
463
TEST_P(CompilationCachingTest,CacheSavingAndRetrievalNonZeroOffset)464 TEST_P(CompilationCachingTest, CacheSavingAndRetrievalNonZeroOffset) {
465 // Create test HIDL model and compile.
466 const TestModel& testModel = createTestModel();
467 const Model model = createModel(testModel);
468 if (checkEarlyTermination(model)) return;
469 sp<IPreparedModel> preparedModel = nullptr;
470
471 // Save the compilation to cache.
472 {
473 hidl_vec<hidl_handle> modelCache, dataCache;
474 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
475 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
476 uint8_t dummyBytes[] = {0, 0};
477 // Write a dummy integer to the cache.
478 // The driver should be able to handle non-empty cache and non-zero fd offset.
479 for (uint32_t i = 0; i < modelCache.size(); i++) {
480 ASSERT_EQ(write(modelCache[i].getNativeHandle()->data[0], &dummyBytes,
481 sizeof(dummyBytes)),
482 sizeof(dummyBytes));
483 }
484 for (uint32_t i = 0; i < dataCache.size(); i++) {
485 ASSERT_EQ(
486 write(dataCache[i].getNativeHandle()->data[0], &dummyBytes, sizeof(dummyBytes)),
487 sizeof(dummyBytes));
488 }
489 saveModelToCache(model, modelCache, dataCache);
490 }
491
492 // Retrieve preparedModel from cache.
493 {
494 preparedModel = nullptr;
495 ErrorStatus status;
496 hidl_vec<hidl_handle> modelCache, dataCache;
497 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
498 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
499 uint8_t dummyByte = 0;
500 // Advance the offset of each handle by one byte.
501 // The driver should be able to handle non-zero fd offset.
502 for (uint32_t i = 0; i < modelCache.size(); i++) {
503 ASSERT_GE(read(modelCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
504 }
505 for (uint32_t i = 0; i < dataCache.size(); i++) {
506 ASSERT_GE(read(dataCache[i].getNativeHandle()->data[0], &dummyByte, 1), 0);
507 }
508 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
509 if (!mIsCachingSupported) {
510 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
511 ASSERT_EQ(preparedModel, nullptr);
512 return;
513 } else if (checkEarlyTermination(status)) {
514 ASSERT_EQ(preparedModel, nullptr);
515 return;
516 } else {
517 ASSERT_EQ(status, ErrorStatus::NONE);
518 ASSERT_NE(preparedModel, nullptr);
519 }
520 }
521
522 // Execute and verify results.
523 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
524 }
525
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumCache)526 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumCache) {
527 // Create test HIDL model and compile.
528 const TestModel& testModel = createTestModel();
529 const Model model = createModel(testModel);
530 if (checkEarlyTermination(model)) return;
531
532 // Test with number of model cache files greater than mNumModelCache.
533 {
534 hidl_vec<hidl_handle> modelCache, dataCache;
535 // Pass an additional cache file for model cache.
536 mModelCache.push_back({mTmpCache});
537 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
538 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
539 mModelCache.pop_back();
540 sp<IPreparedModel> preparedModel = nullptr;
541 saveModelToCache(model, modelCache, dataCache, &preparedModel);
542 ASSERT_NE(preparedModel, nullptr);
543 // Execute and verify results.
544 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
545 // Check if prepareModelFromCache fails.
546 preparedModel = nullptr;
547 ErrorStatus status;
548 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
549 if (status != ErrorStatus::INVALID_ARGUMENT) {
550 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
551 }
552 ASSERT_EQ(preparedModel, nullptr);
553 }
554
555 // Test with number of model cache files smaller than mNumModelCache.
556 if (mModelCache.size() > 0) {
557 hidl_vec<hidl_handle> modelCache, dataCache;
558 // Pop out the last cache file.
559 auto tmp = mModelCache.back();
560 mModelCache.pop_back();
561 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
562 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
563 mModelCache.push_back(tmp);
564 sp<IPreparedModel> preparedModel = nullptr;
565 saveModelToCache(model, modelCache, dataCache, &preparedModel);
566 ASSERT_NE(preparedModel, nullptr);
567 // Execute and verify results.
568 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
569 // Check if prepareModelFromCache fails.
570 preparedModel = nullptr;
571 ErrorStatus status;
572 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
573 if (status != ErrorStatus::INVALID_ARGUMENT) {
574 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
575 }
576 ASSERT_EQ(preparedModel, nullptr);
577 }
578
579 // Test with number of data cache files greater than mNumDataCache.
580 {
581 hidl_vec<hidl_handle> modelCache, dataCache;
582 // Pass an additional cache file for data cache.
583 mDataCache.push_back({mTmpCache});
584 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
585 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
586 mDataCache.pop_back();
587 sp<IPreparedModel> preparedModel = nullptr;
588 saveModelToCache(model, modelCache, dataCache, &preparedModel);
589 ASSERT_NE(preparedModel, nullptr);
590 // Execute and verify results.
591 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
592 // Check if prepareModelFromCache fails.
593 preparedModel = nullptr;
594 ErrorStatus status;
595 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
596 if (status != ErrorStatus::INVALID_ARGUMENT) {
597 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
598 }
599 ASSERT_EQ(preparedModel, nullptr);
600 }
601
602 // Test with number of data cache files smaller than mNumDataCache.
603 if (mDataCache.size() > 0) {
604 hidl_vec<hidl_handle> modelCache, dataCache;
605 // Pop out the last cache file.
606 auto tmp = mDataCache.back();
607 mDataCache.pop_back();
608 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
609 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
610 mDataCache.push_back(tmp);
611 sp<IPreparedModel> preparedModel = nullptr;
612 saveModelToCache(model, modelCache, dataCache, &preparedModel);
613 ASSERT_NE(preparedModel, nullptr);
614 // Execute and verify results.
615 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
616 // Check if prepareModelFromCache fails.
617 preparedModel = nullptr;
618 ErrorStatus status;
619 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
620 if (status != ErrorStatus::INVALID_ARGUMENT) {
621 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
622 }
623 ASSERT_EQ(preparedModel, nullptr);
624 }
625 }
626
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumCache)627 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumCache) {
628 // Create test HIDL model and compile.
629 const TestModel& testModel = createTestModel();
630 const Model model = createModel(testModel);
631 if (checkEarlyTermination(model)) return;
632
633 // Save the compilation to cache.
634 {
635 hidl_vec<hidl_handle> modelCache, dataCache;
636 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
637 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
638 saveModelToCache(model, modelCache, dataCache);
639 }
640
641 // Test with number of model cache files greater than mNumModelCache.
642 {
643 sp<IPreparedModel> preparedModel = nullptr;
644 ErrorStatus status;
645 hidl_vec<hidl_handle> modelCache, dataCache;
646 mModelCache.push_back({mTmpCache});
647 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
648 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
649 mModelCache.pop_back();
650 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
651 if (status != ErrorStatus::GENERAL_FAILURE) {
652 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
653 }
654 ASSERT_EQ(preparedModel, nullptr);
655 }
656
657 // Test with number of model cache files smaller than mNumModelCache.
658 if (mModelCache.size() > 0) {
659 sp<IPreparedModel> preparedModel = nullptr;
660 ErrorStatus status;
661 hidl_vec<hidl_handle> modelCache, dataCache;
662 auto tmp = mModelCache.back();
663 mModelCache.pop_back();
664 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
665 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
666 mModelCache.push_back(tmp);
667 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
668 if (status != ErrorStatus::GENERAL_FAILURE) {
669 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
670 }
671 ASSERT_EQ(preparedModel, nullptr);
672 }
673
674 // Test with number of data cache files greater than mNumDataCache.
675 {
676 sp<IPreparedModel> preparedModel = nullptr;
677 ErrorStatus status;
678 hidl_vec<hidl_handle> modelCache, dataCache;
679 mDataCache.push_back({mTmpCache});
680 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
681 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
682 mDataCache.pop_back();
683 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
684 if (status != ErrorStatus::GENERAL_FAILURE) {
685 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
686 }
687 ASSERT_EQ(preparedModel, nullptr);
688 }
689
690 // Test with number of data cache files smaller than mNumDataCache.
691 if (mDataCache.size() > 0) {
692 sp<IPreparedModel> preparedModel = nullptr;
693 ErrorStatus status;
694 hidl_vec<hidl_handle> modelCache, dataCache;
695 auto tmp = mDataCache.back();
696 mDataCache.pop_back();
697 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
698 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
699 mDataCache.push_back(tmp);
700 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
701 if (status != ErrorStatus::GENERAL_FAILURE) {
702 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
703 }
704 ASSERT_EQ(preparedModel, nullptr);
705 }
706 }
707
TEST_P(CompilationCachingTest,SaveToCacheInvalidNumFd)708 TEST_P(CompilationCachingTest, SaveToCacheInvalidNumFd) {
709 // Create test HIDL model and compile.
710 const TestModel& testModel = createTestModel();
711 const Model model = createModel(testModel);
712 if (checkEarlyTermination(model)) return;
713
714 // Go through each handle in model cache, test with NumFd greater than 1.
715 for (uint32_t i = 0; i < mNumModelCache; i++) {
716 hidl_vec<hidl_handle> modelCache, dataCache;
717 // Pass an invalid number of fds for handle i.
718 mModelCache[i].push_back(mTmpCache);
719 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
720 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
721 mModelCache[i].pop_back();
722 sp<IPreparedModel> preparedModel = nullptr;
723 saveModelToCache(model, modelCache, dataCache, &preparedModel);
724 ASSERT_NE(preparedModel, nullptr);
725 // Execute and verify results.
726 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
727 // Check if prepareModelFromCache fails.
728 preparedModel = nullptr;
729 ErrorStatus status;
730 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
731 if (status != ErrorStatus::INVALID_ARGUMENT) {
732 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
733 }
734 ASSERT_EQ(preparedModel, nullptr);
735 }
736
737 // Go through each handle in model cache, test with NumFd equal to 0.
738 for (uint32_t i = 0; i < mNumModelCache; i++) {
739 hidl_vec<hidl_handle> modelCache, dataCache;
740 // Pass an invalid number of fds for handle i.
741 auto tmp = mModelCache[i].back();
742 mModelCache[i].pop_back();
743 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
744 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
745 mModelCache[i].push_back(tmp);
746 sp<IPreparedModel> preparedModel = nullptr;
747 saveModelToCache(model, modelCache, dataCache, &preparedModel);
748 ASSERT_NE(preparedModel, nullptr);
749 // Execute and verify results.
750 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
751 // Check if prepareModelFromCache fails.
752 preparedModel = nullptr;
753 ErrorStatus status;
754 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
755 if (status != ErrorStatus::INVALID_ARGUMENT) {
756 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
757 }
758 ASSERT_EQ(preparedModel, nullptr);
759 }
760
761 // Go through each handle in data cache, test with NumFd greater than 1.
762 for (uint32_t i = 0; i < mNumDataCache; i++) {
763 hidl_vec<hidl_handle> modelCache, dataCache;
764 // Pass an invalid number of fds for handle i.
765 mDataCache[i].push_back(mTmpCache);
766 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
767 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
768 mDataCache[i].pop_back();
769 sp<IPreparedModel> preparedModel = nullptr;
770 saveModelToCache(model, modelCache, dataCache, &preparedModel);
771 ASSERT_NE(preparedModel, nullptr);
772 // Execute and verify results.
773 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
774 // Check if prepareModelFromCache fails.
775 preparedModel = nullptr;
776 ErrorStatus status;
777 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
778 if (status != ErrorStatus::INVALID_ARGUMENT) {
779 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
780 }
781 ASSERT_EQ(preparedModel, nullptr);
782 }
783
784 // Go through each handle in data cache, test with NumFd equal to 0.
785 for (uint32_t i = 0; i < mNumDataCache; i++) {
786 hidl_vec<hidl_handle> modelCache, dataCache;
787 // Pass an invalid number of fds for handle i.
788 auto tmp = mDataCache[i].back();
789 mDataCache[i].pop_back();
790 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
791 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
792 mDataCache[i].push_back(tmp);
793 sp<IPreparedModel> preparedModel = nullptr;
794 saveModelToCache(model, modelCache, dataCache, &preparedModel);
795 ASSERT_NE(preparedModel, nullptr);
796 // Execute and verify results.
797 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
798 // Check if prepareModelFromCache fails.
799 preparedModel = nullptr;
800 ErrorStatus status;
801 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
802 if (status != ErrorStatus::INVALID_ARGUMENT) {
803 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
804 }
805 ASSERT_EQ(preparedModel, nullptr);
806 }
807 }
808
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidNumFd)809 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidNumFd) {
810 // Create test HIDL model and compile.
811 const TestModel& testModel = createTestModel();
812 const Model model = createModel(testModel);
813 if (checkEarlyTermination(model)) return;
814
815 // Save the compilation to cache.
816 {
817 hidl_vec<hidl_handle> modelCache, dataCache;
818 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
819 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
820 saveModelToCache(model, modelCache, dataCache);
821 }
822
823 // Go through each handle in model cache, test with NumFd greater than 1.
824 for (uint32_t i = 0; i < mNumModelCache; i++) {
825 sp<IPreparedModel> preparedModel = nullptr;
826 ErrorStatus status;
827 hidl_vec<hidl_handle> modelCache, dataCache;
828 mModelCache[i].push_back(mTmpCache);
829 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
830 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
831 mModelCache[i].pop_back();
832 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
833 if (status != ErrorStatus::GENERAL_FAILURE) {
834 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
835 }
836 ASSERT_EQ(preparedModel, nullptr);
837 }
838
839 // Go through each handle in model cache, test with NumFd equal to 0.
840 for (uint32_t i = 0; i < mNumModelCache; i++) {
841 sp<IPreparedModel> preparedModel = nullptr;
842 ErrorStatus status;
843 hidl_vec<hidl_handle> modelCache, dataCache;
844 auto tmp = mModelCache[i].back();
845 mModelCache[i].pop_back();
846 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
847 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
848 mModelCache[i].push_back(tmp);
849 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
850 if (status != ErrorStatus::GENERAL_FAILURE) {
851 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
852 }
853 ASSERT_EQ(preparedModel, nullptr);
854 }
855
856 // Go through each handle in data cache, test with NumFd greater than 1.
857 for (uint32_t i = 0; i < mNumDataCache; i++) {
858 sp<IPreparedModel> preparedModel = nullptr;
859 ErrorStatus status;
860 hidl_vec<hidl_handle> modelCache, dataCache;
861 mDataCache[i].push_back(mTmpCache);
862 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
863 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
864 mDataCache[i].pop_back();
865 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
866 if (status != ErrorStatus::GENERAL_FAILURE) {
867 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
868 }
869 ASSERT_EQ(preparedModel, nullptr);
870 }
871
872 // Go through each handle in data cache, test with NumFd equal to 0.
873 for (uint32_t i = 0; i < mNumDataCache; i++) {
874 sp<IPreparedModel> preparedModel = nullptr;
875 ErrorStatus status;
876 hidl_vec<hidl_handle> modelCache, dataCache;
877 auto tmp = mDataCache[i].back();
878 mDataCache[i].pop_back();
879 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
880 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
881 mDataCache[i].push_back(tmp);
882 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
883 if (status != ErrorStatus::GENERAL_FAILURE) {
884 ASSERT_EQ(status, ErrorStatus::INVALID_ARGUMENT);
885 }
886 ASSERT_EQ(preparedModel, nullptr);
887 }
888 }
889
TEST_P(CompilationCachingTest,SaveToCacheInvalidAccessMode)890 TEST_P(CompilationCachingTest, SaveToCacheInvalidAccessMode) {
891 // Create test HIDL model and compile.
892 const TestModel& testModel = createTestModel();
893 const Model model = createModel(testModel);
894 if (checkEarlyTermination(model)) return;
895 std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
896 std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
897
898 // Go through each handle in model cache, test with invalid access mode.
899 for (uint32_t i = 0; i < mNumModelCache; i++) {
900 hidl_vec<hidl_handle> modelCache, dataCache;
901 modelCacheMode[i] = AccessMode::READ_ONLY;
902 createCacheHandles(mModelCache, modelCacheMode, &modelCache);
903 createCacheHandles(mDataCache, dataCacheMode, &dataCache);
904 modelCacheMode[i] = AccessMode::READ_WRITE;
905 sp<IPreparedModel> preparedModel = nullptr;
906 saveModelToCache(model, modelCache, dataCache, &preparedModel);
907 ASSERT_NE(preparedModel, nullptr);
908 // Execute and verify results.
909 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
910 // Check if prepareModelFromCache fails.
911 preparedModel = nullptr;
912 ErrorStatus status;
913 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
914 if (status != ErrorStatus::INVALID_ARGUMENT) {
915 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
916 }
917 ASSERT_EQ(preparedModel, nullptr);
918 }
919
920 // Go through each handle in data cache, test with invalid access mode.
921 for (uint32_t i = 0; i < mNumDataCache; i++) {
922 hidl_vec<hidl_handle> modelCache, dataCache;
923 dataCacheMode[i] = AccessMode::READ_ONLY;
924 createCacheHandles(mModelCache, modelCacheMode, &modelCache);
925 createCacheHandles(mDataCache, dataCacheMode, &dataCache);
926 dataCacheMode[i] = AccessMode::READ_WRITE;
927 sp<IPreparedModel> preparedModel = nullptr;
928 saveModelToCache(model, modelCache, dataCache, &preparedModel);
929 ASSERT_NE(preparedModel, nullptr);
930 // Execute and verify results.
931 EvaluatePreparedModel(kDevice, preparedModel, testModel, /*testKind=*/TestKind::GENERAL);
932 // Check if prepareModelFromCache fails.
933 preparedModel = nullptr;
934 ErrorStatus status;
935 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
936 if (status != ErrorStatus::INVALID_ARGUMENT) {
937 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
938 }
939 ASSERT_EQ(preparedModel, nullptr);
940 }
941 }
942
TEST_P(CompilationCachingTest,PrepareModelFromCacheInvalidAccessMode)943 TEST_P(CompilationCachingTest, PrepareModelFromCacheInvalidAccessMode) {
944 // Create test HIDL model and compile.
945 const TestModel& testModel = createTestModel();
946 const Model model = createModel(testModel);
947 if (checkEarlyTermination(model)) return;
948 std::vector<AccessMode> modelCacheMode(mNumModelCache, AccessMode::READ_WRITE);
949 std::vector<AccessMode> dataCacheMode(mNumDataCache, AccessMode::READ_WRITE);
950
951 // Save the compilation to cache.
952 {
953 hidl_vec<hidl_handle> modelCache, dataCache;
954 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
955 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
956 saveModelToCache(model, modelCache, dataCache);
957 }
958
959 // Go through each handle in model cache, test with invalid access mode.
960 for (uint32_t i = 0; i < mNumModelCache; i++) {
961 sp<IPreparedModel> preparedModel = nullptr;
962 ErrorStatus status;
963 hidl_vec<hidl_handle> modelCache, dataCache;
964 modelCacheMode[i] = AccessMode::WRITE_ONLY;
965 createCacheHandles(mModelCache, modelCacheMode, &modelCache);
966 createCacheHandles(mDataCache, dataCacheMode, &dataCache);
967 modelCacheMode[i] = AccessMode::READ_WRITE;
968 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
969 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
970 ASSERT_EQ(preparedModel, nullptr);
971 }
972
973 // Go through each handle in data cache, test with invalid access mode.
974 for (uint32_t i = 0; i < mNumDataCache; i++) {
975 sp<IPreparedModel> preparedModel = nullptr;
976 ErrorStatus status;
977 hidl_vec<hidl_handle> modelCache, dataCache;
978 dataCacheMode[i] = AccessMode::WRITE_ONLY;
979 createCacheHandles(mModelCache, modelCacheMode, &modelCache);
980 createCacheHandles(mDataCache, dataCacheMode, &dataCache);
981 dataCacheMode[i] = AccessMode::READ_WRITE;
982 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
983 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
984 ASSERT_EQ(preparedModel, nullptr);
985 }
986 }
987
988 // Copy file contents between file groups.
989 // The outer vector corresponds to handles and the inner vector is for fds held by each handle.
990 // The outer vector sizes must match and the inner vectors must have size = 1.
copyCacheFiles(const std::vector<std::vector<std::string>> & from,const std::vector<std::vector<std::string>> & to)991 static void copyCacheFiles(const std::vector<std::vector<std::string>>& from,
992 const std::vector<std::vector<std::string>>& to) {
993 constexpr size_t kBufferSize = 1000000;
994 uint8_t buffer[kBufferSize];
995
996 ASSERT_EQ(from.size(), to.size());
997 for (uint32_t i = 0; i < from.size(); i++) {
998 ASSERT_EQ(from[i].size(), 1u);
999 ASSERT_EQ(to[i].size(), 1u);
1000 int fromFd = open(from[i][0].c_str(), O_RDONLY);
1001 int toFd = open(to[i][0].c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
1002 ASSERT_GE(fromFd, 0);
1003 ASSERT_GE(toFd, 0);
1004
1005 ssize_t readBytes;
1006 while ((readBytes = read(fromFd, &buffer, kBufferSize)) > 0) {
1007 ASSERT_EQ(write(toFd, &buffer, readBytes), readBytes);
1008 }
1009 ASSERT_GE(readBytes, 0);
1010
1011 close(fromFd);
1012 close(toFd);
1013 }
1014 }
1015
1016 // Number of operations in the large test model.
1017 constexpr uint32_t kLargeModelSize = 100;
1018 constexpr uint32_t kNumIterationsTOCTOU = 100;
1019
TEST_P(CompilationCachingTest,SaveToCache_TOCTOU)1020 TEST_P(CompilationCachingTest, SaveToCache_TOCTOU) {
1021 if (!mIsCachingSupported) return;
1022
1023 // Create test models and check if fully supported by the service.
1024 const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1025 const Model modelMul = createModel(testModelMul);
1026 if (checkEarlyTermination(modelMul)) return;
1027 const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1028 const Model modelAdd = createModel(testModelAdd);
1029 if (checkEarlyTermination(modelAdd)) return;
1030
1031 // Save the modelMul compilation to cache.
1032 auto modelCacheMul = mModelCache;
1033 for (auto& cache : modelCacheMul) {
1034 cache[0].append("_mul");
1035 }
1036 {
1037 hidl_vec<hidl_handle> modelCache, dataCache;
1038 createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1039 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1040 saveModelToCache(modelMul, modelCache, dataCache);
1041 }
1042
1043 // Use a different token for modelAdd.
1044 mToken[0]++;
1045
1046 // This test is probabilistic, so we run it multiple times.
1047 for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1048 // Save the modelAdd compilation to cache.
1049 {
1050 hidl_vec<hidl_handle> modelCache, dataCache;
1051 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1052 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1053
1054 // Spawn a thread to copy the cache content concurrently while saving to cache.
1055 std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1056 saveModelToCache(modelAdd, modelCache, dataCache);
1057 thread.join();
1058 }
1059
1060 // Retrieve preparedModel from cache.
1061 {
1062 sp<IPreparedModel> preparedModel = nullptr;
1063 ErrorStatus status;
1064 hidl_vec<hidl_handle> modelCache, dataCache;
1065 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1066 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1067 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1068
1069 // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1070 // the prepared model must be executed with the correct result and not crash.
1071 if (status != ErrorStatus::NONE) {
1072 ASSERT_EQ(preparedModel, nullptr);
1073 } else {
1074 ASSERT_NE(preparedModel, nullptr);
1075 EvaluatePreparedModel(kDevice, preparedModel, testModelAdd,
1076 /*testKind=*/TestKind::GENERAL);
1077 }
1078 }
1079 }
1080 }
1081
TEST_P(CompilationCachingTest,PrepareFromCache_TOCTOU)1082 TEST_P(CompilationCachingTest, PrepareFromCache_TOCTOU) {
1083 if (!mIsCachingSupported) return;
1084
1085 // Create test models and check if fully supported by the service.
1086 const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1087 const Model modelMul = createModel(testModelMul);
1088 if (checkEarlyTermination(modelMul)) return;
1089 const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1090 const Model modelAdd = createModel(testModelAdd);
1091 if (checkEarlyTermination(modelAdd)) return;
1092
1093 // Save the modelMul compilation to cache.
1094 auto modelCacheMul = mModelCache;
1095 for (auto& cache : modelCacheMul) {
1096 cache[0].append("_mul");
1097 }
1098 {
1099 hidl_vec<hidl_handle> modelCache, dataCache;
1100 createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1101 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1102 saveModelToCache(modelMul, modelCache, dataCache);
1103 }
1104
1105 // Use a different token for modelAdd.
1106 mToken[0]++;
1107
1108 // This test is probabilistic, so we run it multiple times.
1109 for (uint32_t i = 0; i < kNumIterationsTOCTOU; i++) {
1110 // Save the modelAdd compilation to cache.
1111 {
1112 hidl_vec<hidl_handle> modelCache, dataCache;
1113 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1114 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1115 saveModelToCache(modelAdd, modelCache, dataCache);
1116 }
1117
1118 // Retrieve preparedModel from cache.
1119 {
1120 sp<IPreparedModel> preparedModel = nullptr;
1121 ErrorStatus status;
1122 hidl_vec<hidl_handle> modelCache, dataCache;
1123 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1124 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1125
1126 // Spawn a thread to copy the cache content concurrently while preparing from cache.
1127 std::thread thread(copyCacheFiles, std::cref(modelCacheMul), std::cref(mModelCache));
1128 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1129 thread.join();
1130
1131 // The preparation may fail or succeed, but must not crash. If the preparation succeeds,
1132 // the prepared model must be executed with the correct result and not crash.
1133 if (status != ErrorStatus::NONE) {
1134 ASSERT_EQ(preparedModel, nullptr);
1135 } else {
1136 ASSERT_NE(preparedModel, nullptr);
1137 EvaluatePreparedModel(kDevice, preparedModel, testModelAdd,
1138 /*testKind=*/TestKind::GENERAL);
1139 }
1140 }
1141 }
1142 }
1143
TEST_P(CompilationCachingTest,ReplaceSecuritySensitiveCache)1144 TEST_P(CompilationCachingTest, ReplaceSecuritySensitiveCache) {
1145 if (!mIsCachingSupported) return;
1146
1147 // Create test models and check if fully supported by the service.
1148 const TestModel testModelMul = createLargeTestModel(OperationType::MUL, kLargeModelSize);
1149 const Model modelMul = createModel(testModelMul);
1150 if (checkEarlyTermination(modelMul)) return;
1151 const TestModel testModelAdd = createLargeTestModel(OperationType::ADD, kLargeModelSize);
1152 const Model modelAdd = createModel(testModelAdd);
1153 if (checkEarlyTermination(modelAdd)) return;
1154
1155 // Save the modelMul compilation to cache.
1156 auto modelCacheMul = mModelCache;
1157 for (auto& cache : modelCacheMul) {
1158 cache[0].append("_mul");
1159 }
1160 {
1161 hidl_vec<hidl_handle> modelCache, dataCache;
1162 createCacheHandles(modelCacheMul, AccessMode::READ_WRITE, &modelCache);
1163 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1164 saveModelToCache(modelMul, modelCache, dataCache);
1165 }
1166
1167 // Use a different token for modelAdd.
1168 mToken[0]++;
1169
1170 // Save the modelAdd compilation to cache.
1171 {
1172 hidl_vec<hidl_handle> modelCache, dataCache;
1173 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1174 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1175 saveModelToCache(modelAdd, modelCache, dataCache);
1176 }
1177
1178 // Replace the model cache of modelAdd with modelMul.
1179 copyCacheFiles(modelCacheMul, mModelCache);
1180
1181 // Retrieve the preparedModel from cache, expect failure.
1182 {
1183 sp<IPreparedModel> preparedModel = nullptr;
1184 ErrorStatus status;
1185 hidl_vec<hidl_handle> modelCache, dataCache;
1186 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1187 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1188 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1189 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1190 ASSERT_EQ(preparedModel, nullptr);
1191 }
1192 }
1193
1194 static const auto kNamedDeviceChoices = testing::ValuesIn(getNamedDevices());
1195 static const auto kOperandTypeChoices =
1196 testing::Values(OperandType::TENSOR_FLOAT32, OperandType::TENSOR_QUANT8_ASYMM);
1197
printCompilationCachingTest(const testing::TestParamInfo<CompilationCachingTestParam> & info)1198 std::string printCompilationCachingTest(
1199 const testing::TestParamInfo<CompilationCachingTestParam>& info) {
1200 const auto& [namedDevice, operandType] = info.param;
1201 const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1202 return gtestCompliantName(getName(namedDevice) + "_" + type);
1203 }
1204
1205 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CompilationCachingTest);
1206 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, CompilationCachingTest,
1207 testing::Combine(kNamedDeviceChoices, kOperandTypeChoices),
1208 printCompilationCachingTest);
1209
1210 using CompilationCachingSecurityTestParam = std::tuple<NamedDevice, OperandType, uint32_t>;
1211
1212 class CompilationCachingSecurityTest
1213 : public CompilationCachingTestBase,
1214 public testing::WithParamInterface<CompilationCachingSecurityTestParam> {
1215 protected:
CompilationCachingSecurityTest()1216 CompilationCachingSecurityTest()
1217 : CompilationCachingTestBase(getData(std::get<NamedDevice>(GetParam())),
1218 std::get<OperandType>(GetParam())) {}
1219
SetUp()1220 void SetUp() {
1221 CompilationCachingTestBase::SetUp();
1222 generator.seed(kSeed);
1223 }
1224
1225 // Get a random integer within a closed range [lower, upper].
1226 template <typename T>
getRandomInt(T lower,T upper)1227 T getRandomInt(T lower, T upper) {
1228 std::uniform_int_distribution<T> dis(lower, upper);
1229 return dis(generator);
1230 }
1231
1232 // Randomly flip one single bit of the cache entry.
flipOneBitOfCache(const std::string & filename,bool * skip)1233 void flipOneBitOfCache(const std::string& filename, bool* skip) {
1234 FILE* pFile = fopen(filename.c_str(), "r+");
1235 ASSERT_EQ(fseek(pFile, 0, SEEK_END), 0);
1236 long int fileSize = ftell(pFile);
1237 if (fileSize == 0) {
1238 fclose(pFile);
1239 *skip = true;
1240 return;
1241 }
1242 ASSERT_EQ(fseek(pFile, getRandomInt(0l, fileSize - 1), SEEK_SET), 0);
1243 int readByte = fgetc(pFile);
1244 ASSERT_NE(readByte, EOF);
1245 ASSERT_EQ(fseek(pFile, -1, SEEK_CUR), 0);
1246 ASSERT_NE(fputc(static_cast<uint8_t>(readByte) ^ (1U << getRandomInt(0, 7)), pFile), EOF);
1247 fclose(pFile);
1248 *skip = false;
1249 }
1250
1251 // Randomly append bytes to the cache entry.
appendBytesToCache(const std::string & filename,bool * skip)1252 void appendBytesToCache(const std::string& filename, bool* skip) {
1253 FILE* pFile = fopen(filename.c_str(), "a");
1254 uint32_t appendLength = getRandomInt(1, 256);
1255 for (uint32_t i = 0; i < appendLength; i++) {
1256 ASSERT_NE(fputc(getRandomInt<uint8_t>(0, 255), pFile), EOF);
1257 }
1258 fclose(pFile);
1259 *skip = false;
1260 }
1261
1262 enum class ExpectedResult { GENERAL_FAILURE, NOT_CRASH };
1263
1264 // Test if the driver behaves as expected when given corrupted cache or token.
1265 // The modifier will be invoked after save to cache but before prepare from cache.
1266 // The modifier accepts one pointer argument "skip" as the returning value, indicating
1267 // whether the test should be skipped or not.
testCorruptedCache(ExpectedResult expected,std::function<void (bool *)> modifier)1268 void testCorruptedCache(ExpectedResult expected, std::function<void(bool*)> modifier) {
1269 const TestModel& testModel = createTestModel();
1270 const Model model = createModel(testModel);
1271 if (checkEarlyTermination(model)) return;
1272
1273 // Save the compilation to cache.
1274 {
1275 hidl_vec<hidl_handle> modelCache, dataCache;
1276 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1277 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1278 saveModelToCache(model, modelCache, dataCache);
1279 }
1280
1281 bool skip = false;
1282 modifier(&skip);
1283 if (skip) return;
1284
1285 // Retrieve preparedModel from cache.
1286 {
1287 sp<IPreparedModel> preparedModel = nullptr;
1288 ErrorStatus status;
1289 hidl_vec<hidl_handle> modelCache, dataCache;
1290 createCacheHandles(mModelCache, AccessMode::READ_WRITE, &modelCache);
1291 createCacheHandles(mDataCache, AccessMode::READ_WRITE, &dataCache);
1292 prepareModelFromCache(modelCache, dataCache, &preparedModel, &status);
1293
1294 switch (expected) {
1295 case ExpectedResult::GENERAL_FAILURE:
1296 ASSERT_EQ(status, ErrorStatus::GENERAL_FAILURE);
1297 ASSERT_EQ(preparedModel, nullptr);
1298 break;
1299 case ExpectedResult::NOT_CRASH:
1300 ASSERT_EQ(preparedModel == nullptr, status != ErrorStatus::NONE);
1301 break;
1302 default:
1303 FAIL();
1304 }
1305 }
1306 }
1307
1308 const uint32_t kSeed = std::get<uint32_t>(GetParam());
1309 std::mt19937 generator;
1310 };
1311
TEST_P(CompilationCachingSecurityTest,CorruptedModelCache)1312 TEST_P(CompilationCachingSecurityTest, CorruptedModelCache) {
1313 if (!mIsCachingSupported) return;
1314 for (uint32_t i = 0; i < mNumModelCache; i++) {
1315 testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1316 [this, i](bool* skip) { flipOneBitOfCache(mModelCache[i][0], skip); });
1317 }
1318 }
1319
TEST_P(CompilationCachingSecurityTest,WrongLengthModelCache)1320 TEST_P(CompilationCachingSecurityTest, WrongLengthModelCache) {
1321 if (!mIsCachingSupported) return;
1322 for (uint32_t i = 0; i < mNumModelCache; i++) {
1323 testCorruptedCache(ExpectedResult::GENERAL_FAILURE,
1324 [this, i](bool* skip) { appendBytesToCache(mModelCache[i][0], skip); });
1325 }
1326 }
1327
TEST_P(CompilationCachingSecurityTest,CorruptedDataCache)1328 TEST_P(CompilationCachingSecurityTest, CorruptedDataCache) {
1329 if (!mIsCachingSupported) return;
1330 for (uint32_t i = 0; i < mNumDataCache; i++) {
1331 testCorruptedCache(ExpectedResult::NOT_CRASH,
1332 [this, i](bool* skip) { flipOneBitOfCache(mDataCache[i][0], skip); });
1333 }
1334 }
1335
TEST_P(CompilationCachingSecurityTest,WrongLengthDataCache)1336 TEST_P(CompilationCachingSecurityTest, WrongLengthDataCache) {
1337 if (!mIsCachingSupported) return;
1338 for (uint32_t i = 0; i < mNumDataCache; i++) {
1339 testCorruptedCache(ExpectedResult::NOT_CRASH,
1340 [this, i](bool* skip) { appendBytesToCache(mDataCache[i][0], skip); });
1341 }
1342 }
1343
TEST_P(CompilationCachingSecurityTest,WrongToken)1344 TEST_P(CompilationCachingSecurityTest, WrongToken) {
1345 if (!mIsCachingSupported) return;
1346 testCorruptedCache(ExpectedResult::GENERAL_FAILURE, [this](bool* skip) {
1347 // Randomly flip one single bit in mToken.
1348 uint32_t ind =
1349 getRandomInt(0u, static_cast<uint32_t>(Constant::BYTE_SIZE_OF_CACHE_TOKEN) - 1);
1350 mToken[ind] ^= (1U << getRandomInt(0, 7));
1351 *skip = false;
1352 });
1353 }
1354
printCompilationCachingSecurityTest(const testing::TestParamInfo<CompilationCachingSecurityTestParam> & info)1355 std::string printCompilationCachingSecurityTest(
1356 const testing::TestParamInfo<CompilationCachingSecurityTestParam>& info) {
1357 const auto& [namedDevice, operandType, seed] = info.param;
1358 const std::string type = (operandType == OperandType::TENSOR_FLOAT32 ? "float32" : "quant8");
1359 return gtestCompliantName(getName(namedDevice) + "_" + type + "_" + std::to_string(seed));
1360 }
1361
1362 GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(CompilationCachingSecurityTest);
1363 INSTANTIATE_TEST_SUITE_P(TestCompilationCaching, CompilationCachingSecurityTest,
1364 testing::Combine(kNamedDeviceChoices, kOperandTypeChoices,
1365 testing::Range(0U, 10U)),
1366 printCompilationCachingSecurityTest);
1367
1368 } // namespace android::hardware::neuralnetworks::V1_3::vts::functional
1369