1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "neuralnetworks_aidl_hal_test"
18
19 #include <aidl/android/hardware/common/NativeHandle.h>
20 #include <android/binder_auto_utils.h>
21 #include <android/binder_enums.h>
22 #include <android/binder_interface_utils.h>
23 #include <nnapi/TypeUtils.h>
24 #include <nnapi/hal/aidl/Conversions.h>
25 #include <nnapi/hal/aidl/Utils.h>
26
27 #include <optional>
28 #include <type_traits>
29 #include <utility>
30
31 #include "Callbacks.h"
32 #include "GeneratedTestHarness.h"
33 #include "Utils.h"
34 #include "VtsHalNeuralnetworks.h"
35
36 namespace aidl::android::hardware::neuralnetworks::vts::functional {
37
38 using common::NativeHandle;
39 using implementation::PreparedModelCallback;
40
41 using PrepareModelMutation = std::function<void(Model*, ExecutionPreference*, Priority*)>;
42
43 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
44
validateGetSupportedOperations(const std::shared_ptr<IDevice> & device,const std::string & message,const Model & model)45 static void validateGetSupportedOperations(const std::shared_ptr<IDevice>& device,
46 const std::string& message, const Model& model) {
47 SCOPED_TRACE(message + " [getSupportedOperations]");
48
49 std::vector<bool> supported;
50 const auto retStatus = device->getSupportedOperations(model, &supported);
51
52 ASSERT_FALSE(retStatus.isOk());
53 ASSERT_EQ(retStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
54 ASSERT_EQ(static_cast<ErrorStatus>(retStatus.getServiceSpecificError()),
55 ErrorStatus::INVALID_ARGUMENT);
56 }
57
validatePrepareModel(const std::shared_ptr<IDevice> & device,const std::string & message,const Model & model,ExecutionPreference preference,Priority priority)58 static void validatePrepareModel(const std::shared_ptr<IDevice>& device, const std::string& message,
59 const Model& model, ExecutionPreference preference,
60 Priority priority) {
61 SCOPED_TRACE(message + " [prepareModel]");
62
63 std::shared_ptr<PreparedModelCallback> preparedModelCallback =
64 ndk::SharedRefBase::make<PreparedModelCallback>();
65 const auto prepareLaunchStatus =
66 device->prepareModel(model, preference, priority, kNoDeadline, {}, {}, kEmptyCacheToken,
67 preparedModelCallback);
68 ASSERT_FALSE(prepareLaunchStatus.isOk());
69 ASSERT_EQ(prepareLaunchStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
70 ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus.getServiceSpecificError()),
71 ErrorStatus::INVALID_ARGUMENT);
72
73 preparedModelCallback->wait();
74 ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
75 ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, prepareReturnStatus);
76 std::shared_ptr<IPreparedModel> preparedModel = preparedModelCallback->getPreparedModel();
77 ASSERT_EQ(nullptr, preparedModel.get());
78 }
79
validExecutionPreference(ExecutionPreference preference)80 static bool validExecutionPreference(ExecutionPreference preference) {
81 return preference == ExecutionPreference::LOW_POWER ||
82 preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
83 preference == ExecutionPreference::SUSTAINED_SPEED;
84 }
85
validExecutionPriority(Priority priority)86 static bool validExecutionPriority(Priority priority) {
87 return priority == Priority::LOW || priority == Priority::MEDIUM || priority == Priority::HIGH;
88 }
89
90 // Primary validation function. This function will take a valid model, apply a
91 // mutation to invalidate the model, the execution preference, or the priority,
92 // then pass these to supportedOperations and/or prepareModel if that method is
93 // called with an invalid argument.
validate(const std::shared_ptr<IDevice> & device,const std::string & message,const Model & originalModel,const PrepareModelMutation & mutate)94 static void validate(const std::shared_ptr<IDevice>& device, const std::string& message,
95 const Model& originalModel, const PrepareModelMutation& mutate) {
96 Model model = utils::clone(originalModel).value();
97 ExecutionPreference preference = ExecutionPreference::FAST_SINGLE_ANSWER;
98 Priority priority = kDefaultPriority;
99 mutate(&model, &preference, &priority);
100
101 if (validExecutionPreference(preference) && validExecutionPriority(priority)) {
102 validateGetSupportedOperations(device, message, model);
103 }
104
105 validatePrepareModel(device, message, model, preference, priority);
106 }
107
addOperand(Model * model)108 static uint32_t addOperand(Model* model) {
109 model->main.operands.push_back({
110 .type = OperandType::INT32,
111 .dimensions = {},
112 .scale = 0.0f,
113 .zeroPoint = 0,
114 .lifetime = OperandLifeTime::SUBGRAPH_INPUT,
115 .location = {.poolIndex = 0, .offset = 0, .length = 0},
116 });
117 return model->main.operands.size() - 1;
118 }
119
addOperand(Model * model,OperandLifeTime lifetime)120 static uint32_t addOperand(Model* model, OperandLifeTime lifetime) {
121 uint32_t index = addOperand(model);
122 model->main.operands[index].lifetime = lifetime;
123 return index;
124 }
125
126 // If we introduce a CONSTANT_COPY for an operand of size operandSize,
127 // how much will this increase the size of the model? This assumes
128 // that we can (re)use all of model.operandValues for the operand
129 // value.
constantCopyExtraSize(const Model & model,size_t operandSize)130 static size_t constantCopyExtraSize(const Model& model, size_t operandSize) {
131 const size_t operandValuesSize = model.operandValues.size();
132 return (operandValuesSize < operandSize) ? (operandSize - operandValuesSize) : 0;
133 }
134
135 // Highly specialized utility routine for converting an operand to
136 // CONSTANT_COPY lifetime.
137 //
138 // Expects that:
139 // - operand has a known size
140 // - operand->lifetime has already been set to CONSTANT_COPY
141 // - operand->location has been zeroed out
142 //
143 // Does the following:
144 // - initializes operand->location to point to the beginning of model->operandValues
145 // - resizes model->operandValues (if necessary) to be large enough for the operand
146 // value, padding it with zeroes on the end
147 //
148 // Potential problem:
149 // By changing the operand to CONSTANT_COPY lifetime, this function is effectively initializing the
150 // operand with unspecified (but deterministic) data. This means that the model may be invalidated
151 // in two ways: not only is the lifetime of CONSTANT_COPY invalid, but the operand's value in the
152 // graph may also be invalid (e.g., if the operand is used as an activation code and has an invalid
153 // value). For now, this should be fine because it just means we're not testing what we think we're
154 // testing in certain cases; but we can handwave this and assume we're probabilistically likely to
155 // exercise the validation code over the span of the entire test set and operand space.
156 //
157 // Aborts if the specified operand type is an extension type or OEM type.
becomeConstantCopy(Model * model,Operand * operand)158 static void becomeConstantCopy(Model* model, Operand* operand) {
159 // sizeOfData will abort if the specified type is an extension type or OEM type.
160 const size_t sizeOfOperand = sizeOfData(*operand);
161 EXPECT_NE(sizeOfOperand, size_t(0));
162 operand->location.poolIndex = 0;
163 operand->location.offset = 0;
164 operand->location.length = sizeOfOperand;
165 if (model->operandValues.size() < sizeOfOperand) {
166 model->operandValues.resize(sizeOfOperand);
167 }
168 }
169
170 // The sizeForBinder() functions estimate the size of the
171 // representation of a value when sent to binder. It's probably a bit
172 // of an under-estimate, because we don't know the size of the
173 // metadata in the binder format (e.g., representation of the size of
174 // a vector); but at least it adds up "big" things like vector
175 // contents. However, it doesn't treat inter-field or end-of-struct
176 // padding in a methodical way -- there's no attempt to be consistent
177 // in whether or not padding in the native (C++) representation
178 // contributes to the estimated size for the binder representation;
179 // and there's no attempt to understand what padding (if any) is
180 // needed in the binder representation.
181 //
182 // This assumes that non-metadata uses a fixed length encoding (e.g.,
183 // a uint32_t is always encoded in sizeof(uint32_t) bytes, rather than
184 // using an encoding whose length is related to the magnitude of the
185 // encoded value).
186
187 template <typename Type>
sizeForBinder(const Type & val)188 static size_t sizeForBinder(const Type& val) {
189 static_assert(std::is_trivially_copyable_v<std::remove_reference_t<Type>>,
190 "expected a trivially copyable type");
191 return sizeof(val);
192 }
193
194 template <typename Type>
sizeForBinder(const std::vector<Type> & vec)195 static size_t sizeForBinder(const std::vector<Type>& vec) {
196 return std::accumulate(vec.begin(), vec.end(), 0,
197 [](size_t acc, const Type& x) { return acc + sizeForBinder(x); });
198 }
199
200 template <>
sizeForBinder(const SymmPerChannelQuantParams & symmPerChannelQuantParams)201 size_t sizeForBinder(const SymmPerChannelQuantParams& symmPerChannelQuantParams) {
202 size_t size = 0;
203
204 size += sizeForBinder(symmPerChannelQuantParams.scales);
205 size += sizeForBinder(symmPerChannelQuantParams.channelDim);
206
207 return size;
208 }
209
210 template <>
sizeForBinder(const std::optional<OperandExtraParams> & optionalExtraParams)211 size_t sizeForBinder(const std::optional<OperandExtraParams>& optionalExtraParams) {
212 if (!optionalExtraParams.has_value()) {
213 return 0;
214 }
215 const auto& extraParams = optionalExtraParams.value();
216 using Tag = OperandExtraParams::Tag;
217 switch (extraParams.getTag()) {
218 case Tag::channelQuant:
219 return sizeForBinder(extraParams.get<Tag::channelQuant>());
220 case Tag::extension:
221 return sizeForBinder(extraParams.get<Tag::extension>());
222 }
223 LOG(FATAL) << "Unrecognized extraParams tag: " << static_cast<int>(extraParams.getTag());
224 return 0;
225 }
226
227 template <>
sizeForBinder(const Operand & operand)228 size_t sizeForBinder(const Operand& operand) {
229 size_t size = 0;
230
231 size += sizeForBinder(operand.type);
232 size += sizeForBinder(operand.dimensions);
233 size += sizeForBinder(operand.scale);
234 size += sizeForBinder(operand.zeroPoint);
235 size += sizeForBinder(operand.lifetime);
236 size += sizeForBinder(operand.location);
237 size += sizeForBinder(operand.extraParams);
238
239 return size;
240 }
241
242 template <>
sizeForBinder(const Operation & operation)243 size_t sizeForBinder(const Operation& operation) {
244 size_t size = 0;
245
246 size += sizeForBinder(operation.type);
247 size += sizeForBinder(operation.inputs);
248 size += sizeForBinder(operation.outputs);
249
250 return size;
251 }
252
253 template <>
sizeForBinder(const std::string & name)254 size_t sizeForBinder(const std::string& name) {
255 return name.size();
256 }
257
258 template <>
sizeForBinder(const Memory & memory)259 size_t sizeForBinder(const Memory& memory) {
260 // This is just a guess.
261
262 size_t size = sizeof(Memory);
263
264 // Only hardwareBuffer type memory has dynamic memory that needs to be accounted for (in the
265 // form of a NativeHandle type). The other other types of memory (MappableFile, Ashmem) use a
266 // single file descriptor (with metadata) instead.
267 if (memory.getTag() == Memory::Tag::hardwareBuffer) {
268 const NativeHandle& handle = memory.get<Memory::Tag::hardwareBuffer>().handle;
269 size += sizeof(decltype(handle.fds)::value_type) * handle.fds.size();
270 size += sizeof(decltype(handle.ints)::value_type) * handle.ints.size();
271 }
272
273 return size;
274 }
275
276 template <>
sizeForBinder(const Subgraph & subgraph)277 size_t sizeForBinder(const Subgraph& subgraph) {
278 size_t size = 0;
279
280 size += sizeForBinder(subgraph.operands);
281 size += sizeForBinder(subgraph.operations);
282 size += sizeForBinder(subgraph.inputIndexes);
283 size += sizeForBinder(subgraph.outputIndexes);
284
285 return size;
286 }
287
288 template <>
sizeForBinder(const ExtensionNameAndPrefix & extensionNameToPrefix)289 size_t sizeForBinder(const ExtensionNameAndPrefix& extensionNameToPrefix) {
290 size_t size = 0;
291
292 size += sizeForBinder(extensionNameToPrefix.name);
293 size += sizeForBinder(extensionNameToPrefix.prefix);
294
295 return size;
296 }
297
298 template <>
sizeForBinder(const Model & model)299 size_t sizeForBinder(const Model& model) {
300 size_t size = 0;
301
302 size += sizeForBinder(model.main);
303 size += sizeForBinder(model.referenced);
304 size += sizeForBinder(model.operandValues);
305 size += sizeForBinder(model.pools);
306 size += sizeForBinder(model.relaxComputationFloat32toFloat16);
307 size += sizeForBinder(model.extensionNameToPrefix);
308
309 return size;
310 }
311
312 // https://developer.android.com/reference/android/os/TransactionTooLargeException.html
313 //
314 // "The Binder transaction buffer has a limited fixed size,
315 // currently 1Mb, which is shared by all transactions in progress
316 // for the process."
317 //
318 // Will our representation fit under this limit? There are two complications:
319 // - Our representation size is just approximate (see sizeForBinder()).
320 // - This object may not be the only occupant of the Binder transaction buffer.
321 // So we'll be very conservative: We want the representation size to be no
322 // larger than half the transaction buffer size.
323 //
324 // If our representation grows large enough that it still fits within
325 // the transaction buffer but combined with other transactions may
326 // exceed the buffer size, then we may see intermittent HAL transport
327 // errors.
exceedsBinderSizeLimit(size_t representationSize)328 static bool exceedsBinderSizeLimit(size_t representationSize) {
329 // Instead of using this fixed buffer size, we might instead be able to use
330 // ProcessState::self()->getMmapSize(). However, this has a potential
331 // problem: The binder/mmap size of the current process does not necessarily
332 // indicate the binder/mmap size of the service (i.e., the other process).
333 // The only way it would be a good indication is if both the current process
334 // and the service use the default size.
335 static const size_t kHalfBufferSize = 1024 * 1024 / 2;
336
337 return representationSize > kHalfBufferSize;
338 }
339
340 ///////////////////////// VALIDATE EXECUTION ORDER ////////////////////////////
341
mutateExecutionOrderTest(const std::shared_ptr<IDevice> & device,const Model & model,const std::vector<uint32_t> & numberOfConsumers)342 static void mutateExecutionOrderTest(const std::shared_ptr<IDevice>& device, const Model& model,
343 const std::vector<uint32_t>& numberOfConsumers) {
344 for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
345 const Operation& operationObj = model.main.operations[operation];
346 for (uint32_t input : operationObj.inputs) {
347 if (model.main.operands[input].lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
348 model.main.operands[input].lifetime == OperandLifeTime::SUBGRAPH_OUTPUT) {
349 // This operation reads an operand written by some
350 // other operation. Move this operation to the
351 // beginning of the sequence, ensuring that it reads
352 // the operand before that operand is written, thereby
353 // violating execution order rules.
354 const std::string message = "mutateExecutionOrderTest: operation " +
355 std::to_string(operation) + " is a reader";
356 validate(device, message, model,
357 [operation](Model* model, ExecutionPreference*, Priority*) {
358 auto& operations = model->main.operations;
359 std::rotate(operations.begin(), operations.begin() + operation,
360 operations.begin() + operation + 1);
361 });
362 break; // only need to do this once per operation
363 }
364 }
365 for (uint32_t output : operationObj.outputs) {
366 if (numberOfConsumers[output] > 0) {
367 // This operation writes an operand read by some other
368 // operation. Move this operation to the end of the
369 // sequence, ensuring that it writes the operand after
370 // that operand is read, thereby violating execution
371 // order rules.
372 const std::string message = "mutateExecutionOrderTest: operation " +
373 std::to_string(operation) + " is a writer";
374 validate(device, message, model,
375 [operation](Model* model, ExecutionPreference*, Priority*) {
376 auto& operations = model->main.operations;
377 std::rotate(operations.begin() + operation,
378 operations.begin() + operation + 1, operations.end());
379 });
380 break; // only need to do this once per operation
381 }
382 }
383 }
384 }
385
386 ///////////////////////// VALIDATE MODEL OPERAND TYPE /////////////////////////
387
388 static const int32_t invalidOperandTypes[] = {
389 -1,
390 static_cast<int32_t>(*(ndk::enum_range<OperandType>().end() - 1)) + 1,
391 };
392
mutateOperandTypeTest(const std::shared_ptr<IDevice> & device,const Model & model)393 static void mutateOperandTypeTest(const std::shared_ptr<IDevice>& device, const Model& model) {
394 for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
395 for (int32_t invalidOperandType : invalidOperandTypes) {
396 const std::string message = "mutateOperandTypeTest: operand " +
397 std::to_string(operand) + " set to value " +
398 std::to_string(invalidOperandType);
399 validate(device, message, model,
400 [operand, invalidOperandType](Model* model, ExecutionPreference*, Priority*) {
401 model->main.operands[operand].type =
402 static_cast<OperandType>(invalidOperandType);
403 });
404 }
405 }
406 }
407
408 ///////////////////////// VALIDATE OPERAND RANK /////////////////////////
409
getInvalidRank(OperandType type)410 static uint32_t getInvalidRank(OperandType type) {
411 switch (type) {
412 case OperandType::FLOAT16:
413 case OperandType::FLOAT32:
414 case OperandType::INT32:
415 case OperandType::UINT32:
416 case OperandType::BOOL:
417 return 1;
418 case OperandType::TENSOR_BOOL8:
419 case OperandType::TENSOR_FLOAT16:
420 case OperandType::TENSOR_FLOAT32:
421 case OperandType::TENSOR_INT32:
422 case OperandType::TENSOR_QUANT8_ASYMM:
423 case OperandType::TENSOR_QUANT8_SYMM:
424 case OperandType::TENSOR_QUANT16_ASYMM:
425 case OperandType::TENSOR_QUANT16_SYMM:
426 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
427 return 0;
428 default:
429 return 0;
430 }
431 }
432
mutateOperandRankTest(const std::shared_ptr<IDevice> & device,const Model & model)433 static void mutateOperandRankTest(const std::shared_ptr<IDevice>& device, const Model& model) {
434 for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
435 const uint32_t invalidRank = getInvalidRank(model.main.operands[operand].type);
436 if (invalidRank == 0) {
437 continue;
438 }
439 const std::string message = "mutateOperandRankTest: operand " + std::to_string(operand) +
440 " has rank of " + std::to_string(invalidRank);
441 validate(device, message, model,
442 [operand, invalidRank](Model* model, ExecutionPreference*, Priority*) {
443 model->main.operands[operand].dimensions =
444 std::vector<int32_t>(invalidRank, 0);
445 });
446 }
447 }
448
449 ///////////////////////// VALIDATE OPERAND SCALE /////////////////////////
450
getInvalidScale(OperandType type)451 static float getInvalidScale(OperandType type) {
452 switch (type) {
453 case OperandType::FLOAT16:
454 case OperandType::FLOAT32:
455 case OperandType::INT32:
456 case OperandType::UINT32:
457 case OperandType::BOOL:
458 case OperandType::TENSOR_BOOL8:
459 case OperandType::TENSOR_FLOAT16:
460 case OperandType::TENSOR_FLOAT32:
461 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
462 case OperandType::SUBGRAPH:
463 return 1.0f;
464 case OperandType::TENSOR_INT32:
465 return -1.0f;
466 case OperandType::TENSOR_QUANT8_SYMM:
467 case OperandType::TENSOR_QUANT8_ASYMM:
468 case OperandType::TENSOR_QUANT16_ASYMM:
469 case OperandType::TENSOR_QUANT16_SYMM:
470 return 0.0f;
471 default:
472 return 0.0f;
473 }
474 }
475
mutateOperandScaleTest(const std::shared_ptr<IDevice> & device,const Model & model)476 static void mutateOperandScaleTest(const std::shared_ptr<IDevice>& device, const Model& model) {
477 for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
478 const float invalidScale = getInvalidScale(model.main.operands[operand].type);
479 const std::string message = "mutateOperandScaleTest: operand " + std::to_string(operand) +
480 " has scale of " + std::to_string(invalidScale);
481 validate(device, message, model,
482 [operand, invalidScale](Model* model, ExecutionPreference*, Priority*) {
483 model->main.operands[operand].scale = invalidScale;
484 });
485 }
486 }
487
488 ///////////////////////// VALIDATE OPERAND ZERO POINT /////////////////////////
489
getInvalidZeroPoints(OperandType type)490 static std::vector<int32_t> getInvalidZeroPoints(OperandType type) {
491 switch (type) {
492 case OperandType::FLOAT16:
493 case OperandType::FLOAT32:
494 case OperandType::INT32:
495 case OperandType::UINT32:
496 case OperandType::BOOL:
497 case OperandType::TENSOR_BOOL8:
498 case OperandType::TENSOR_FLOAT16:
499 case OperandType::TENSOR_FLOAT32:
500 case OperandType::TENSOR_INT32:
501 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
502 case OperandType::SUBGRAPH:
503 return {1};
504 case OperandType::TENSOR_QUANT8_ASYMM:
505 return {-1, 256};
506 case OperandType::TENSOR_QUANT8_SYMM:
507 return {-129, -1, 1, 128};
508 case OperandType::TENSOR_QUANT16_ASYMM:
509 return {-1, 65536};
510 case OperandType::TENSOR_QUANT16_SYMM:
511 return {-32769, -1, 1, 32768};
512 default:
513 return {};
514 }
515 }
516
mutateOperandZeroPointTest(const std::shared_ptr<IDevice> & device,const Model & model)517 static void mutateOperandZeroPointTest(const std::shared_ptr<IDevice>& device, const Model& model) {
518 for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
519 const std::vector<int32_t> invalidZeroPoints =
520 getInvalidZeroPoints(model.main.operands[operand].type);
521 for (int32_t invalidZeroPoint : invalidZeroPoints) {
522 const std::string message = "mutateOperandZeroPointTest: operand " +
523 std::to_string(operand) + " has zero point of " +
524 std::to_string(invalidZeroPoint);
525 validate(device, message, model,
526 [operand, invalidZeroPoint](Model* model, ExecutionPreference*, Priority*) {
527 model->main.operands[operand].zeroPoint = invalidZeroPoint;
528 });
529 }
530 }
531 }
532
533 ///////////////////////// VALIDATE OPERAND LIFETIME /////////////////////////////////////////////
534
getInvalidLifeTimes(const Model & model,size_t modelSize,const Operand & operand)535 static std::vector<OperandLifeTime> getInvalidLifeTimes(const Model& model, size_t modelSize,
536 const Operand& operand) {
537 // TODO: Support OperandLifeTime::CONSTANT_REFERENCE as an invalid lifetime
538 // TODO: Support OperandLifeTime::NO_VALUE as an invalid lifetime
539
540 // Ways to get an invalid lifetime:
541 // - change whether a lifetime means an operand should have a writer
542 std::vector<OperandLifeTime> ret;
543 switch (operand.lifetime) {
544 case OperandLifeTime::SUBGRAPH_OUTPUT:
545 case OperandLifeTime::TEMPORARY_VARIABLE:
546 ret = {
547 OperandLifeTime::SUBGRAPH_INPUT,
548 OperandLifeTime::CONSTANT_COPY,
549 };
550 break;
551 case OperandLifeTime::CONSTANT_COPY:
552 case OperandLifeTime::CONSTANT_POOL:
553 case OperandLifeTime::SUBGRAPH_INPUT:
554 ret = {
555 OperandLifeTime::TEMPORARY_VARIABLE,
556 OperandLifeTime::SUBGRAPH_OUTPUT,
557 };
558 break;
559 case OperandLifeTime::NO_VALUE:
560 // Not enough information to know whether
561 // TEMPORARY_VARIABLE or CONSTANT_COPY would be invalid --
562 // is this operand written (then CONSTANT_COPY would be
563 // invalid) or not (then TEMPORARY_VARIABLE would be
564 // invalid)?
565 break;
566 case OperandLifeTime::SUBGRAPH:
567 break;
568 default:
569 ADD_FAILURE();
570 break;
571 }
572
573 const size_t operandSize = sizeOfData(operand); // will be zero if shape is unknown
574 if (!operandSize ||
575 exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
576 // Unknown size or too-large size
577 ret.erase(std::remove(ret.begin(), ret.end(), OperandLifeTime::CONSTANT_COPY), ret.end());
578 }
579
580 return ret;
581 }
582
mutateOperandLifeTimeTest(const std::shared_ptr<IDevice> & device,const Model & model)583 static void mutateOperandLifeTimeTest(const std::shared_ptr<IDevice>& device, const Model& model) {
584 const size_t modelSize = sizeForBinder(model);
585 for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
586 const std::vector<OperandLifeTime> invalidLifeTimes =
587 getInvalidLifeTimes(model, modelSize, model.main.operands[operand]);
588 for (OperandLifeTime invalidLifeTime : invalidLifeTimes) {
589 const std::string message = "mutateOperandLifetimeTest: operand " +
590 std::to_string(operand) + " has lifetime " +
591 toString(invalidLifeTime) + " instead of lifetime " +
592 toString(model.main.operands[operand].lifetime);
593 validate(device, message, model,
594 [operand, invalidLifeTime](Model* model, ExecutionPreference*, Priority*) {
595 static const DataLocation kZeroDataLocation = {};
596 Operand& operandObj = model->main.operands[operand];
597 switch (operandObj.lifetime) {
598 case OperandLifeTime::SUBGRAPH_INPUT: {
599 auto& inputs = model->main.inputIndexes;
600 inputs.erase(std::remove(inputs.begin(), inputs.end(), operand),
601 inputs.end());
602 break;
603 }
604 case OperandLifeTime::SUBGRAPH_OUTPUT: {
605 auto& outputs = model->main.outputIndexes;
606 outputs.erase(std::remove(outputs.begin(), outputs.end(), operand),
607 outputs.end());
608 break;
609 }
610 default:
611 break;
612 }
613 operandObj.lifetime = invalidLifeTime;
614 operandObj.location = kZeroDataLocation;
615 switch (invalidLifeTime) {
616 case OperandLifeTime::CONSTANT_COPY: {
617 becomeConstantCopy(model, &operandObj);
618 break;
619 }
620 case OperandLifeTime::SUBGRAPH_INPUT:
621 model->main.inputIndexes.push_back(operand);
622 break;
623 case OperandLifeTime::SUBGRAPH_OUTPUT:
624 model->main.outputIndexes.push_back(operand);
625 break;
626 default:
627 break;
628 }
629 });
630 }
631 }
632 }
633
634 ///////////////////////// VALIDATE OPERAND INPUT-or-OUTPUT //////////////////////////////////////
635
getInputOutputLifeTime(const Model & model,size_t modelSize,const Operand & operand)636 static std::optional<OperandLifeTime> getInputOutputLifeTime(const Model& model, size_t modelSize,
637 const Operand& operand) {
638 // Ways to get an invalid lifetime (with respect to model inputIndexes and outputIndexes):
639 // - change whether a lifetime means an operand is a model input, a model output, or neither
640 // - preserve whether or not a lifetime means an operand should have a writer
641 switch (operand.lifetime) {
642 case OperandLifeTime::CONSTANT_COPY:
643 case OperandLifeTime::CONSTANT_POOL:
644 return OperandLifeTime::SUBGRAPH_INPUT;
645 case OperandLifeTime::SUBGRAPH_INPUT: {
646 const size_t operandSize = sizeOfData(operand); // will be zero if shape is unknown
647 if (!operandSize ||
648 exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
649 // Unknown size or too-large size
650 break;
651 }
652 return OperandLifeTime::CONSTANT_COPY;
653 }
654 case OperandLifeTime::SUBGRAPH_OUTPUT:
655 return OperandLifeTime::TEMPORARY_VARIABLE;
656 case OperandLifeTime::TEMPORARY_VARIABLE:
657 return OperandLifeTime::SUBGRAPH_OUTPUT;
658 case OperandLifeTime::NO_VALUE:
659 // Not enough information to know whether
660 // TEMPORARY_VARIABLE or CONSTANT_COPY would be an
661 // appropriate choice -- is this operand written (then
662 // TEMPORARY_VARIABLE would be appropriate) or not (then
663 // CONSTANT_COPY would be appropriate)?
664 break;
665 case OperandLifeTime::SUBGRAPH:
666 break;
667 default:
668 ADD_FAILURE();
669 break;
670 }
671
672 return std::nullopt;
673 }
674
mutateOperandInputOutputTest(const std::shared_ptr<IDevice> & device,const Model & model)675 static void mutateOperandInputOutputTest(const std::shared_ptr<IDevice>& device,
676 const Model& model) {
677 const size_t modelSize = sizeForBinder(model);
678 for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
679 const std::optional<OperandLifeTime> changedLifeTime =
680 getInputOutputLifeTime(model, modelSize, model.main.operands[operand]);
681 if (changedLifeTime) {
682 const std::string message = "mutateOperandInputOutputTest: operand " +
683 std::to_string(operand) + " has lifetime " +
684 toString(*changedLifeTime) + " instead of lifetime " +
685 toString(model.main.operands[operand].lifetime);
686 validate(device, message, model,
687 [operand, changedLifeTime](Model* model, ExecutionPreference*, Priority*) {
688 static const DataLocation kZeroDataLocation = {};
689 Operand& operandObj = model->main.operands[operand];
690 operandObj.lifetime = *changedLifeTime;
691 operandObj.location = kZeroDataLocation;
692 if (*changedLifeTime == OperandLifeTime::CONSTANT_COPY) {
693 becomeConstantCopy(model, &operandObj);
694 }
695 });
696 }
697 }
698 }
699
700 ///////////////////////// VALIDATE OPERAND NUMBER OF WRITERS ////////////////////////////////////
701
mutateOperandAddWriterTest(const std::shared_ptr<IDevice> & device,const Model & model)702 static void mutateOperandAddWriterTest(const std::shared_ptr<IDevice>& device, const Model& model) {
703 for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
704 for (size_t badOutputNum = 0;
705 badOutputNum < model.main.operations[operation].outputs.size(); ++badOutputNum) {
706 const uint32_t outputOperandIndex =
707 model.main.operations[operation].outputs[badOutputNum];
708 const std::string message = "mutateOperandAddWriterTest: operation " +
709 std::to_string(operation) + " writes to " +
710 std::to_string(outputOperandIndex);
711 // We'll insert a copy of the operation, all of whose
712 // OTHER output operands are newly-created -- i.e.,
713 // there'll only be a duplicate write of ONE of that
714 // operation's output operands.
715 validate(device, message, model,
716 [operation, badOutputNum](Model* model, ExecutionPreference*, Priority*) {
717 Operation newOperation = model->main.operations[operation];
718 for (size_t outputNum = 0; outputNum < newOperation.outputs.size();
719 ++outputNum) {
720 if (outputNum == badOutputNum) continue;
721
722 Operand operandValue =
723 model->main.operands[newOperation.outputs[outputNum]];
724 if (operandValue.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT) {
725 operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
726 } else {
727 ASSERT_EQ(operandValue.lifetime,
728 OperandLifeTime::TEMPORARY_VARIABLE);
729 }
730 newOperation.outputs[outputNum] = model->main.operands.size();
731 model->main.operands.push_back(operandValue);
732 }
733 // Where do we insert the extra writer (a new
734 // operation)? It has to be later than all the
735 // writers of its inputs. The easiest thing to do
736 // is to insert it at the end of the operation
737 // sequence.
738 model->main.operations.push_back(newOperation);
739 });
740 }
741 }
742 }
743
744 ///////////////////////// VALIDATE EXTRA ??? /////////////////////////
745
746 // TODO: Operand::location
747
748 ///////////////////////// VALIDATE OPERATION OPERAND TYPE /////////////////////////
749
mutateOperand(Operand * operand,OperandType type)750 static void mutateOperand(Operand* operand, OperandType type) {
751 Operand newOperand = *operand;
752 newOperand.type = type;
753 switch (type) {
754 case OperandType::FLOAT16:
755 case OperandType::FLOAT32:
756 case OperandType::INT32:
757 case OperandType::UINT32:
758 case OperandType::BOOL:
759 newOperand.dimensions = {};
760 newOperand.scale = 0.0f;
761 newOperand.zeroPoint = 0;
762 break;
763 case OperandType::TENSOR_BOOL8:
764 case OperandType::TENSOR_FLOAT16:
765 case OperandType::TENSOR_FLOAT32:
766 newOperand.dimensions = operand->dimensions.size() > 0 ? operand->dimensions
767 : std::vector<int32_t>({1});
768 newOperand.scale = 0.0f;
769 newOperand.zeroPoint = 0;
770 break;
771 case OperandType::TENSOR_INT32:
772 newOperand.dimensions = operand->dimensions.size() > 0 ? operand->dimensions
773 : std::vector<int32_t>({1});
774 newOperand.zeroPoint = 0;
775 break;
776 case OperandType::TENSOR_QUANT8_ASYMM:
777 case OperandType::TENSOR_QUANT8_SYMM:
778 case OperandType::TENSOR_QUANT16_ASYMM:
779 case OperandType::TENSOR_QUANT16_SYMM:
780 newOperand.dimensions = operand->dimensions.size() > 0 ? operand->dimensions
781 : std::vector<int32_t>({1});
782 newOperand.scale = operand->scale != 0.0f ? operand->scale : 1.0f;
783 break;
784 case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
785 newOperand.dimensions = operand->dimensions.size() > 0 ? operand->dimensions
786 : std::vector<int32_t>({1});
787 newOperand.scale = 0.0f;
788 newOperand.zeroPoint = 0;
789
790 SymmPerChannelQuantParams channelQuant;
791 channelQuant.channelDim = 0;
792 channelQuant.scales = std::vector<float>(
793 operand->dimensions.size() > 0 ? static_cast<size_t>(operand->dimensions[0])
794 : 0);
795 for (size_t i = 0; i < channelQuant.scales.size(); ++i) {
796 channelQuant.scales[i] = 1.0f;
797 }
798 newOperand.extraParams->set<OperandExtraParams::Tag::channelQuant>(
799 std::move(channelQuant));
800 } break;
801 default:
802 break;
803 }
804 *operand = newOperand;
805 }
806
mutateOperationOperandTypeSkip(size_t operand,OperandType type,const Model & model)807 static bool mutateOperationOperandTypeSkip(size_t operand, OperandType type, const Model& model) {
808 if (type == model.main.operands[operand].type) {
809 return true;
810 }
811 for (const Operation& operation : model.main.operations) {
812 // Skip mutateOperationOperandTypeTest for the following operations.
813 // - LSH_PROJECTION's second argument is allowed to have any type.
814 // - ARGMIN and ARGMAX's first argument can be any of
815 // TENSOR_(FLOAT16|FLOAT32|INT32|QUANT8_ASYMM).
816 // - CAST's argument can be any of TENSOR_(FLOAT16|FLOAT32|INT32|QUANT8_ASYMM).
817 // - RANDOM_MULTINOMIAL's argument can be either TENSOR_FLOAT16 or TENSOR_FLOAT32.
818 // - DEQUANTIZE input can be any of
819 // TENSOR_(QUANT8_ASYMM|QUANT8_ASYMM_SIGNED|QUANT8_SYMM|QUANT8_SYMM_PER_CHANNEL),
820 // output can be of either TENSOR_FLOAT16 or TENSOR_FLOAT32.
821 // - QUANTIZE input can be either TENSOR_FLOAT16 or TENSOR_FLOAT32
822 // - CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
823 // - DEPTHWISE_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
824 // - GROUPED_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
825 // - TRANSPOSE_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
826 // - AXIS_ALIGNED_BBOX_TRANSFORM bounding boxes (arg 1) can be of
827 // TENSOR_QUANT8_ASYMM or TENSOR_QUANT8_ASYMM_SIGNED.
828 // - RANK's input can have any TENSOR_* type.
829 switch (operation.type) {
830 case OperationType::LSH_PROJECTION: {
831 if (operand == operation.inputs[1]) {
832 return true;
833 }
834 } break;
835 case OperationType::CAST:
836 case OperationType::ARGMAX:
837 case OperationType::ARGMIN: {
838 if (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32 ||
839 type == OperandType::TENSOR_INT32 || type == OperandType::TENSOR_QUANT8_ASYMM ||
840 type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
841 return true;
842 }
843 } break;
844 case OperationType::QUANTIZE: {
845 if (operand == operation.inputs[0] &&
846 (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
847 return true;
848 }
849 if (operand == operation.outputs[0] &&
850 (type == OperandType::TENSOR_QUANT8_ASYMM ||
851 type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)) {
852 return true;
853 }
854 } break;
855 case OperationType::RANDOM_MULTINOMIAL: {
856 if (operand == operation.inputs[0] &&
857 (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
858 return true;
859 }
860 } break;
861 case OperationType::DEQUANTIZE: {
862 if (operand == operation.inputs[0] &&
863 (type == OperandType::TENSOR_QUANT8_ASYMM ||
864 type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED ||
865 type == OperandType::TENSOR_QUANT8_SYMM ||
866 type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)) {
867 return true;
868 }
869 if (operand == operation.outputs[0] &&
870 (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
871 return true;
872 }
873 } break;
874 case OperationType::TRANSPOSE_CONV_2D:
875 case OperationType::GROUPED_CONV_2D:
876 case OperationType::DEPTHWISE_CONV_2D:
877 case OperationType::CONV_2D: {
878 if (operand == operation.inputs[1] &&
879 (type == OperandType::TENSOR_QUANT8_ASYMM ||
880 type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)) {
881 return true;
882 }
883 } break;
884 case OperationType::AXIS_ALIGNED_BBOX_TRANSFORM: {
885 if (operand == operation.inputs[1] &&
886 (type == OperandType::TENSOR_QUANT8_ASYMM ||
887 type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)) {
888 return true;
889 }
890 } break;
891 case OperationType::RANK: {
892 if (operand == operation.inputs[0] &&
893 (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32 ||
894 type == OperandType::TENSOR_INT32 ||
895 type == OperandType::TENSOR_QUANT8_ASYMM ||
896 type == OperandType::TENSOR_QUANT16_SYMM ||
897 type == OperandType::TENSOR_BOOL8 ||
898 type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL ||
899 type == OperandType::TENSOR_QUANT16_ASYMM ||
900 type == OperandType::TENSOR_QUANT8_SYMM ||
901 type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)) {
902 return true;
903 }
904 } break;
905 default:
906 break;
907 }
908 }
909 return false;
910 }
911
mutateOperationOperandTypeTest(const std::shared_ptr<IDevice> & device,const Model & model)912 static void mutateOperationOperandTypeTest(const std::shared_ptr<IDevice>& device,
913 const Model& model) {
914 for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
915 for (OperandType invalidOperandType : ndk::enum_range<OperandType>()) {
916 if (mutateOperationOperandTypeSkip(operand, invalidOperandType, model)) {
917 continue;
918 }
919 const std::string message = "mutateOperationOperandTypeTest: operand " +
920 std::to_string(operand) + " set to type " +
921 toString(invalidOperandType);
922 validate(device, message, model,
923 [operand, invalidOperandType](Model* model, ExecutionPreference*, Priority*) {
924 mutateOperand(&model->main.operands[operand], invalidOperandType);
925 });
926 }
927 }
928 }
929
930 ///////////////////////// VALIDATE MODEL OPERATION TYPE /////////////////////////
931
932 static const int32_t invalidOperationTypes[] = {
933 -1,
934 static_cast<int32_t>(*(ndk::enum_range<OperationType>().end() - 1)) + 1,
935 };
936
mutateOperationTypeTest(const std::shared_ptr<IDevice> & device,const Model & model)937 static void mutateOperationTypeTest(const std::shared_ptr<IDevice>& device, const Model& model) {
938 for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
939 for (int32_t invalidOperationType : invalidOperationTypes) {
940 const std::string message = "mutateOperationTypeTest: operation " +
941 std::to_string(operation) + " set to value " +
942 std::to_string(invalidOperationType);
943 validate(device, message, model,
944 [operation, invalidOperationType](Model* model, ExecutionPreference*,
945 Priority*) {
946 model->main.operations[operation].type =
947 static_cast<OperationType>(invalidOperationType);
948 });
949 }
950 }
951 }
952
953 ///////////////////////// VALIDATE MODEL OPERATION INPUT OPERAND INDEX /////////////////////////
954
mutateOperationInputOperandIndexTest(const std::shared_ptr<IDevice> & device,const Model & model)955 static void mutateOperationInputOperandIndexTest(const std::shared_ptr<IDevice>& device,
956 const Model& model) {
957 for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
958 const uint32_t invalidOperand = model.main.operands.size();
959 for (size_t input = 0; input < model.main.operations[operation].inputs.size(); ++input) {
960 const std::string message = "mutateOperationInputOperandIndexTest: operation " +
961 std::to_string(operation) + " input " +
962 std::to_string(input);
963 validate(device, message, model,
964 [operation, input, invalidOperand](Model* model, ExecutionPreference*,
965 Priority*) {
966 model->main.operations[operation].inputs[input] = invalidOperand;
967 });
968 }
969 }
970 }
971
972 ///////////////////////// VALIDATE MODEL OPERATION OUTPUT OPERAND INDEX /////////////////////////
973
mutateOperationOutputOperandIndexTest(const std::shared_ptr<IDevice> & device,const Model & model)974 static void mutateOperationOutputOperandIndexTest(const std::shared_ptr<IDevice>& device,
975 const Model& model) {
976 for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
977 const uint32_t invalidOperand = model.main.operands.size();
978 for (size_t output = 0; output < model.main.operations[operation].outputs.size();
979 ++output) {
980 const std::string message = "mutateOperationOutputOperandIndexTest: operation " +
981 std::to_string(operation) + " output " +
982 std::to_string(output);
983 validate(device, message, model,
984 [operation, output, invalidOperand](Model* model, ExecutionPreference*,
985 Priority*) {
986 model->main.operations[operation].outputs[output] = invalidOperand;
987 });
988 }
989 }
990 }
991
992 ///////////////////////// VALIDATE MODEL OPERANDS WRITTEN ///////////////////////////////////////
993
mutateOperationRemoveWriteTest(const std::shared_ptr<IDevice> & device,const Model & model,const std::vector<uint32_t> & numberOfConsumers)994 static void mutateOperationRemoveWriteTest(const std::shared_ptr<IDevice>& device,
995 const Model& model,
996 const std::vector<uint32_t>& numberOfConsumers) {
997 for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
998 for (size_t outputNum = 0; outputNum < model.main.operations[operation].outputs.size();
999 ++outputNum) {
1000 const uint32_t outputOperandIndex = model.main.operations[operation].outputs[outputNum];
1001 if (numberOfConsumers[outputOperandIndex] > 0) {
1002 const std::string message = "mutateOperationRemoveWriteTest: operation " +
1003 std::to_string(operation) + " writes to " +
1004 std::to_string(outputOperandIndex);
1005 validate(device, message, model,
1006 [operation, outputNum](Model* model, ExecutionPreference*, Priority*) {
1007 int32_t& outputOperandIndex =
1008 model->main.operations[operation].outputs[outputNum];
1009 Operand operandValue = model->main.operands[outputOperandIndex];
1010 if (operandValue.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT) {
1011 operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
1012 } else {
1013 ASSERT_EQ(operandValue.lifetime,
1014 OperandLifeTime::TEMPORARY_VARIABLE);
1015 }
1016 outputOperandIndex = model->main.operands.size();
1017 model->main.operands.push_back(operandValue);
1018 });
1019 }
1020 }
1021 }
1022 }
1023
1024 ///////////////////////// REMOVE OPERAND FROM EVERYTHING /////////////////////////
1025
removeValueAndDecrementGreaterValues(std::vector<int32_t> * vec,uint32_t value)1026 static void removeValueAndDecrementGreaterValues(std::vector<int32_t>* vec, uint32_t value) {
1027 if (vec) {
1028 // remove elements matching "value"
1029 vec->erase(std::remove(vec->begin(), vec->end(), value), vec->end());
1030
1031 // decrement elements exceeding "value"
1032 std::transform(vec->begin(), vec->end(), vec->begin(),
1033 [value](uint32_t v) { return v > value ? v-- : v; });
1034 }
1035 }
1036
removeOperand(Model * model,uint32_t index)1037 static void removeOperand(Model* model, uint32_t index) {
1038 model->main.operands.erase(model->main.operands.begin() + index);
1039 for (Operation& operation : model->main.operations) {
1040 removeValueAndDecrementGreaterValues(&operation.inputs, index);
1041 removeValueAndDecrementGreaterValues(&operation.outputs, index);
1042 }
1043 removeValueAndDecrementGreaterValues(&model->main.inputIndexes, index);
1044 removeValueAndDecrementGreaterValues(&model->main.outputIndexes, index);
1045 }
1046
removeOperandSkip(size_t operandIndex,const Model & model,const std::vector<uint32_t> & numberOfConsumers)1047 static bool removeOperandSkip(size_t operandIndex, const Model& model,
1048 const std::vector<uint32_t>& numberOfConsumers) {
1049 if (numberOfConsumers[operandIndex] == 0) {
1050 // Removing an unused operand has no effect.
1051 return true;
1052 }
1053 for (const Operation& operation : model.main.operations) {
1054 // Skip removeOperandTest for the following operations.
1055 // - SPLIT's outputs are not checked during prepareModel.
1056 if (operation.type == OperationType::SPLIT) {
1057 for (const size_t index : operation.outputs) {
1058 if (index == operandIndex) {
1059 return true;
1060 }
1061 }
1062 }
1063 // BIDIRECTIONAL_SEQUENCE_LSTM and BIDIRECTIONAL_SEQUENCE_RNN can have
1064 // either one, two, three or four outputs depending on their
1065 // mergeOutputs parameter and if state outputs are provided.
1066 // UNIDIRECTIONAL_SEQUENCE_LSTM and UNIDIRECTIONAL_SEQUENCE_RNN can have
1067 // either one or three outputs depending on whether state outputs are
1068 // provided.
1069 if (operation.type == OperationType::UNIDIRECTIONAL_SEQUENCE_LSTM ||
1070 operation.type == OperationType::UNIDIRECTIONAL_SEQUENCE_RNN ||
1071 operation.type == OperationType::BIDIRECTIONAL_SEQUENCE_LSTM ||
1072 operation.type == OperationType::BIDIRECTIONAL_SEQUENCE_RNN) {
1073 for (const size_t index : operation.outputs) {
1074 if (index == operandIndex) {
1075 return true;
1076 }
1077 }
1078 }
1079 }
1080 return false;
1081 }
1082
removeOperandTest(const std::shared_ptr<IDevice> & device,const Model & model,const std::vector<uint32_t> & numberOfConsumers)1083 static void removeOperandTest(const std::shared_ptr<IDevice>& device, const Model& model,
1084 const std::vector<uint32_t>& numberOfConsumers) {
1085 for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
1086 if (removeOperandSkip(operand, model, numberOfConsumers)) {
1087 continue;
1088 }
1089 const std::string message = "removeOperandTest: operand " + std::to_string(operand);
1090 validate(device, message, model, [operand](Model* model, ExecutionPreference*, Priority*) {
1091 removeOperand(model, operand);
1092 });
1093 }
1094 }
1095
1096 ///////////////////////// REMOVE OPERATION /////////////////////////
1097
removeOperation(Model * model,uint32_t index)1098 static void removeOperation(Model* model, uint32_t index) {
1099 auto& operations = model->main.operations;
1100 operations.erase(operations.begin() + index);
1101 }
1102
removeOperationTest(const std::shared_ptr<IDevice> & device,const Model & model)1103 static void removeOperationTest(const std::shared_ptr<IDevice>& device, const Model& model) {
1104 for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1105 const std::string message = "removeOperationTest: operation " + std::to_string(operation);
1106 validate(device, message, model,
1107 [operation](Model* model, ExecutionPreference*, Priority*) {
1108 removeOperation(model, operation);
1109 });
1110 }
1111 }
1112
1113 ///////////////////////// REMOVE OPERATION INPUT /////////////////////////
1114
removeOperationInputSkip(const Operation & op,size_t input)1115 static bool removeOperationInputSkip(const Operation& op, size_t input) {
1116 // Skip removeOperationInputTest for the following operations.
1117 // - CONCATENATION has at least 2 inputs, with the last element being INT32.
1118 // - CONV_2D, DEPTHWISE_CONV_2D, MAX_POOL_2D, AVERAGE_POOL_2D, L2_POOL_2D, RESIZE_BILINEAR,
1119 // SPACE_TO_DEPTH, SPACE_TO_DEPTH, SPACE_TO_BATCH_ND, BATCH_TO_SPACE_ND can have an optional
1120 // layout parameter.
1121 // RESIZE_BILINEAR and RESIZE_NEAREST_NEIGHBOR can have optional
1122 // align_corners and half_pixel_centers parameters.
1123 // - L2_NORMALIZATION, LOCAL_RESPONSE_NORMALIZATION, SOFTMAX can have an optional axis
1124 // parameter.
1125 switch (op.type) {
1126 case OperationType::CONCATENATION: {
1127 if (op.inputs.size() > 2 && input != op.inputs.size() - 1) {
1128 return true;
1129 }
1130 } break;
1131 case OperationType::DEPTHWISE_CONV_2D: {
1132 if ((op.inputs.size() == 12 && input == 11) || (op.inputs.size() == 9 && input == 8)) {
1133 return true;
1134 }
1135 } break;
1136 case OperationType::CONV_2D:
1137 case OperationType::AVERAGE_POOL_2D:
1138 case OperationType::MAX_POOL_2D:
1139 case OperationType::L2_POOL_2D: {
1140 if ((op.inputs.size() == 11 && input == 10) || (op.inputs.size() == 8 && input == 7)) {
1141 return true;
1142 }
1143 } break;
1144 case OperationType::RESIZE_BILINEAR: {
1145 if (op.inputs.size() >= 4 && input >= 3) {
1146 return true;
1147 }
1148 } break;
1149 case OperationType::RESIZE_NEAREST_NEIGHBOR: {
1150 if (op.inputs.size() >= 5 && input >= 3) {
1151 return true;
1152 }
1153 } break;
1154 case OperationType::SPACE_TO_DEPTH:
1155 case OperationType::DEPTH_TO_SPACE:
1156 case OperationType::BATCH_TO_SPACE_ND: {
1157 if (op.inputs.size() == 3 && input == 2) {
1158 return true;
1159 }
1160 } break;
1161 case OperationType::SPACE_TO_BATCH_ND: {
1162 if (op.inputs.size() == 4 && input == 3) {
1163 return true;
1164 }
1165 } break;
1166 case OperationType::L2_NORMALIZATION: {
1167 if (op.inputs.size() == 2 && input == 1) {
1168 return true;
1169 }
1170 } break;
1171 case OperationType::LOCAL_RESPONSE_NORMALIZATION: {
1172 if (op.inputs.size() == 6 && input == 5) {
1173 return true;
1174 }
1175 } break;
1176 case OperationType::SOFTMAX: {
1177 if (op.inputs.size() == 3 && input == 2) {
1178 return true;
1179 }
1180 } break;
1181 default:
1182 break;
1183 }
1184 return false;
1185 }
1186
removeOperationInputTest(const std::shared_ptr<IDevice> & device,const Model & model)1187 static void removeOperationInputTest(const std::shared_ptr<IDevice>& device, const Model& model) {
1188 for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1189 for (size_t input = 0; input < model.main.operations[operation].inputs.size(); ++input) {
1190 const Operation& op = model.main.operations[operation];
1191 if (removeOperationInputSkip(op, input)) {
1192 continue;
1193 }
1194 const std::string message = "removeOperationInputTest: operation " +
1195 std::to_string(operation) + ", input " +
1196 std::to_string(input);
1197 validate(device, message, model,
1198 [operation, input](Model* model, ExecutionPreference*, Priority*) {
1199 auto& inputs = model->main.operations[operation].inputs;
1200 inputs.erase(inputs.begin() + input);
1201 });
1202 }
1203 }
1204 }
1205
1206 ///////////////////////// REMOVE OPERATION OUTPUT /////////////////////////
1207
removeOperationOutputTest(const std::shared_ptr<IDevice> & device,const Model & model)1208 static void removeOperationOutputTest(const std::shared_ptr<IDevice>& device, const Model& model) {
1209 for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1210 for (size_t output = 0; output < model.main.operations[operation].outputs.size();
1211 ++output) {
1212 const std::string message = "removeOperationOutputTest: operation " +
1213 std::to_string(operation) + ", output " +
1214 std::to_string(output);
1215 validate(device, message, model,
1216 [operation, output](Model* model, ExecutionPreference*, Priority*) {
1217 auto& outputs = model->main.operations[operation].outputs;
1218 outputs.erase(outputs.begin() + output);
1219 });
1220 }
1221 }
1222 }
1223
1224 ///////////////////////// MODEL VALIDATION /////////////////////////
1225
1226 // TODO: remove model input
1227 // TODO: remove model output
1228 // TODO: add unused operation
1229
1230 ///////////////////////// ADD OPERATION INPUT /////////////////////////
1231
addOperationInputSkip(const Operation & op)1232 static bool addOperationInputSkip(const Operation& op) {
1233 // Skip addOperationInputTest for the following operations.
1234 // - L2_NORMALIZATION, LOCAL_RESPONSE_NORMALIZATION, SOFTMAX can have an optional INT32 axis
1235 // parameter.
1236 if ((op.type == OperationType::L2_NORMALIZATION && op.inputs.size() == 1) ||
1237 (op.type == OperationType::LOCAL_RESPONSE_NORMALIZATION && op.inputs.size() == 5) ||
1238 (op.type == OperationType::SOFTMAX && op.inputs.size() == 2) ||
1239 (op.type == OperationType::RESIZE_BILINEAR && op.inputs.size() < 6) ||
1240 (op.type == OperationType::RESIZE_NEAREST_NEIGHBOR && op.inputs.size() < 6)) {
1241 return true;
1242 }
1243 return false;
1244 }
1245
addOperationInputTest(const std::shared_ptr<IDevice> & device,const Model & model)1246 static void addOperationInputTest(const std::shared_ptr<IDevice>& device, const Model& model) {
1247 for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1248 if (addOperationInputSkip(model.main.operations[operation])) {
1249 continue;
1250 }
1251 const std::string message = "addOperationInputTest: operation " + std::to_string(operation);
1252 validate(device, message, model,
1253 [operation](Model* model, ExecutionPreference*, Priority*) {
1254 uint32_t index = addOperand(model, OperandLifeTime::SUBGRAPH_INPUT);
1255 model->main.operations[operation].inputs.push_back(index);
1256 model->main.inputIndexes.push_back(index);
1257 });
1258 }
1259 }
1260
1261 ///////////////////////// ADD OPERATION OUTPUT /////////////////////////
1262
addOperationOutputTest(const std::shared_ptr<IDevice> & device,const Model & model)1263 static void addOperationOutputTest(const std::shared_ptr<IDevice>& device, const Model& model) {
1264 for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1265 const std::string message =
1266 "addOperationOutputTest: operation " + std::to_string(operation);
1267 validate(device, message, model,
1268 [operation](Model* model, ExecutionPreference*, Priority*) {
1269 uint32_t index = addOperand(model, OperandLifeTime::SUBGRAPH_OUTPUT);
1270 model->main.operations[operation].outputs.push_back(index);
1271 model->main.outputIndexes.push_back(index);
1272 });
1273 }
1274 }
1275
1276 ///////////////////////// VALIDATE EXECUTION PREFERENCE /////////////////////////
1277
1278 static const int32_t invalidExecutionPreferences[] = {
1279 static_cast<int32_t>(ExecutionPreference::LOW_POWER) - 1, // lower bound
1280 static_cast<int32_t>(ExecutionPreference::SUSTAINED_SPEED) + 1, // upper bound
1281 };
1282
mutateExecutionPreferenceTest(const std::shared_ptr<IDevice> & device,const Model & model)1283 static void mutateExecutionPreferenceTest(const std::shared_ptr<IDevice>& device,
1284 const Model& model) {
1285 for (int32_t invalidPreference : invalidExecutionPreferences) {
1286 const std::string message =
1287 "mutateExecutionPreferenceTest: preference " + std::to_string(invalidPreference);
1288 validate(device, message, model,
1289 [invalidPreference](Model*, ExecutionPreference* preference, Priority*) {
1290 *preference = static_cast<ExecutionPreference>(invalidPreference);
1291 });
1292 }
1293 }
1294
1295 ///////////////////////// VALIDATE PRIORITY /////////////////////////
1296
1297 static const int32_t invalidPriorities[] = {
1298 static_cast<int32_t>(Priority::LOW) - 1, // lower bound
1299 static_cast<int32_t>(Priority::HIGH) + 1, // upper bound
1300 };
1301
mutateExecutionPriorityTest(const std::shared_ptr<IDevice> & device,const Model & model)1302 static void mutateExecutionPriorityTest(const std::shared_ptr<IDevice>& device,
1303 const Model& model) {
1304 for (int32_t invalidPriority : invalidPriorities) {
1305 const std::string message =
1306 "mutatePriorityTest: priority " + std::to_string(invalidPriority);
1307 validate(device, message, model,
1308 [invalidPriority](Model*, ExecutionPreference*, Priority* priority) {
1309 *priority = static_cast<Priority>(invalidPriority);
1310 });
1311 }
1312 }
1313
1314 ////////////////////////// ENTRY POINT //////////////////////////////
1315
validateModel(const std::shared_ptr<IDevice> & device,const Model & model)1316 void validateModel(const std::shared_ptr<IDevice>& device, const Model& model) {
1317 const auto numberOfConsumers =
1318 nn::countNumberOfConsumers(model.main.operands.size(),
1319 nn::unvalidatedConvert(model.main.operations).value())
1320 .value();
1321 mutateExecutionOrderTest(device, model, numberOfConsumers);
1322 mutateOperandTypeTest(device, model);
1323 mutateOperandRankTest(device, model);
1324 mutateOperandScaleTest(device, model);
1325 mutateOperandZeroPointTest(device, model);
1326 mutateOperandLifeTimeTest(device, model);
1327 mutateOperandInputOutputTest(device, model);
1328 mutateOperandAddWriterTest(device, model);
1329 mutateOperationOperandTypeTest(device, model);
1330 mutateOperationTypeTest(device, model);
1331 mutateOperationInputOperandIndexTest(device, model);
1332 mutateOperationOutputOperandIndexTest(device, model);
1333 mutateOperationRemoveWriteTest(device, model, numberOfConsumers);
1334 removeOperandTest(device, model, numberOfConsumers);
1335 removeOperationTest(device, model);
1336 removeOperationInputTest(device, model);
1337 removeOperationOutputTest(device, model);
1338 addOperationInputTest(device, model);
1339 addOperationOutputTest(device, model);
1340 mutateExecutionPreferenceTest(device, model);
1341 mutateExecutionPriorityTest(device, model);
1342 }
1343
1344 } // namespace aidl::android::hardware::neuralnetworks::vts::functional
1345