1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_aidl_hal_test"
18 
19 #include <aidl/android/hardware/neuralnetworks/RequestMemoryPool.h>
20 #include <android/binder_auto_utils.h>
21 #include <variant>
22 
23 #include <chrono>
24 
25 #include <TestHarness.h>
26 #include <nnapi/hal/aidl/Utils.h>
27 
28 #include "Callbacks.h"
29 #include "GeneratedTestHarness.h"
30 #include "Utils.h"
31 #include "VtsHalNeuralnetworks.h"
32 
33 namespace aidl::android::hardware::neuralnetworks::vts::functional {
34 
35 using ExecutionMutation = std::function<void(Request*)>;
36 
37 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
38 
39 // Primary validation function. This function will take a valid request, apply a
40 // mutation to it to invalidate the request, then pass it to interface calls
41 // that use the request.
validate(const std::shared_ptr<IPreparedModel> & preparedModel,const std::string & message,const Request & originalRequest,const ExecutionMutation & mutate)42 static void validate(const std::shared_ptr<IPreparedModel>& preparedModel,
43                      const std::string& message, const Request& originalRequest,
44                      const ExecutionMutation& mutate) {
45     Request request = utils::clone(originalRequest).value();
46     mutate(&request);
47 
48     // We'd like to test both with timing requested and without timing
49     // requested. Rather than running each test both ways, we'll decide whether
50     // to request timing by hashing the message. We do not use std::hash because
51     // it is not guaranteed stable across executions.
52     char hash = 0;
53     for (auto c : message) {
54         hash ^= c;
55     };
56     bool measure = (hash & 1);
57 
58     // synchronous
59     {
60         SCOPED_TRACE(message + " [executeSynchronously]");
61         ExecutionResult executionResult;
62         const auto executeStatus = preparedModel->executeSynchronously(
63                 request, measure, kNoDeadline, kOmittedTimeoutDuration, &executionResult);
64         ASSERT_FALSE(executeStatus.isOk());
65         ASSERT_EQ(executeStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
66         ASSERT_EQ(static_cast<ErrorStatus>(executeStatus.getServiceSpecificError()),
67                   ErrorStatus::INVALID_ARGUMENT);
68     }
69 
70     // fenced
71     {
72         SCOPED_TRACE(message + " [executeFenced]");
73         FencedExecutionResult executionResult;
74         const auto executeStatus = preparedModel->executeFenced(request, {}, false, kNoDeadline,
75                                                                 kOmittedTimeoutDuration,
76                                                                 kNoDuration, &executionResult);
77         ASSERT_FALSE(executeStatus.isOk());
78         ASSERT_EQ(executeStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
79         ASSERT_EQ(static_cast<ErrorStatus>(executeStatus.getServiceSpecificError()),
80                   ErrorStatus::INVALID_ARGUMENT);
81     }
82 
83     // burst
84     {
85         SCOPED_TRACE(message + " [burst]");
86 
87         // create burst
88         std::shared_ptr<IBurst> burst;
89         auto ret = preparedModel->configureExecutionBurst(&burst);
90         ASSERT_TRUE(ret.isOk()) << ret.getDescription();
91         ASSERT_NE(nullptr, burst.get());
92 
93         // use -1 for all memory identifier tokens
94         const std::vector<int64_t> slots(request.pools.size(), -1);
95 
96         ExecutionResult executionResult;
97         const auto executeStatus = burst->executeSynchronously(
98                 request, slots, measure, kNoDeadline, kOmittedTimeoutDuration, &executionResult);
99         ASSERT_FALSE(executeStatus.isOk());
100         ASSERT_EQ(executeStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
101         ASSERT_EQ(static_cast<ErrorStatus>(executeStatus.getServiceSpecificError()),
102                   ErrorStatus::INVALID_ARGUMENT);
103     }
104 }
105 
createBurst(const std::shared_ptr<IPreparedModel> & preparedModel)106 std::shared_ptr<IBurst> createBurst(const std::shared_ptr<IPreparedModel>& preparedModel) {
107     std::shared_ptr<IBurst> burst;
108     const auto ret = preparedModel->configureExecutionBurst(&burst);
109     if (!ret.isOk()) return nullptr;
110     return burst;
111 }
112 
113 ///////////////////////// REMOVE INPUT ////////////////////////////////////
114 
removeInputTest(const std::shared_ptr<IPreparedModel> & preparedModel,const Request & request)115 static void removeInputTest(const std::shared_ptr<IPreparedModel>& preparedModel,
116                             const Request& request) {
117     for (size_t input = 0; input < request.inputs.size(); ++input) {
118         const std::string message = "removeInput: removed input " + std::to_string(input);
119         validate(preparedModel, message, request, [input](Request* request) {
120             request->inputs.erase(request->inputs.begin() + input);
121         });
122     }
123 }
124 
125 ///////////////////////// REMOVE OUTPUT ////////////////////////////////////
126 
removeOutputTest(const std::shared_ptr<IPreparedModel> & preparedModel,const Request & request)127 static void removeOutputTest(const std::shared_ptr<IPreparedModel>& preparedModel,
128                              const Request& request) {
129     for (size_t output = 0; output < request.outputs.size(); ++output) {
130         const std::string message = "removeOutput: removed Output " + std::to_string(output);
131         validate(preparedModel, message, request, [output](Request* request) {
132             request->outputs.erase(request->outputs.begin() + output);
133         });
134     }
135 }
136 
137 ///////////////////////////// ENTRY POINT //////////////////////////////////
138 
validateRequest(const std::shared_ptr<IPreparedModel> & preparedModel,const Request & request)139 void validateRequest(const std::shared_ptr<IPreparedModel>& preparedModel, const Request& request) {
140     removeInputTest(preparedModel, request);
141     removeOutputTest(preparedModel, request);
142 }
143 
validateBurst(const std::shared_ptr<IPreparedModel> & preparedModel,const Request & request)144 void validateBurst(const std::shared_ptr<IPreparedModel>& preparedModel, const Request& request) {
145     // create burst
146     std::shared_ptr<IBurst> burst;
147     auto ret = preparedModel->configureExecutionBurst(&burst);
148     ASSERT_TRUE(ret.isOk()) << ret.getDescription();
149     ASSERT_NE(nullptr, burst.get());
150 
151     const auto test = [&burst, &request](const std::vector<int64_t>& slots) {
152         ExecutionResult executionResult;
153         const auto executeStatus =
154                 burst->executeSynchronously(request, slots, /*measure=*/false, kNoDeadline,
155                                             kOmittedTimeoutDuration, &executionResult);
156         ASSERT_FALSE(executeStatus.isOk());
157         ASSERT_EQ(executeStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
158         ASSERT_EQ(static_cast<ErrorStatus>(executeStatus.getServiceSpecificError()),
159                   ErrorStatus::INVALID_ARGUMENT);
160     };
161 
162     int64_t currentSlot = 0;
163     std::vector<int64_t> slots;
164     slots.reserve(request.pools.size());
165     for (const auto& pool : request.pools) {
166         if (pool.getTag() == RequestMemoryPool::Tag::pool) {
167             slots.push_back(currentSlot++);
168         } else {
169             slots.push_back(-1);
170         }
171     }
172 
173     constexpr int64_t invalidSlot = -2;
174 
175     // validate failure when invalid memory identifier token value
176     for (size_t i = 0; i < request.pools.size(); ++i) {
177         const int64_t oldSlotValue = slots[i];
178 
179         slots[i] = invalidSlot;
180         test(slots);
181 
182         slots[i] = oldSlotValue;
183     }
184 
185     // validate failure when request.pools.size() != memoryIdentifierTokens.size()
186     if (request.pools.size() > 0) {
187         slots = std::vector<int64_t>(request.pools.size() - 1, -1);
188         test(slots);
189     }
190 
191     // validate failure when request.pools.size() != memoryIdentifierTokens.size()
192     slots = std::vector<int64_t>(request.pools.size() + 1, -1);
193     test(slots);
194 
195     // validate failure when invalid memory identifier token value
196     const auto freeStatus = burst->releaseMemoryResource(invalidSlot);
197     ASSERT_FALSE(freeStatus.isOk());
198     ASSERT_EQ(freeStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
199     ASSERT_EQ(static_cast<ErrorStatus>(freeStatus.getServiceSpecificError()),
200               ErrorStatus::INVALID_ARGUMENT);
201 }
202 
validateRequestFailure(const std::shared_ptr<IPreparedModel> & preparedModel,const Request & request)203 void validateRequestFailure(const std::shared_ptr<IPreparedModel>& preparedModel,
204                             const Request& request) {
205     SCOPED_TRACE("Expecting request to fail [executeSynchronously]");
206     ExecutionResult executionResult;
207     const auto executeStatus = preparedModel->executeSynchronously(
208             request, false, kNoDeadline, kOmittedTimeoutDuration, &executionResult);
209 
210     ASSERT_FALSE(executeStatus.isOk());
211     ASSERT_EQ(executeStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
212     ASSERT_NE(static_cast<ErrorStatus>(executeStatus.getServiceSpecificError()), ErrorStatus::NONE);
213 }
214 
215 }  // namespace aidl::android::hardware::neuralnetworks::vts::functional
216