1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "MockBuffer.h"
18 #include "MockDevice.h"
19 #include "MockPreparedModel.h"
20 
21 #include <aidl/android/hardware/neuralnetworks/BnDevice.h>
22 #include <android/binder_auto_utils.h>
23 #include <android/binder_status.h>
24 #include <gmock/gmock.h>
25 #include <gtest/gtest.h>
26 #include <nnapi/IDevice.h>
27 #include <nnapi/TypeUtils.h>
28 #include <nnapi/Types.h>
29 #include <nnapi/hal/aidl/Device.h>
30 
31 #include <functional>
32 #include <memory>
33 #include <string>
34 
35 namespace aidl::android::hardware::neuralnetworks::utils {
36 namespace {
37 
38 namespace nn = ::android::nn;
39 using ::testing::_;
40 using ::testing::DoAll;
41 using ::testing::Invoke;
42 using ::testing::InvokeWithoutArgs;
43 using ::testing::SetArgPointee;
44 
45 const nn::Model kSimpleModel = {
46         .main = {.operands = {{.type = nn::OperandType::TENSOR_FLOAT32,
47                                .dimensions = {1},
48                                .lifetime = nn::Operand::LifeTime::SUBGRAPH_INPUT},
49                               {.type = nn::OperandType::TENSOR_FLOAT32,
50                                .dimensions = {1},
51                                .lifetime = nn::Operand::LifeTime::SUBGRAPH_OUTPUT}},
52                  .operations = {{.type = nn::OperationType::RELU, .inputs = {0}, .outputs = {1}}},
53                  .inputIndexes = {0},
54                  .outputIndexes = {1}}};
55 
56 const std::string kName = "Google-MockV1";
57 const std::string kInvalidName = "";
58 const std::shared_ptr<BnDevice> kInvalidDevice;
59 constexpr PerformanceInfo kNoPerformanceInfo = {.execTime = std::numeric_limits<float>::max(),
60                                                 .powerUsage = std::numeric_limits<float>::max()};
61 constexpr NumberOfCacheFiles kNumberOfCacheFiles = {.numModelCache = nn::kMaxNumberOfCacheFiles - 1,
62                                                     .numDataCache = nn::kMaxNumberOfCacheFiles};
63 
__anonacec0d670202null64 constexpr auto makeStatusOk = [] { return ndk::ScopedAStatus::ok(); };
65 
createMockDevice()66 std::shared_ptr<MockDevice> createMockDevice() {
67     const auto mockDevice = MockDevice::create();
68 
69     // Setup default actions for each relevant call.
70     ON_CALL(*mockDevice, getVersionString(_))
71             .WillByDefault(DoAll(SetArgPointee<0>(kName), InvokeWithoutArgs(makeStatusOk)));
72     ON_CALL(*mockDevice, getType(_))
73             .WillByDefault(
74                     DoAll(SetArgPointee<0>(DeviceType::OTHER), InvokeWithoutArgs(makeStatusOk)));
75     ON_CALL(*mockDevice, getSupportedExtensions(_))
76             .WillByDefault(DoAll(SetArgPointee<0>(std::vector<Extension>{}),
77                                  InvokeWithoutArgs(makeStatusOk)));
78     ON_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
79             .WillByDefault(
80                     DoAll(SetArgPointee<0>(kNumberOfCacheFiles), InvokeWithoutArgs(makeStatusOk)));
81     ON_CALL(*mockDevice, getCapabilities(_))
82             .WillByDefault(
83                     DoAll(SetArgPointee<0>(Capabilities{
84                                   .relaxedFloat32toFloat16PerformanceScalar = kNoPerformanceInfo,
85                                   .relaxedFloat32toFloat16PerformanceTensor = kNoPerformanceInfo,
86                                   .ifPerformance = kNoPerformanceInfo,
87                                   .whilePerformance = kNoPerformanceInfo,
88                           }),
89                           InvokeWithoutArgs(makeStatusOk)));
90 
91     // These EXPECT_CALL(...).Times(testing::AnyNumber()) calls are to suppress warnings on the
92     // uninteresting methods calls.
93     EXPECT_CALL(*mockDevice, getVersionString(_)).Times(testing::AnyNumber());
94     EXPECT_CALL(*mockDevice, getType(_)).Times(testing::AnyNumber());
95     EXPECT_CALL(*mockDevice, getSupportedExtensions(_)).Times(testing::AnyNumber());
96     EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_)).Times(testing::AnyNumber());
97     EXPECT_CALL(*mockDevice, getCapabilities(_)).Times(testing::AnyNumber());
98 
99     return mockDevice;
100 }
101 
102 constexpr auto makePreparedModelReturnImpl =
103         [](ErrorStatus launchStatus, ErrorStatus returnStatus,
104            const std::shared_ptr<MockPreparedModel>& preparedModel,
__anonacec0d670302(ErrorStatus launchStatus, ErrorStatus returnStatus, const std::shared_ptr<MockPreparedModel>& preparedModel, const std::shared_ptr<IPreparedModelCallback>& cb) 105            const std::shared_ptr<IPreparedModelCallback>& cb) {
106             cb->notify(returnStatus, preparedModel);
107             if (launchStatus == ErrorStatus::NONE) {
108                 return ndk::ScopedAStatus::ok();
109             }
110             return ndk::ScopedAStatus::fromServiceSpecificError(static_cast<int32_t>(launchStatus));
111         };
112 
makePreparedModelReturn(ErrorStatus launchStatus,ErrorStatus returnStatus,const std::shared_ptr<MockPreparedModel> & preparedModel)113 auto makePreparedModelReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
114                              const std::shared_ptr<MockPreparedModel>& preparedModel) {
115     return [launchStatus, returnStatus, preparedModel](
116                    const Model& /*model*/, ExecutionPreference /*preference*/,
117                    Priority /*priority*/, const int64_t& /*deadline*/,
118                    const std::vector<ndk::ScopedFileDescriptor>& /*modelCache*/,
119                    const std::vector<ndk::ScopedFileDescriptor>& /*dataCache*/,
120                    const std::vector<uint8_t>& /*token*/,
121                    const std::shared_ptr<IPreparedModelCallback>& cb) -> ndk::ScopedAStatus {
122         return makePreparedModelReturnImpl(launchStatus, returnStatus, preparedModel, cb);
123     };
124 }
125 
makePreparedModelFromCacheReturn(ErrorStatus launchStatus,ErrorStatus returnStatus,const std::shared_ptr<MockPreparedModel> & preparedModel)126 auto makePreparedModelFromCacheReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
127                                       const std::shared_ptr<MockPreparedModel>& preparedModel) {
128     return [launchStatus, returnStatus, preparedModel](
129                    const int64_t& /*deadline*/,
130                    const std::vector<ndk::ScopedFileDescriptor>& /*modelCache*/,
131                    const std::vector<ndk::ScopedFileDescriptor>& /*dataCache*/,
132                    const std::vector<uint8_t>& /*token*/,
133                    const std::shared_ptr<IPreparedModelCallback>& cb) {
134         return makePreparedModelReturnImpl(launchStatus, returnStatus, preparedModel, cb);
135     };
136 }
137 
__anonacec0d670602null138 constexpr auto makeGeneralFailure = [] {
139     return ndk::ScopedAStatus::fromServiceSpecificError(
140             static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
141 };
__anonacec0d670702null142 constexpr auto makeGeneralTransportFailure = [] {
143     return ndk::ScopedAStatus::fromStatus(STATUS_NO_MEMORY);
144 };
__anonacec0d670802null145 constexpr auto makeDeadObjectFailure = [] {
146     return ndk::ScopedAStatus::fromStatus(STATUS_DEAD_OBJECT);
147 };
148 
149 }  // namespace
150 
TEST(DeviceTest,invalidName)151 TEST(DeviceTest, invalidName) {
152     // run test
153     const auto device = MockDevice::create();
154     const auto result = Device::create(kInvalidName, device);
155 
156     // verify result
157     ASSERT_FALSE(result.has_value());
158     EXPECT_EQ(result.error().code, nn::ErrorStatus::INVALID_ARGUMENT);
159 }
160 
TEST(DeviceTest,invalidDevice)161 TEST(DeviceTest, invalidDevice) {
162     // run test
163     const auto result = Device::create(kName, kInvalidDevice);
164 
165     // verify result
166     ASSERT_FALSE(result.has_value());
167     EXPECT_EQ(result.error().code, nn::ErrorStatus::INVALID_ARGUMENT);
168 }
169 
TEST(DeviceTest,getVersionStringError)170 TEST(DeviceTest, getVersionStringError) {
171     // setup call
172     const auto mockDevice = createMockDevice();
173     EXPECT_CALL(*mockDevice, getVersionString(_))
174             .Times(1)
175             .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
176 
177     // run test
178     const auto result = Device::create(kName, mockDevice);
179 
180     // verify result
181     ASSERT_FALSE(result.has_value());
182     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
183 }
184 
TEST(DeviceTest,getVersionStringTransportFailure)185 TEST(DeviceTest, getVersionStringTransportFailure) {
186     // setup call
187     const auto mockDevice = createMockDevice();
188     EXPECT_CALL(*mockDevice, getVersionString(_))
189             .Times(1)
190             .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
191 
192     // run test
193     const auto result = Device::create(kName, mockDevice);
194 
195     // verify result
196     ASSERT_FALSE(result.has_value());
197     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
198 }
199 
TEST(DeviceTest,getVersionStringDeadObject)200 TEST(DeviceTest, getVersionStringDeadObject) {
201     // setup call
202     const auto mockDevice = createMockDevice();
203     EXPECT_CALL(*mockDevice, getVersionString(_))
204             .Times(1)
205             .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
206 
207     // run test
208     const auto result = Device::create(kName, mockDevice);
209 
210     // verify result
211     ASSERT_FALSE(result.has_value());
212     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
213 }
214 
TEST(DeviceTest,getTypeError)215 TEST(DeviceTest, getTypeError) {
216     // setup call
217     const auto mockDevice = createMockDevice();
218     EXPECT_CALL(*mockDevice, getType(_)).Times(1).WillOnce(InvokeWithoutArgs(makeGeneralFailure));
219 
220     // run test
221     const auto result = Device::create(kName, mockDevice);
222 
223     // verify result
224     ASSERT_FALSE(result.has_value());
225     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
226 }
227 
TEST(DeviceTest,getTypeTransportFailure)228 TEST(DeviceTest, getTypeTransportFailure) {
229     // setup call
230     const auto mockDevice = createMockDevice();
231     EXPECT_CALL(*mockDevice, getType(_))
232             .Times(1)
233             .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
234 
235     // run test
236     const auto result = Device::create(kName, mockDevice);
237 
238     // verify result
239     ASSERT_FALSE(result.has_value());
240     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
241 }
242 
TEST(DeviceTest,getTypeDeadObject)243 TEST(DeviceTest, getTypeDeadObject) {
244     // setup call
245     const auto mockDevice = createMockDevice();
246     EXPECT_CALL(*mockDevice, getType(_))
247             .Times(1)
248             .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
249 
250     // run test
251     const auto result = Device::create(kName, mockDevice);
252 
253     // verify result
254     ASSERT_FALSE(result.has_value());
255     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
256 }
257 
TEST(DeviceTest,getSupportedExtensionsError)258 TEST(DeviceTest, getSupportedExtensionsError) {
259     // setup call
260     const auto mockDevice = createMockDevice();
261     EXPECT_CALL(*mockDevice, getSupportedExtensions(_))
262             .Times(1)
263             .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
264 
265     // run test
266     const auto result = Device::create(kName, mockDevice);
267 
268     // verify result
269     ASSERT_FALSE(result.has_value());
270     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
271 }
272 
TEST(DeviceTest,getSupportedExtensionsTransportFailure)273 TEST(DeviceTest, getSupportedExtensionsTransportFailure) {
274     // setup call
275     const auto mockDevice = createMockDevice();
276     EXPECT_CALL(*mockDevice, getSupportedExtensions(_))
277             .Times(1)
278             .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
279 
280     // run test
281     const auto result = Device::create(kName, mockDevice);
282 
283     // verify result
284     ASSERT_FALSE(result.has_value());
285     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
286 }
287 
TEST(DeviceTest,getSupportedExtensionsDeadObject)288 TEST(DeviceTest, getSupportedExtensionsDeadObject) {
289     // setup call
290     const auto mockDevice = createMockDevice();
291     EXPECT_CALL(*mockDevice, getSupportedExtensions(_))
292             .Times(1)
293             .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
294 
295     // run test
296     const auto result = Device::create(kName, mockDevice);
297 
298     // verify result
299     ASSERT_FALSE(result.has_value());
300     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
301 }
302 
TEST(DeviceTest,getNumberOfCacheFilesNeeded)303 TEST(DeviceTest, getNumberOfCacheFilesNeeded) {
304     // setup call
305     const auto mockDevice = createMockDevice();
306     EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_)).Times(1);
307 
308     // run test
309     const auto result = Device::create(kName, mockDevice);
310 
311     // verify result
312     ASSERT_TRUE(result.has_value());
313     constexpr auto kNumberOfCacheFilesPair = std::make_pair<uint32_t, uint32_t>(
314             kNumberOfCacheFiles.numModelCache, kNumberOfCacheFiles.numDataCache);
315     EXPECT_EQ(result.value()->getNumberOfCacheFilesNeeded(), kNumberOfCacheFilesPair);
316 }
317 
TEST(DeviceTest,getNumberOfCacheFilesNeededError)318 TEST(DeviceTest, getNumberOfCacheFilesNeededError) {
319     // setup call
320     const auto mockDevice = createMockDevice();
321     EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
322             .Times(1)
323             .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
324 
325     // run test
326     const auto result = Device::create(kName, mockDevice);
327 
328     // verify result
329     ASSERT_FALSE(result.has_value());
330     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
331 }
332 
TEST(DeviceTest,dataCacheFilesExceedsSpecifiedMax)333 TEST(DeviceTest, dataCacheFilesExceedsSpecifiedMax) {
334     // setup test
335     const auto mockDevice = createMockDevice();
336     EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
337             .Times(1)
338             .WillOnce(DoAll(SetArgPointee<0>(NumberOfCacheFiles{
339                                     .numModelCache = nn::kMaxNumberOfCacheFiles + 1,
340                                     .numDataCache = nn::kMaxNumberOfCacheFiles}),
341                             InvokeWithoutArgs(makeStatusOk)));
342 
343     // run test
344     const auto result = Device::create(kName, mockDevice);
345 
346     // verify result
347     ASSERT_FALSE(result.has_value());
348     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
349 }
350 
TEST(DeviceTest,modelCacheFilesExceedsSpecifiedMax)351 TEST(DeviceTest, modelCacheFilesExceedsSpecifiedMax) {
352     // setup test
353     const auto mockDevice = createMockDevice();
354     EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
355             .Times(1)
356             .WillOnce(DoAll(SetArgPointee<0>(NumberOfCacheFiles{
357                                     .numModelCache = nn::kMaxNumberOfCacheFiles,
358                                     .numDataCache = nn::kMaxNumberOfCacheFiles + 1}),
359                             InvokeWithoutArgs(makeStatusOk)));
360 
361     // run test
362     const auto result = Device::create(kName, mockDevice);
363 
364     // verify result
365     ASSERT_FALSE(result.has_value());
366     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
367 }
368 
TEST(DeviceTest,getNumberOfCacheFilesNeededTransportFailure)369 TEST(DeviceTest, getNumberOfCacheFilesNeededTransportFailure) {
370     // setup call
371     const auto mockDevice = createMockDevice();
372     EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
373             .Times(1)
374             .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
375 
376     // run test
377     const auto result = Device::create(kName, mockDevice);
378 
379     // verify result
380     ASSERT_FALSE(result.has_value());
381     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
382 }
383 
TEST(DeviceTest,getNumberOfCacheFilesNeededDeadObject)384 TEST(DeviceTest, getNumberOfCacheFilesNeededDeadObject) {
385     // setup call
386     const auto mockDevice = createMockDevice();
387     EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
388             .Times(1)
389             .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
390 
391     // run test
392     const auto result = Device::create(kName, mockDevice);
393 
394     // verify result
395     ASSERT_FALSE(result.has_value());
396     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
397 }
398 
TEST(DeviceTest,getCapabilitiesError)399 TEST(DeviceTest, getCapabilitiesError) {
400     // setup call
401     const auto mockDevice = createMockDevice();
402     EXPECT_CALL(*mockDevice, getCapabilities(_))
403             .Times(1)
404             .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
405 
406     // run test
407     const auto result = Device::create(kName, mockDevice);
408 
409     // verify result
410     ASSERT_FALSE(result.has_value());
411     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
412 }
413 
TEST(DeviceTest,getCapabilitiesTransportFailure)414 TEST(DeviceTest, getCapabilitiesTransportFailure) {
415     // setup call
416     const auto mockDevice = createMockDevice();
417     EXPECT_CALL(*mockDevice, getCapabilities(_))
418             .Times(1)
419             .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
420 
421     // run test
422     const auto result = Device::create(kName, mockDevice);
423 
424     // verify result
425     ASSERT_FALSE(result.has_value());
426     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
427 }
428 
TEST(DeviceTest,getCapabilitiesDeadObject)429 TEST(DeviceTest, getCapabilitiesDeadObject) {
430     // setup call
431     const auto mockDevice = createMockDevice();
432     EXPECT_CALL(*mockDevice, getCapabilities(_))
433             .Times(1)
434             .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
435 
436     // run test
437     const auto result = Device::create(kName, mockDevice);
438 
439     // verify result
440     ASSERT_FALSE(result.has_value());
441     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
442 }
443 
TEST(DeviceTest,getName)444 TEST(DeviceTest, getName) {
445     // setup call
446     const auto mockDevice = createMockDevice();
447     const auto device = Device::create(kName, mockDevice).value();
448 
449     // run test
450     const auto& name = device->getName();
451 
452     // verify result
453     EXPECT_EQ(name, kName);
454 }
455 
TEST(DeviceTest,getFeatureLevel)456 TEST(DeviceTest, getFeatureLevel) {
457     // setup call
458     const auto mockDevice = createMockDevice();
459     const auto device = Device::create(kName, mockDevice).value();
460 
461     // run test
462     const auto featureLevel = device->getFeatureLevel();
463 
464     // verify result
465     EXPECT_EQ(featureLevel, nn::Version::ANDROID_S);
466 }
467 
TEST(DeviceTest,getCachedData)468 TEST(DeviceTest, getCachedData) {
469     // setup call
470     const auto mockDevice = createMockDevice();
471     EXPECT_CALL(*mockDevice, getVersionString(_)).Times(1);
472     EXPECT_CALL(*mockDevice, getType(_)).Times(1);
473     EXPECT_CALL(*mockDevice, getSupportedExtensions(_)).Times(1);
474     EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_)).Times(1);
475     EXPECT_CALL(*mockDevice, getCapabilities(_)).Times(1);
476 
477     const auto result = Device::create(kName, mockDevice);
478     ASSERT_TRUE(result.has_value())
479             << "Failed with " << result.error().code << ": " << result.error().message;
480     const auto& device = result.value();
481 
482     // run test and verify results
483     EXPECT_EQ(device->getVersionString(), device->getVersionString());
484     EXPECT_EQ(device->getType(), device->getType());
485     EXPECT_EQ(device->getSupportedExtensions(), device->getSupportedExtensions());
486     EXPECT_EQ(device->getNumberOfCacheFilesNeeded(), device->getNumberOfCacheFilesNeeded());
487     EXPECT_EQ(device->getCapabilities(), device->getCapabilities());
488 }
489 
TEST(DeviceTest,getSupportedOperations)490 TEST(DeviceTest, getSupportedOperations) {
491     // setup call
492     const auto mockDevice = createMockDevice();
493     const auto device = Device::create(kName, mockDevice).value();
494     EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
495             .Times(1)
496             .WillOnce(DoAll(
497                     SetArgPointee<1>(std::vector<bool>(kSimpleModel.main.operations.size(), true)),
498                     InvokeWithoutArgs(makeStatusOk)));
499 
500     // run test
501     const auto result = device->getSupportedOperations(kSimpleModel);
502 
503     // verify result
504     ASSERT_TRUE(result.has_value())
505             << "Failed with " << result.error().code << ": " << result.error().message;
506     const auto& supportedOperations = result.value();
507     EXPECT_EQ(supportedOperations.size(), kSimpleModel.main.operations.size());
508     EXPECT_THAT(supportedOperations, Each(testing::IsTrue()));
509 }
510 
TEST(DeviceTest,getSupportedOperationsError)511 TEST(DeviceTest, getSupportedOperationsError) {
512     // setup call
513     const auto mockDevice = createMockDevice();
514     const auto device = Device::create(kName, mockDevice).value();
515     EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
516             .Times(1)
517             .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
518 
519     // run test
520     const auto result = device->getSupportedOperations(kSimpleModel);
521 
522     // verify result
523     ASSERT_FALSE(result.has_value());
524     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
525 }
526 
TEST(DeviceTest,getSupportedOperationsTransportFailure)527 TEST(DeviceTest, getSupportedOperationsTransportFailure) {
528     // setup call
529     const auto mockDevice = createMockDevice();
530     const auto device = Device::create(kName, mockDevice).value();
531     EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
532             .Times(1)
533             .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
534 
535     // run test
536     const auto result = device->getSupportedOperations(kSimpleModel);
537 
538     // verify result
539     ASSERT_FALSE(result.has_value());
540     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
541 }
542 
TEST(DeviceTest,getSupportedOperationsDeadObject)543 TEST(DeviceTest, getSupportedOperationsDeadObject) {
544     // setup call
545     const auto mockDevice = createMockDevice();
546     const auto device = Device::create(kName, mockDevice).value();
547     EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
548             .Times(1)
549             .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
550 
551     // run test
552     const auto result = device->getSupportedOperations(kSimpleModel);
553 
554     // verify result
555     ASSERT_FALSE(result.has_value());
556     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
557 }
558 
TEST(DeviceTest,prepareModel)559 TEST(DeviceTest, prepareModel) {
560     // setup call
561     const auto mockDevice = createMockDevice();
562     const auto device = Device::create(kName, mockDevice).value();
563     const auto mockPreparedModel = MockPreparedModel::create();
564     EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
565             .Times(1)
566             .WillOnce(Invoke(makePreparedModelReturn(ErrorStatus::NONE, ErrorStatus::NONE,
567                                                      mockPreparedModel)));
568 
569     // run test
570     const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
571                                              nn::Priority::DEFAULT, {}, {}, {}, {});
572 
573     // verify result
574     ASSERT_TRUE(result.has_value())
575             << "Failed with " << result.error().code << ": " << result.error().message;
576     EXPECT_NE(result.value(), nullptr);
577 }
578 
TEST(DeviceTest,prepareModelLaunchError)579 TEST(DeviceTest, prepareModelLaunchError) {
580     // setup call
581     const auto mockDevice = createMockDevice();
582     const auto device = Device::create(kName, mockDevice).value();
583     EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
584             .Times(1)
585             .WillOnce(Invoke(makePreparedModelReturn(ErrorStatus::GENERAL_FAILURE,
586                                                      ErrorStatus::GENERAL_FAILURE, nullptr)));
587 
588     // run test
589     const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
590                                              nn::Priority::DEFAULT, {}, {}, {}, {});
591 
592     // verify result
593     ASSERT_FALSE(result.has_value());
594     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
595 }
596 
TEST(DeviceTest,prepareModelReturnError)597 TEST(DeviceTest, prepareModelReturnError) {
598     // setup call
599     const auto mockDevice = createMockDevice();
600     const auto device = Device::create(kName, mockDevice).value();
601     EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
602             .Times(1)
603             .WillOnce(Invoke(makePreparedModelReturn(ErrorStatus::NONE,
604                                                      ErrorStatus::GENERAL_FAILURE, nullptr)));
605 
606     // run test
607     const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
608                                              nn::Priority::DEFAULT, {}, {}, {}, {});
609 
610     // verify result
611     ASSERT_FALSE(result.has_value());
612     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
613 }
614 
TEST(DeviceTest,prepareModelNullptrError)615 TEST(DeviceTest, prepareModelNullptrError) {
616     // setup call
617     const auto mockDevice = createMockDevice();
618     const auto device = Device::create(kName, mockDevice).value();
619     EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
620             .Times(1)
621             .WillOnce(
622                     Invoke(makePreparedModelReturn(ErrorStatus::NONE, ErrorStatus::NONE, nullptr)));
623 
624     // run test
625     const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
626                                              nn::Priority::DEFAULT, {}, {}, {}, {});
627 
628     // verify result
629     ASSERT_FALSE(result.has_value());
630     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
631 }
632 
TEST(DeviceTest,prepareModelTransportFailure)633 TEST(DeviceTest, prepareModelTransportFailure) {
634     // setup call
635     const auto mockDevice = createMockDevice();
636     const auto device = Device::create(kName, mockDevice).value();
637     EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
638             .Times(1)
639             .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
640 
641     // run test
642     const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
643                                              nn::Priority::DEFAULT, {}, {}, {}, {});
644 
645     // verify result
646     ASSERT_FALSE(result.has_value());
647     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
648 }
649 
TEST(DeviceTest,prepareModelDeadObject)650 TEST(DeviceTest, prepareModelDeadObject) {
651     // setup call
652     const auto mockDevice = createMockDevice();
653     const auto device = Device::create(kName, mockDevice).value();
654     EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
655             .Times(1)
656             .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
657 
658     // run test
659     const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
660                                              nn::Priority::DEFAULT, {}, {}, {}, {});
661 
662     // verify result
663     ASSERT_FALSE(result.has_value());
664     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
665 }
666 
TEST(DeviceTest,prepareModelAsyncCrash)667 TEST(DeviceTest, prepareModelAsyncCrash) {
668     // setup test
669     const auto mockDevice = createMockDevice();
670     const auto device = Device::create(kName, mockDevice).value();
671     const auto ret = [&device]() {
672         DeathMonitor::serviceDied(device->getDeathMonitor());
673         return ndk::ScopedAStatus::ok();
674     };
675     EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
676             .Times(1)
677             .WillOnce(InvokeWithoutArgs(ret));
678 
679     // run test
680     const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
681                                              nn::Priority::DEFAULT, {}, {}, {}, {});
682 
683     // verify result
684     ASSERT_FALSE(result.has_value());
685     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
686 }
687 
TEST(DeviceTest,prepareModelFromCache)688 TEST(DeviceTest, prepareModelFromCache) {
689     // setup call
690     const auto mockDevice = createMockDevice();
691     const auto device = Device::create(kName, mockDevice).value();
692     const auto mockPreparedModel = MockPreparedModel::create();
693     EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
694             .Times(1)
695             .WillOnce(Invoke(makePreparedModelFromCacheReturn(ErrorStatus::NONE, ErrorStatus::NONE,
696                                                               mockPreparedModel)));
697 
698     // run test
699     const auto result = device->prepareModelFromCache({}, {}, {}, {});
700 
701     // verify result
702     ASSERT_TRUE(result.has_value())
703             << "Failed with " << result.error().code << ": " << result.error().message;
704     EXPECT_NE(result.value(), nullptr);
705 }
706 
TEST(DeviceTest,prepareModelFromCacheLaunchError)707 TEST(DeviceTest, prepareModelFromCacheLaunchError) {
708     // setup call
709     const auto mockDevice = createMockDevice();
710     const auto device = Device::create(kName, mockDevice).value();
711     EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
712             .Times(1)
713             .WillOnce(Invoke(makePreparedModelFromCacheReturn(
714                     ErrorStatus::GENERAL_FAILURE, ErrorStatus::GENERAL_FAILURE, nullptr)));
715 
716     // run test
717     const auto result = device->prepareModelFromCache({}, {}, {}, {});
718 
719     // verify result
720     ASSERT_FALSE(result.has_value());
721     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
722 }
723 
TEST(DeviceTest,prepareModelFromCacheReturnError)724 TEST(DeviceTest, prepareModelFromCacheReturnError) {
725     // setup call
726     const auto mockDevice = createMockDevice();
727     const auto device = Device::create(kName, mockDevice).value();
728     EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
729             .Times(1)
730             .WillOnce(Invoke(makePreparedModelFromCacheReturn(
731                     ErrorStatus::NONE, ErrorStatus::GENERAL_FAILURE, nullptr)));
732 
733     // run test
734     const auto result = device->prepareModelFromCache({}, {}, {}, {});
735 
736     // verify result
737     ASSERT_FALSE(result.has_value());
738     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
739 }
740 
TEST(DeviceTest,prepareModelFromCacheNullptrError)741 TEST(DeviceTest, prepareModelFromCacheNullptrError) {
742     // setup call
743     const auto mockDevice = createMockDevice();
744     const auto device = Device::create(kName, mockDevice).value();
745     EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
746             .Times(1)
747             .WillOnce(Invoke(makePreparedModelFromCacheReturn(ErrorStatus::NONE, ErrorStatus::NONE,
748                                                               nullptr)));
749 
750     // run test
751     const auto result = device->prepareModelFromCache({}, {}, {}, {});
752 
753     // verify result
754     ASSERT_FALSE(result.has_value());
755     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
756 }
757 
TEST(DeviceTest,prepareModelFromCacheTransportFailure)758 TEST(DeviceTest, prepareModelFromCacheTransportFailure) {
759     // setup call
760     const auto mockDevice = createMockDevice();
761     const auto device = Device::create(kName, mockDevice).value();
762     EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
763             .Times(1)
764             .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
765 
766     // run test
767     const auto result = device->prepareModelFromCache({}, {}, {}, {});
768 
769     // verify result
770     ASSERT_FALSE(result.has_value());
771     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
772 }
773 
TEST(DeviceTest,prepareModelFromCacheDeadObject)774 TEST(DeviceTest, prepareModelFromCacheDeadObject) {
775     // setup call
776     const auto mockDevice = createMockDevice();
777     const auto device = Device::create(kName, mockDevice).value();
778     EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
779             .Times(1)
780             .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
781 
782     // run test
783     const auto result = device->prepareModelFromCache({}, {}, {}, {});
784 
785     // verify result
786     ASSERT_FALSE(result.has_value());
787     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
788 }
789 
TEST(DeviceTest,prepareModelFromCacheAsyncCrash)790 TEST(DeviceTest, prepareModelFromCacheAsyncCrash) {
791     // setup test
792     const auto mockDevice = createMockDevice();
793     const auto device = Device::create(kName, mockDevice).value();
794     const auto ret = [&device]() {
795         DeathMonitor::serviceDied(device->getDeathMonitor());
796         return ndk::ScopedAStatus::ok();
797     };
798     EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
799             .Times(1)
800             .WillOnce(InvokeWithoutArgs(ret));
801 
802     // run test
803     const auto result = device->prepareModelFromCache({}, {}, {}, {});
804 
805     // verify result
806     ASSERT_FALSE(result.has_value());
807     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
808 }
809 
TEST(DeviceTest,allocate)810 TEST(DeviceTest, allocate) {
811     // setup call
812     const auto mockDevice = createMockDevice();
813     const auto device = Device::create(kName, mockDevice).value();
814     const auto mockBuffer = DeviceBuffer{.buffer = MockBuffer::create(), .token = 1};
815     EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
816             .Times(1)
817             .WillOnce(DoAll(SetArgPointee<4>(mockBuffer), InvokeWithoutArgs(makeStatusOk)));
818 
819     // run test
820     const auto result = device->allocate({}, {}, {}, {});
821 
822     // verify result
823     ASSERT_TRUE(result.has_value())
824             << "Failed with " << result.error().code << ": " << result.error().message;
825     EXPECT_NE(result.value(), nullptr);
826 }
827 
TEST(DeviceTest,allocateError)828 TEST(DeviceTest, allocateError) {
829     // setup call
830     const auto mockDevice = createMockDevice();
831     const auto device = Device::create(kName, mockDevice).value();
832     EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
833             .Times(1)
834             .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
835 
836     // run test
837     const auto result = device->allocate({}, {}, {}, {});
838 
839     // verify result
840     ASSERT_FALSE(result.has_value());
841     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
842 }
843 
TEST(DeviceTest,allocateTransportFailure)844 TEST(DeviceTest, allocateTransportFailure) {
845     // setup call
846     const auto mockDevice = createMockDevice();
847     const auto device = Device::create(kName, mockDevice).value();
848     EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
849             .Times(1)
850             .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
851 
852     // run test
853     const auto result = device->allocate({}, {}, {}, {});
854 
855     // verify result
856     ASSERT_FALSE(result.has_value());
857     EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
858 }
859 
TEST(DeviceTest,allocateDeadObject)860 TEST(DeviceTest, allocateDeadObject) {
861     // setup call
862     const auto mockDevice = createMockDevice();
863     const auto device = Device::create(kName, mockDevice).value();
864     EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
865             .Times(1)
866             .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
867 
868     // run test
869     const auto result = device->allocate({}, {}, {}, {});
870 
871     // verify result
872     ASSERT_FALSE(result.has_value());
873     EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
874 }
875 
876 }  // namespace aidl::android::hardware::neuralnetworks::utils
877