1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_aidl_hal_test"
18 
19 #include <aidl/android/hardware/common/NativeHandle.h>
20 #include <android/binder_auto_utils.h>
21 #include <android/binder_enums.h>
22 #include <android/binder_interface_utils.h>
23 #include <nnapi/TypeUtils.h>
24 #include <nnapi/hal/aidl/Conversions.h>
25 #include <nnapi/hal/aidl/Utils.h>
26 
27 #include <optional>
28 #include <type_traits>
29 #include <utility>
30 
31 #include "Callbacks.h"
32 #include "GeneratedTestHarness.h"
33 #include "Utils.h"
34 #include "VtsHalNeuralnetworks.h"
35 
36 namespace aidl::android::hardware::neuralnetworks::vts::functional {
37 
38 using common::NativeHandle;
39 using implementation::PreparedModelCallback;
40 
41 using PrepareModelMutation = std::function<void(Model*, ExecutionPreference*, Priority*)>;
42 
43 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
44 
validateGetSupportedOperations(const std::shared_ptr<IDevice> & device,const std::string & message,const Model & model)45 static void validateGetSupportedOperations(const std::shared_ptr<IDevice>& device,
46                                            const std::string& message, const Model& model) {
47     SCOPED_TRACE(message + " [getSupportedOperations]");
48 
49     std::vector<bool> supported;
50     const auto retStatus = device->getSupportedOperations(model, &supported);
51 
52     ASSERT_FALSE(retStatus.isOk());
53     ASSERT_EQ(retStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
54     ASSERT_EQ(static_cast<ErrorStatus>(retStatus.getServiceSpecificError()),
55               ErrorStatus::INVALID_ARGUMENT);
56 }
57 
validatePrepareModel(const std::shared_ptr<IDevice> & device,const std::string & message,const Model & model,ExecutionPreference preference,Priority priority)58 static void validatePrepareModel(const std::shared_ptr<IDevice>& device, const std::string& message,
59                                  const Model& model, ExecutionPreference preference,
60                                  Priority priority) {
61     SCOPED_TRACE(message + " [prepareModel]");
62 
63     std::shared_ptr<PreparedModelCallback> preparedModelCallback =
64             ndk::SharedRefBase::make<PreparedModelCallback>();
65     const auto prepareLaunchStatus =
66             device->prepareModel(model, preference, priority, kNoDeadline, {}, {}, kEmptyCacheToken,
67                                  preparedModelCallback);
68     ASSERT_FALSE(prepareLaunchStatus.isOk());
69     ASSERT_EQ(prepareLaunchStatus.getExceptionCode(), EX_SERVICE_SPECIFIC);
70     ASSERT_EQ(static_cast<ErrorStatus>(prepareLaunchStatus.getServiceSpecificError()),
71               ErrorStatus::INVALID_ARGUMENT);
72 
73     preparedModelCallback->wait();
74     ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
75     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, prepareReturnStatus);
76     std::shared_ptr<IPreparedModel> preparedModel = preparedModelCallback->getPreparedModel();
77     ASSERT_EQ(nullptr, preparedModel.get());
78 }
79 
validExecutionPreference(ExecutionPreference preference)80 static bool validExecutionPreference(ExecutionPreference preference) {
81     return preference == ExecutionPreference::LOW_POWER ||
82            preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
83            preference == ExecutionPreference::SUSTAINED_SPEED;
84 }
85 
validExecutionPriority(Priority priority)86 static bool validExecutionPriority(Priority priority) {
87     return priority == Priority::LOW || priority == Priority::MEDIUM || priority == Priority::HIGH;
88 }
89 
90 // Primary validation function. This function will take a valid model, apply a
91 // mutation to invalidate the model, the execution preference, or the priority,
92 // then pass these to supportedOperations and/or prepareModel if that method is
93 // called with an invalid argument.
validate(const std::shared_ptr<IDevice> & device,const std::string & message,const Model & originalModel,const PrepareModelMutation & mutate)94 static void validate(const std::shared_ptr<IDevice>& device, const std::string& message,
95                      const Model& originalModel, const PrepareModelMutation& mutate) {
96     Model model = utils::clone(originalModel).value();
97     ExecutionPreference preference = ExecutionPreference::FAST_SINGLE_ANSWER;
98     Priority priority = kDefaultPriority;
99     mutate(&model, &preference, &priority);
100 
101     if (validExecutionPreference(preference) && validExecutionPriority(priority)) {
102         validateGetSupportedOperations(device, message, model);
103     }
104 
105     validatePrepareModel(device, message, model, preference, priority);
106 }
107 
addOperand(Model * model)108 static uint32_t addOperand(Model* model) {
109     model->main.operands.push_back({
110             .type = OperandType::INT32,
111             .dimensions = {},
112             .scale = 0.0f,
113             .zeroPoint = 0,
114             .lifetime = OperandLifeTime::SUBGRAPH_INPUT,
115             .location = {.poolIndex = 0, .offset = 0, .length = 0},
116     });
117     return model->main.operands.size() - 1;
118 }
119 
addOperand(Model * model,OperandLifeTime lifetime)120 static uint32_t addOperand(Model* model, OperandLifeTime lifetime) {
121     uint32_t index = addOperand(model);
122     model->main.operands[index].lifetime = lifetime;
123     return index;
124 }
125 
126 // If we introduce a CONSTANT_COPY for an operand of size operandSize,
127 // how much will this increase the size of the model?  This assumes
128 // that we can (re)use all of model.operandValues for the operand
129 // value.
constantCopyExtraSize(const Model & model,size_t operandSize)130 static size_t constantCopyExtraSize(const Model& model, size_t operandSize) {
131     const size_t operandValuesSize = model.operandValues.size();
132     return (operandValuesSize < operandSize) ? (operandSize - operandValuesSize) : 0;
133 }
134 
135 // Highly specialized utility routine for converting an operand to
136 // CONSTANT_COPY lifetime.
137 //
138 // Expects that:
139 // - operand has a known size
140 // - operand->lifetime has already been set to CONSTANT_COPY
141 // - operand->location has been zeroed out
142 //
143 // Does the following:
144 // - initializes operand->location to point to the beginning of model->operandValues
145 // - resizes model->operandValues (if necessary) to be large enough for the operand
146 //   value, padding it with zeroes on the end
147 //
148 // Potential problem:
149 // By changing the operand to CONSTANT_COPY lifetime, this function is effectively initializing the
150 // operand with unspecified (but deterministic) data. This means that the model may be invalidated
151 // in two ways: not only is the lifetime of CONSTANT_COPY invalid, but the operand's value in the
152 // graph may also be invalid (e.g., if the operand is used as an activation code and has an invalid
153 // value). For now, this should be fine because it just means we're not testing what we think we're
154 // testing in certain cases; but we can handwave this and assume we're probabilistically likely to
155 // exercise the validation code over the span of the entire test set and operand space.
156 //
157 // Aborts if the specified operand type is an extension type or OEM type.
becomeConstantCopy(Model * model,Operand * operand)158 static void becomeConstantCopy(Model* model, Operand* operand) {
159     // sizeOfData will abort if the specified type is an extension type or OEM type.
160     const size_t sizeOfOperand = sizeOfData(*operand);
161     EXPECT_NE(sizeOfOperand, size_t(0));
162     operand->location.poolIndex = 0;
163     operand->location.offset = 0;
164     operand->location.length = sizeOfOperand;
165     if (model->operandValues.size() < sizeOfOperand) {
166         model->operandValues.resize(sizeOfOperand);
167     }
168 }
169 
170 // The sizeForBinder() functions estimate the size of the
171 // representation of a value when sent to binder.  It's probably a bit
172 // of an under-estimate, because we don't know the size of the
173 // metadata in the binder format (e.g., representation of the size of
174 // a vector); but at least it adds up "big" things like vector
175 // contents.  However, it doesn't treat inter-field or end-of-struct
176 // padding in a methodical way -- there's no attempt to be consistent
177 // in whether or not padding in the native (C++) representation
178 // contributes to the estimated size for the binder representation;
179 // and there's no attempt to understand what padding (if any) is
180 // needed in the binder representation.
181 //
182 // This assumes that non-metadata uses a fixed length encoding (e.g.,
183 // a uint32_t is always encoded in sizeof(uint32_t) bytes, rather than
184 // using an encoding whose length is related to the magnitude of the
185 // encoded value).
186 
187 template <typename Type>
sizeForBinder(const Type & val)188 static size_t sizeForBinder(const Type& val) {
189     static_assert(std::is_trivially_copyable_v<std::remove_reference_t<Type>>,
190                   "expected a trivially copyable type");
191     return sizeof(val);
192 }
193 
194 template <typename Type>
sizeForBinder(const std::vector<Type> & vec)195 static size_t sizeForBinder(const std::vector<Type>& vec) {
196     return std::accumulate(vec.begin(), vec.end(), 0,
197                            [](size_t acc, const Type& x) { return acc + sizeForBinder(x); });
198 }
199 
200 template <>
sizeForBinder(const SymmPerChannelQuantParams & symmPerChannelQuantParams)201 size_t sizeForBinder(const SymmPerChannelQuantParams& symmPerChannelQuantParams) {
202     size_t size = 0;
203 
204     size += sizeForBinder(symmPerChannelQuantParams.scales);
205     size += sizeForBinder(symmPerChannelQuantParams.channelDim);
206 
207     return size;
208 }
209 
210 template <>
sizeForBinder(const std::optional<OperandExtraParams> & optionalExtraParams)211 size_t sizeForBinder(const std::optional<OperandExtraParams>& optionalExtraParams) {
212     if (!optionalExtraParams.has_value()) {
213         return 0;
214     }
215     const auto& extraParams = optionalExtraParams.value();
216     using Tag = OperandExtraParams::Tag;
217     switch (extraParams.getTag()) {
218         case Tag::channelQuant:
219             return sizeForBinder(extraParams.get<Tag::channelQuant>());
220         case Tag::extension:
221             return sizeForBinder(extraParams.get<Tag::extension>());
222     }
223     LOG(FATAL) << "Unrecognized extraParams tag: " << static_cast<int>(extraParams.getTag());
224     return 0;
225 }
226 
227 template <>
sizeForBinder(const Operand & operand)228 size_t sizeForBinder(const Operand& operand) {
229     size_t size = 0;
230 
231     size += sizeForBinder(operand.type);
232     size += sizeForBinder(operand.dimensions);
233     size += sizeForBinder(operand.scale);
234     size += sizeForBinder(operand.zeroPoint);
235     size += sizeForBinder(operand.lifetime);
236     size += sizeForBinder(operand.location);
237     size += sizeForBinder(operand.extraParams);
238 
239     return size;
240 }
241 
242 template <>
sizeForBinder(const Operation & operation)243 size_t sizeForBinder(const Operation& operation) {
244     size_t size = 0;
245 
246     size += sizeForBinder(operation.type);
247     size += sizeForBinder(operation.inputs);
248     size += sizeForBinder(operation.outputs);
249 
250     return size;
251 }
252 
253 template <>
sizeForBinder(const std::string & name)254 size_t sizeForBinder(const std::string& name) {
255     return name.size();
256 }
257 
258 template <>
sizeForBinder(const Memory & memory)259 size_t sizeForBinder(const Memory& memory) {
260     // This is just a guess.
261 
262     size_t size = sizeof(Memory);
263 
264     // Only hardwareBuffer type memory has dynamic memory that needs to be accounted for (in the
265     // form of a NativeHandle type). The other other types of memory (MappableFile, Ashmem) use a
266     // single file descriptor (with metadata) instead.
267     if (memory.getTag() == Memory::Tag::hardwareBuffer) {
268         const NativeHandle& handle = memory.get<Memory::Tag::hardwareBuffer>().handle;
269         size += sizeof(decltype(handle.fds)::value_type) * handle.fds.size();
270         size += sizeof(decltype(handle.ints)::value_type) * handle.ints.size();
271     }
272 
273     return size;
274 }
275 
276 template <>
sizeForBinder(const Subgraph & subgraph)277 size_t sizeForBinder(const Subgraph& subgraph) {
278     size_t size = 0;
279 
280     size += sizeForBinder(subgraph.operands);
281     size += sizeForBinder(subgraph.operations);
282     size += sizeForBinder(subgraph.inputIndexes);
283     size += sizeForBinder(subgraph.outputIndexes);
284 
285     return size;
286 }
287 
288 template <>
sizeForBinder(const ExtensionNameAndPrefix & extensionNameToPrefix)289 size_t sizeForBinder(const ExtensionNameAndPrefix& extensionNameToPrefix) {
290     size_t size = 0;
291 
292     size += sizeForBinder(extensionNameToPrefix.name);
293     size += sizeForBinder(extensionNameToPrefix.prefix);
294 
295     return size;
296 }
297 
298 template <>
sizeForBinder(const Model & model)299 size_t sizeForBinder(const Model& model) {
300     size_t size = 0;
301 
302     size += sizeForBinder(model.main);
303     size += sizeForBinder(model.referenced);
304     size += sizeForBinder(model.operandValues);
305     size += sizeForBinder(model.pools);
306     size += sizeForBinder(model.relaxComputationFloat32toFloat16);
307     size += sizeForBinder(model.extensionNameToPrefix);
308 
309     return size;
310 }
311 
312 // https://developer.android.com/reference/android/os/TransactionTooLargeException.html
313 //
314 //     "The Binder transaction buffer has a limited fixed size,
315 //     currently 1Mb, which is shared by all transactions in progress
316 //     for the process."
317 //
318 // Will our representation fit under this limit?  There are two complications:
319 // - Our representation size is just approximate (see sizeForBinder()).
320 // - This object may not be the only occupant of the Binder transaction buffer.
321 // So we'll be very conservative: We want the representation size to be no
322 // larger than half the transaction buffer size.
323 //
324 // If our representation grows large enough that it still fits within
325 // the transaction buffer but combined with other transactions may
326 // exceed the buffer size, then we may see intermittent HAL transport
327 // errors.
exceedsBinderSizeLimit(size_t representationSize)328 static bool exceedsBinderSizeLimit(size_t representationSize) {
329     // Instead of using this fixed buffer size, we might instead be able to use
330     // ProcessState::self()->getMmapSize(). However, this has a potential
331     // problem: The binder/mmap size of the current process does not necessarily
332     // indicate the binder/mmap size of the service (i.e., the other process).
333     // The only way it would be a good indication is if both the current process
334     // and the service use the default size.
335     static const size_t kHalfBufferSize = 1024 * 1024 / 2;
336 
337     return representationSize > kHalfBufferSize;
338 }
339 
340 ///////////////////////// VALIDATE EXECUTION ORDER ////////////////////////////
341 
mutateExecutionOrderTest(const std::shared_ptr<IDevice> & device,const Model & model,const std::vector<uint32_t> & numberOfConsumers)342 static void mutateExecutionOrderTest(const std::shared_ptr<IDevice>& device, const Model& model,
343                                      const std::vector<uint32_t>& numberOfConsumers) {
344     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
345         const Operation& operationObj = model.main.operations[operation];
346         for (uint32_t input : operationObj.inputs) {
347             if (model.main.operands[input].lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
348                 model.main.operands[input].lifetime == OperandLifeTime::SUBGRAPH_OUTPUT) {
349                 // This operation reads an operand written by some
350                 // other operation.  Move this operation to the
351                 // beginning of the sequence, ensuring that it reads
352                 // the operand before that operand is written, thereby
353                 // violating execution order rules.
354                 const std::string message = "mutateExecutionOrderTest: operation " +
355                                             std::to_string(operation) + " is a reader";
356                 validate(device, message, model,
357                          [operation](Model* model, ExecutionPreference*, Priority*) {
358                              auto& operations = model->main.operations;
359                              std::rotate(operations.begin(), operations.begin() + operation,
360                                          operations.begin() + operation + 1);
361                          });
362                 break;  // only need to do this once per operation
363             }
364         }
365         for (uint32_t output : operationObj.outputs) {
366             if (numberOfConsumers[output] > 0) {
367                 // This operation writes an operand read by some other
368                 // operation.  Move this operation to the end of the
369                 // sequence, ensuring that it writes the operand after
370                 // that operand is read, thereby violating execution
371                 // order rules.
372                 const std::string message = "mutateExecutionOrderTest: operation " +
373                                             std::to_string(operation) + " is a writer";
374                 validate(device, message, model,
375                          [operation](Model* model, ExecutionPreference*, Priority*) {
376                              auto& operations = model->main.operations;
377                              std::rotate(operations.begin() + operation,
378                                          operations.begin() + operation + 1, operations.end());
379                          });
380                 break;  // only need to do this once per operation
381             }
382         }
383     }
384 }
385 
386 ///////////////////////// VALIDATE MODEL OPERAND TYPE /////////////////////////
387 
388 static const int32_t invalidOperandTypes[] = {
389         -1,
390         static_cast<int32_t>(*(ndk::enum_range<OperandType>().end() - 1)) + 1,
391 };
392 
mutateOperandTypeTest(const std::shared_ptr<IDevice> & device,const Model & model)393 static void mutateOperandTypeTest(const std::shared_ptr<IDevice>& device, const Model& model) {
394     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
395         for (int32_t invalidOperandType : invalidOperandTypes) {
396             const std::string message = "mutateOperandTypeTest: operand " +
397                                         std::to_string(operand) + " set to value " +
398                                         std::to_string(invalidOperandType);
399             validate(device, message, model,
400                      [operand, invalidOperandType](Model* model, ExecutionPreference*, Priority*) {
401                          model->main.operands[operand].type =
402                                  static_cast<OperandType>(invalidOperandType);
403                      });
404         }
405     }
406 }
407 
408 ///////////////////////// VALIDATE OPERAND RANK /////////////////////////
409 
getInvalidRank(OperandType type)410 static uint32_t getInvalidRank(OperandType type) {
411     switch (type) {
412         case OperandType::FLOAT16:
413         case OperandType::FLOAT32:
414         case OperandType::INT32:
415         case OperandType::UINT32:
416         case OperandType::BOOL:
417             return 1;
418         case OperandType::TENSOR_BOOL8:
419         case OperandType::TENSOR_FLOAT16:
420         case OperandType::TENSOR_FLOAT32:
421         case OperandType::TENSOR_INT32:
422         case OperandType::TENSOR_QUANT8_ASYMM:
423         case OperandType::TENSOR_QUANT8_SYMM:
424         case OperandType::TENSOR_QUANT16_ASYMM:
425         case OperandType::TENSOR_QUANT16_SYMM:
426         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
427             return 0;
428         default:
429             return 0;
430     }
431 }
432 
mutateOperandRankTest(const std::shared_ptr<IDevice> & device,const Model & model)433 static void mutateOperandRankTest(const std::shared_ptr<IDevice>& device, const Model& model) {
434     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
435         const uint32_t invalidRank = getInvalidRank(model.main.operands[operand].type);
436         if (invalidRank == 0) {
437             continue;
438         }
439         const std::string message = "mutateOperandRankTest: operand " + std::to_string(operand) +
440                                     " has rank of " + std::to_string(invalidRank);
441         validate(device, message, model,
442                  [operand, invalidRank](Model* model, ExecutionPreference*, Priority*) {
443                      model->main.operands[operand].dimensions =
444                              std::vector<int32_t>(invalidRank, 0);
445                  });
446     }
447 }
448 
449 ///////////////////////// VALIDATE OPERAND SCALE /////////////////////////
450 
getInvalidScale(OperandType type)451 static float getInvalidScale(OperandType type) {
452     switch (type) {
453         case OperandType::FLOAT16:
454         case OperandType::FLOAT32:
455         case OperandType::INT32:
456         case OperandType::UINT32:
457         case OperandType::BOOL:
458         case OperandType::TENSOR_BOOL8:
459         case OperandType::TENSOR_FLOAT16:
460         case OperandType::TENSOR_FLOAT32:
461         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
462         case OperandType::SUBGRAPH:
463             return 1.0f;
464         case OperandType::TENSOR_INT32:
465             return -1.0f;
466         case OperandType::TENSOR_QUANT8_SYMM:
467         case OperandType::TENSOR_QUANT8_ASYMM:
468         case OperandType::TENSOR_QUANT16_ASYMM:
469         case OperandType::TENSOR_QUANT16_SYMM:
470             return 0.0f;
471         default:
472             return 0.0f;
473     }
474 }
475 
mutateOperandScaleTest(const std::shared_ptr<IDevice> & device,const Model & model)476 static void mutateOperandScaleTest(const std::shared_ptr<IDevice>& device, const Model& model) {
477     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
478         const float invalidScale = getInvalidScale(model.main.operands[operand].type);
479         const std::string message = "mutateOperandScaleTest: operand " + std::to_string(operand) +
480                                     " has scale of " + std::to_string(invalidScale);
481         validate(device, message, model,
482                  [operand, invalidScale](Model* model, ExecutionPreference*, Priority*) {
483                      model->main.operands[operand].scale = invalidScale;
484                  });
485     }
486 }
487 
488 ///////////////////////// VALIDATE OPERAND ZERO POINT /////////////////////////
489 
getInvalidZeroPoints(OperandType type)490 static std::vector<int32_t> getInvalidZeroPoints(OperandType type) {
491     switch (type) {
492         case OperandType::FLOAT16:
493         case OperandType::FLOAT32:
494         case OperandType::INT32:
495         case OperandType::UINT32:
496         case OperandType::BOOL:
497         case OperandType::TENSOR_BOOL8:
498         case OperandType::TENSOR_FLOAT16:
499         case OperandType::TENSOR_FLOAT32:
500         case OperandType::TENSOR_INT32:
501         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
502         case OperandType::SUBGRAPH:
503             return {1};
504         case OperandType::TENSOR_QUANT8_ASYMM:
505             return {-1, 256};
506         case OperandType::TENSOR_QUANT8_SYMM:
507             return {-129, -1, 1, 128};
508         case OperandType::TENSOR_QUANT16_ASYMM:
509             return {-1, 65536};
510         case OperandType::TENSOR_QUANT16_SYMM:
511             return {-32769, -1, 1, 32768};
512         default:
513             return {};
514     }
515 }
516 
mutateOperandZeroPointTest(const std::shared_ptr<IDevice> & device,const Model & model)517 static void mutateOperandZeroPointTest(const std::shared_ptr<IDevice>& device, const Model& model) {
518     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
519         const std::vector<int32_t> invalidZeroPoints =
520                 getInvalidZeroPoints(model.main.operands[operand].type);
521         for (int32_t invalidZeroPoint : invalidZeroPoints) {
522             const std::string message = "mutateOperandZeroPointTest: operand " +
523                                         std::to_string(operand) + " has zero point of " +
524                                         std::to_string(invalidZeroPoint);
525             validate(device, message, model,
526                      [operand, invalidZeroPoint](Model* model, ExecutionPreference*, Priority*) {
527                          model->main.operands[operand].zeroPoint = invalidZeroPoint;
528                      });
529         }
530     }
531 }
532 
533 ///////////////////////// VALIDATE OPERAND LIFETIME /////////////////////////////////////////////
534 
getInvalidLifeTimes(const Model & model,size_t modelSize,const Operand & operand)535 static std::vector<OperandLifeTime> getInvalidLifeTimes(const Model& model, size_t modelSize,
536                                                         const Operand& operand) {
537     // TODO: Support OperandLifeTime::CONSTANT_REFERENCE as an invalid lifetime
538     // TODO: Support OperandLifeTime::NO_VALUE as an invalid lifetime
539 
540     // Ways to get an invalid lifetime:
541     // - change whether a lifetime means an operand should have a writer
542     std::vector<OperandLifeTime> ret;
543     switch (operand.lifetime) {
544         case OperandLifeTime::SUBGRAPH_OUTPUT:
545         case OperandLifeTime::TEMPORARY_VARIABLE:
546             ret = {
547                     OperandLifeTime::SUBGRAPH_INPUT,
548                     OperandLifeTime::CONSTANT_COPY,
549             };
550             break;
551         case OperandLifeTime::CONSTANT_COPY:
552         case OperandLifeTime::CONSTANT_POOL:
553         case OperandLifeTime::SUBGRAPH_INPUT:
554             ret = {
555                     OperandLifeTime::TEMPORARY_VARIABLE,
556                     OperandLifeTime::SUBGRAPH_OUTPUT,
557             };
558             break;
559         case OperandLifeTime::NO_VALUE:
560             // Not enough information to know whether
561             // TEMPORARY_VARIABLE or CONSTANT_COPY would be invalid --
562             // is this operand written (then CONSTANT_COPY would be
563             // invalid) or not (then TEMPORARY_VARIABLE would be
564             // invalid)?
565             break;
566         case OperandLifeTime::SUBGRAPH:
567             break;
568         default:
569             ADD_FAILURE();
570             break;
571     }
572 
573     const size_t operandSize = sizeOfData(operand);  // will be zero if shape is unknown
574     if (!operandSize ||
575         exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
576         // Unknown size or too-large size
577         ret.erase(std::remove(ret.begin(), ret.end(), OperandLifeTime::CONSTANT_COPY), ret.end());
578     }
579 
580     return ret;
581 }
582 
mutateOperandLifeTimeTest(const std::shared_ptr<IDevice> & device,const Model & model)583 static void mutateOperandLifeTimeTest(const std::shared_ptr<IDevice>& device, const Model& model) {
584     const size_t modelSize = sizeForBinder(model);
585     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
586         const std::vector<OperandLifeTime> invalidLifeTimes =
587                 getInvalidLifeTimes(model, modelSize, model.main.operands[operand]);
588         for (OperandLifeTime invalidLifeTime : invalidLifeTimes) {
589             const std::string message = "mutateOperandLifetimeTest: operand " +
590                                         std::to_string(operand) + " has lifetime " +
591                                         toString(invalidLifeTime) + " instead of lifetime " +
592                                         toString(model.main.operands[operand].lifetime);
593             validate(device, message, model,
594                      [operand, invalidLifeTime](Model* model, ExecutionPreference*, Priority*) {
595                          static const DataLocation kZeroDataLocation = {};
596                          Operand& operandObj = model->main.operands[operand];
597                          switch (operandObj.lifetime) {
598                              case OperandLifeTime::SUBGRAPH_INPUT: {
599                                  auto& inputs = model->main.inputIndexes;
600                                  inputs.erase(std::remove(inputs.begin(), inputs.end(), operand),
601                                               inputs.end());
602                                  break;
603                              }
604                              case OperandLifeTime::SUBGRAPH_OUTPUT: {
605                                  auto& outputs = model->main.outputIndexes;
606                                  outputs.erase(std::remove(outputs.begin(), outputs.end(), operand),
607                                                outputs.end());
608                                  break;
609                              }
610                              default:
611                                  break;
612                          }
613                          operandObj.lifetime = invalidLifeTime;
614                          operandObj.location = kZeroDataLocation;
615                          switch (invalidLifeTime) {
616                              case OperandLifeTime::CONSTANT_COPY: {
617                                  becomeConstantCopy(model, &operandObj);
618                                  break;
619                              }
620                              case OperandLifeTime::SUBGRAPH_INPUT:
621                                  model->main.inputIndexes.push_back(operand);
622                                  break;
623                              case OperandLifeTime::SUBGRAPH_OUTPUT:
624                                  model->main.outputIndexes.push_back(operand);
625                                  break;
626                              default:
627                                  break;
628                          }
629                      });
630         }
631     }
632 }
633 
634 ///////////////////////// VALIDATE OPERAND INPUT-or-OUTPUT //////////////////////////////////////
635 
getInputOutputLifeTime(const Model & model,size_t modelSize,const Operand & operand)636 static std::optional<OperandLifeTime> getInputOutputLifeTime(const Model& model, size_t modelSize,
637                                                              const Operand& operand) {
638     // Ways to get an invalid lifetime (with respect to model inputIndexes and outputIndexes):
639     // - change whether a lifetime means an operand is a model input, a model output, or neither
640     // - preserve whether or not a lifetime means an operand should have a writer
641     switch (operand.lifetime) {
642         case OperandLifeTime::CONSTANT_COPY:
643         case OperandLifeTime::CONSTANT_POOL:
644             return OperandLifeTime::SUBGRAPH_INPUT;
645         case OperandLifeTime::SUBGRAPH_INPUT: {
646             const size_t operandSize = sizeOfData(operand);  // will be zero if shape is unknown
647             if (!operandSize ||
648                 exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
649                 // Unknown size or too-large size
650                 break;
651             }
652             return OperandLifeTime::CONSTANT_COPY;
653         }
654         case OperandLifeTime::SUBGRAPH_OUTPUT:
655             return OperandLifeTime::TEMPORARY_VARIABLE;
656         case OperandLifeTime::TEMPORARY_VARIABLE:
657             return OperandLifeTime::SUBGRAPH_OUTPUT;
658         case OperandLifeTime::NO_VALUE:
659             // Not enough information to know whether
660             // TEMPORARY_VARIABLE or CONSTANT_COPY would be an
661             // appropriate choice -- is this operand written (then
662             // TEMPORARY_VARIABLE would be appropriate) or not (then
663             // CONSTANT_COPY would be appropriate)?
664             break;
665         case OperandLifeTime::SUBGRAPH:
666             break;
667         default:
668             ADD_FAILURE();
669             break;
670     }
671 
672     return std::nullopt;
673 }
674 
mutateOperandInputOutputTest(const std::shared_ptr<IDevice> & device,const Model & model)675 static void mutateOperandInputOutputTest(const std::shared_ptr<IDevice>& device,
676                                          const Model& model) {
677     const size_t modelSize = sizeForBinder(model);
678     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
679         const std::optional<OperandLifeTime> changedLifeTime =
680                 getInputOutputLifeTime(model, modelSize, model.main.operands[operand]);
681         if (changedLifeTime) {
682             const std::string message = "mutateOperandInputOutputTest: operand " +
683                                         std::to_string(operand) + " has lifetime " +
684                                         toString(*changedLifeTime) + " instead of lifetime " +
685                                         toString(model.main.operands[operand].lifetime);
686             validate(device, message, model,
687                      [operand, changedLifeTime](Model* model, ExecutionPreference*, Priority*) {
688                          static const DataLocation kZeroDataLocation = {};
689                          Operand& operandObj = model->main.operands[operand];
690                          operandObj.lifetime = *changedLifeTime;
691                          operandObj.location = kZeroDataLocation;
692                          if (*changedLifeTime == OperandLifeTime::CONSTANT_COPY) {
693                              becomeConstantCopy(model, &operandObj);
694                          }
695                      });
696         }
697     }
698 }
699 
700 ///////////////////////// VALIDATE OPERAND NUMBER OF WRITERS ////////////////////////////////////
701 
mutateOperandAddWriterTest(const std::shared_ptr<IDevice> & device,const Model & model)702 static void mutateOperandAddWriterTest(const std::shared_ptr<IDevice>& device, const Model& model) {
703     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
704         for (size_t badOutputNum = 0;
705              badOutputNum < model.main.operations[operation].outputs.size(); ++badOutputNum) {
706             const uint32_t outputOperandIndex =
707                     model.main.operations[operation].outputs[badOutputNum];
708             const std::string message = "mutateOperandAddWriterTest: operation " +
709                                         std::to_string(operation) + " writes to " +
710                                         std::to_string(outputOperandIndex);
711             // We'll insert a copy of the operation, all of whose
712             // OTHER output operands are newly-created -- i.e.,
713             // there'll only be a duplicate write of ONE of that
714             // operation's output operands.
715             validate(device, message, model,
716                      [operation, badOutputNum](Model* model, ExecutionPreference*, Priority*) {
717                          Operation newOperation = model->main.operations[operation];
718                          for (size_t outputNum = 0; outputNum < newOperation.outputs.size();
719                               ++outputNum) {
720                              if (outputNum == badOutputNum) continue;
721 
722                              Operand operandValue =
723                                      model->main.operands[newOperation.outputs[outputNum]];
724                              if (operandValue.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT) {
725                                  operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
726                              } else {
727                                  ASSERT_EQ(operandValue.lifetime,
728                                            OperandLifeTime::TEMPORARY_VARIABLE);
729                              }
730                              newOperation.outputs[outputNum] = model->main.operands.size();
731                              model->main.operands.push_back(operandValue);
732                          }
733                          // Where do we insert the extra writer (a new
734                          // operation)?  It has to be later than all the
735                          // writers of its inputs.  The easiest thing to do
736                          // is to insert it at the end of the operation
737                          // sequence.
738                          model->main.operations.push_back(newOperation);
739                      });
740         }
741     }
742 }
743 
744 ///////////////////////// VALIDATE EXTRA ??? /////////////////////////
745 
746 // TODO: Operand::location
747 
748 ///////////////////////// VALIDATE OPERATION OPERAND TYPE /////////////////////////
749 
mutateOperand(Operand * operand,OperandType type)750 static void mutateOperand(Operand* operand, OperandType type) {
751     Operand newOperand = *operand;
752     newOperand.type = type;
753     switch (type) {
754         case OperandType::FLOAT16:
755         case OperandType::FLOAT32:
756         case OperandType::INT32:
757         case OperandType::UINT32:
758         case OperandType::BOOL:
759             newOperand.dimensions = {};
760             newOperand.scale = 0.0f;
761             newOperand.zeroPoint = 0;
762             break;
763         case OperandType::TENSOR_BOOL8:
764         case OperandType::TENSOR_FLOAT16:
765         case OperandType::TENSOR_FLOAT32:
766             newOperand.dimensions = operand->dimensions.size() > 0 ? operand->dimensions
767                                                                    : std::vector<int32_t>({1});
768             newOperand.scale = 0.0f;
769             newOperand.zeroPoint = 0;
770             break;
771         case OperandType::TENSOR_INT32:
772             newOperand.dimensions = operand->dimensions.size() > 0 ? operand->dimensions
773                                                                    : std::vector<int32_t>({1});
774             newOperand.zeroPoint = 0;
775             break;
776         case OperandType::TENSOR_QUANT8_ASYMM:
777         case OperandType::TENSOR_QUANT8_SYMM:
778         case OperandType::TENSOR_QUANT16_ASYMM:
779         case OperandType::TENSOR_QUANT16_SYMM:
780             newOperand.dimensions = operand->dimensions.size() > 0 ? operand->dimensions
781                                                                    : std::vector<int32_t>({1});
782             newOperand.scale = operand->scale != 0.0f ? operand->scale : 1.0f;
783             break;
784         case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
785             newOperand.dimensions = operand->dimensions.size() > 0 ? operand->dimensions
786                                                                    : std::vector<int32_t>({1});
787             newOperand.scale = 0.0f;
788             newOperand.zeroPoint = 0;
789 
790             SymmPerChannelQuantParams channelQuant;
791             channelQuant.channelDim = 0;
792             channelQuant.scales = std::vector<float>(
793                     operand->dimensions.size() > 0 ? static_cast<size_t>(operand->dimensions[0])
794                                                    : 0);
795             for (size_t i = 0; i < channelQuant.scales.size(); ++i) {
796                 channelQuant.scales[i] = 1.0f;
797             }
798             newOperand.extraParams->set<OperandExtraParams::Tag::channelQuant>(
799                     std::move(channelQuant));
800         } break;
801         default:
802             break;
803     }
804     *operand = newOperand;
805 }
806 
mutateOperationOperandTypeSkip(size_t operand,OperandType type,const Model & model)807 static bool mutateOperationOperandTypeSkip(size_t operand, OperandType type, const Model& model) {
808     if (type == model.main.operands[operand].type) {
809         return true;
810     }
811     for (const Operation& operation : model.main.operations) {
812         // Skip mutateOperationOperandTypeTest for the following operations.
813         // - LSH_PROJECTION's second argument is allowed to have any type.
814         // - ARGMIN and ARGMAX's first argument can be any of
815         // TENSOR_(FLOAT16|FLOAT32|INT32|QUANT8_ASYMM).
816         // - CAST's argument can be any of TENSOR_(FLOAT16|FLOAT32|INT32|QUANT8_ASYMM).
817         // - RANDOM_MULTINOMIAL's argument can be either TENSOR_FLOAT16 or TENSOR_FLOAT32.
818         // - DEQUANTIZE input can be any of
819         // TENSOR_(QUANT8_ASYMM|QUANT8_ASYMM_SIGNED|QUANT8_SYMM|QUANT8_SYMM_PER_CHANNEL),
820         // output can be of either TENSOR_FLOAT16 or TENSOR_FLOAT32.
821         // - QUANTIZE input can be either TENSOR_FLOAT16 or TENSOR_FLOAT32
822         // - CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
823         // - DEPTHWISE_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
824         // - GROUPED_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
825         // - TRANSPOSE_CONV_2D filter type (arg 1) can be QUANT8_ASYMM or QUANT8_SYMM_PER_CHANNEL
826         // - AXIS_ALIGNED_BBOX_TRANSFORM bounding boxes (arg 1) can be of
827         //     TENSOR_QUANT8_ASYMM or TENSOR_QUANT8_ASYMM_SIGNED.
828         // - RANK's input can have any TENSOR_* type.
829         switch (operation.type) {
830             case OperationType::LSH_PROJECTION: {
831                 if (operand == operation.inputs[1]) {
832                     return true;
833                 }
834             } break;
835             case OperationType::CAST:
836             case OperationType::ARGMAX:
837             case OperationType::ARGMIN: {
838                 if (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32 ||
839                     type == OperandType::TENSOR_INT32 || type == OperandType::TENSOR_QUANT8_ASYMM ||
840                     type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
841                     return true;
842                 }
843             } break;
844             case OperationType::QUANTIZE: {
845                 if (operand == operation.inputs[0] &&
846                     (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
847                     return true;
848                 }
849                 if (operand == operation.outputs[0] &&
850                     (type == OperandType::TENSOR_QUANT8_ASYMM ||
851                      type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)) {
852                     return true;
853                 }
854             } break;
855             case OperationType::RANDOM_MULTINOMIAL: {
856                 if (operand == operation.inputs[0] &&
857                     (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
858                     return true;
859                 }
860             } break;
861             case OperationType::DEQUANTIZE: {
862                 if (operand == operation.inputs[0] &&
863                     (type == OperandType::TENSOR_QUANT8_ASYMM ||
864                      type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED ||
865                      type == OperandType::TENSOR_QUANT8_SYMM ||
866                      type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)) {
867                     return true;
868                 }
869                 if (operand == operation.outputs[0] &&
870                     (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32)) {
871                     return true;
872                 }
873             } break;
874             case OperationType::TRANSPOSE_CONV_2D:
875             case OperationType::GROUPED_CONV_2D:
876             case OperationType::DEPTHWISE_CONV_2D:
877             case OperationType::CONV_2D: {
878                 if (operand == operation.inputs[1] &&
879                     (type == OperandType::TENSOR_QUANT8_ASYMM ||
880                      type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL)) {
881                     return true;
882                 }
883             } break;
884             case OperationType::AXIS_ALIGNED_BBOX_TRANSFORM: {
885                 if (operand == operation.inputs[1] &&
886                     (type == OperandType::TENSOR_QUANT8_ASYMM ||
887                      type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)) {
888                     return true;
889                 }
890             } break;
891             case OperationType::RANK: {
892                 if (operand == operation.inputs[0] &&
893                     (type == OperandType::TENSOR_FLOAT16 || type == OperandType::TENSOR_FLOAT32 ||
894                      type == OperandType::TENSOR_INT32 ||
895                      type == OperandType::TENSOR_QUANT8_ASYMM ||
896                      type == OperandType::TENSOR_QUANT16_SYMM ||
897                      type == OperandType::TENSOR_BOOL8 ||
898                      type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL ||
899                      type == OperandType::TENSOR_QUANT16_ASYMM ||
900                      type == OperandType::TENSOR_QUANT8_SYMM ||
901                      type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)) {
902                     return true;
903                 }
904             } break;
905             default:
906                 break;
907         }
908     }
909     return false;
910 }
911 
mutateOperationOperandTypeTest(const std::shared_ptr<IDevice> & device,const Model & model)912 static void mutateOperationOperandTypeTest(const std::shared_ptr<IDevice>& device,
913                                            const Model& model) {
914     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
915         for (OperandType invalidOperandType : ndk::enum_range<OperandType>()) {
916             if (mutateOperationOperandTypeSkip(operand, invalidOperandType, model)) {
917                 continue;
918             }
919             const std::string message = "mutateOperationOperandTypeTest: operand " +
920                                         std::to_string(operand) + " set to type " +
921                                         toString(invalidOperandType);
922             validate(device, message, model,
923                      [operand, invalidOperandType](Model* model, ExecutionPreference*, Priority*) {
924                          mutateOperand(&model->main.operands[operand], invalidOperandType);
925                      });
926         }
927     }
928 }
929 
930 ///////////////////////// VALIDATE MODEL OPERATION TYPE /////////////////////////
931 
932 static const int32_t invalidOperationTypes[] = {
933         -1,
934         static_cast<int32_t>(*(ndk::enum_range<OperationType>().end() - 1)) + 1,
935 };
936 
mutateOperationTypeTest(const std::shared_ptr<IDevice> & device,const Model & model)937 static void mutateOperationTypeTest(const std::shared_ptr<IDevice>& device, const Model& model) {
938     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
939         for (int32_t invalidOperationType : invalidOperationTypes) {
940             const std::string message = "mutateOperationTypeTest: operation " +
941                                         std::to_string(operation) + " set to value " +
942                                         std::to_string(invalidOperationType);
943             validate(device, message, model,
944                      [operation, invalidOperationType](Model* model, ExecutionPreference*,
945                                                        Priority*) {
946                          model->main.operations[operation].type =
947                                  static_cast<OperationType>(invalidOperationType);
948                      });
949         }
950     }
951 }
952 
953 ///////////////////////// VALIDATE MODEL OPERATION INPUT OPERAND INDEX /////////////////////////
954 
mutateOperationInputOperandIndexTest(const std::shared_ptr<IDevice> & device,const Model & model)955 static void mutateOperationInputOperandIndexTest(const std::shared_ptr<IDevice>& device,
956                                                  const Model& model) {
957     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
958         const uint32_t invalidOperand = model.main.operands.size();
959         for (size_t input = 0; input < model.main.operations[operation].inputs.size(); ++input) {
960             const std::string message = "mutateOperationInputOperandIndexTest: operation " +
961                                         std::to_string(operation) + " input " +
962                                         std::to_string(input);
963             validate(device, message, model,
964                      [operation, input, invalidOperand](Model* model, ExecutionPreference*,
965                                                         Priority*) {
966                          model->main.operations[operation].inputs[input] = invalidOperand;
967                      });
968         }
969     }
970 }
971 
972 ///////////////////////// VALIDATE MODEL OPERATION OUTPUT OPERAND INDEX /////////////////////////
973 
mutateOperationOutputOperandIndexTest(const std::shared_ptr<IDevice> & device,const Model & model)974 static void mutateOperationOutputOperandIndexTest(const std::shared_ptr<IDevice>& device,
975                                                   const Model& model) {
976     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
977         const uint32_t invalidOperand = model.main.operands.size();
978         for (size_t output = 0; output < model.main.operations[operation].outputs.size();
979              ++output) {
980             const std::string message = "mutateOperationOutputOperandIndexTest: operation " +
981                                         std::to_string(operation) + " output " +
982                                         std::to_string(output);
983             validate(device, message, model,
984                      [operation, output, invalidOperand](Model* model, ExecutionPreference*,
985                                                          Priority*) {
986                          model->main.operations[operation].outputs[output] = invalidOperand;
987                      });
988         }
989     }
990 }
991 
992 ///////////////////////// VALIDATE MODEL OPERANDS WRITTEN ///////////////////////////////////////
993 
mutateOperationRemoveWriteTest(const std::shared_ptr<IDevice> & device,const Model & model,const std::vector<uint32_t> & numberOfConsumers)994 static void mutateOperationRemoveWriteTest(const std::shared_ptr<IDevice>& device,
995                                            const Model& model,
996                                            const std::vector<uint32_t>& numberOfConsumers) {
997     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
998         for (size_t outputNum = 0; outputNum < model.main.operations[operation].outputs.size();
999              ++outputNum) {
1000             const uint32_t outputOperandIndex = model.main.operations[operation].outputs[outputNum];
1001             if (numberOfConsumers[outputOperandIndex] > 0) {
1002                 const std::string message = "mutateOperationRemoveWriteTest: operation " +
1003                                             std::to_string(operation) + " writes to " +
1004                                             std::to_string(outputOperandIndex);
1005                 validate(device, message, model,
1006                          [operation, outputNum](Model* model, ExecutionPreference*, Priority*) {
1007                              int32_t& outputOperandIndex =
1008                                      model->main.operations[operation].outputs[outputNum];
1009                              Operand operandValue = model->main.operands[outputOperandIndex];
1010                              if (operandValue.lifetime == OperandLifeTime::SUBGRAPH_OUTPUT) {
1011                                  operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
1012                              } else {
1013                                  ASSERT_EQ(operandValue.lifetime,
1014                                            OperandLifeTime::TEMPORARY_VARIABLE);
1015                              }
1016                              outputOperandIndex = model->main.operands.size();
1017                              model->main.operands.push_back(operandValue);
1018                          });
1019             }
1020         }
1021     }
1022 }
1023 
1024 ///////////////////////// REMOVE OPERAND FROM EVERYTHING /////////////////////////
1025 
removeValueAndDecrementGreaterValues(std::vector<int32_t> * vec,uint32_t value)1026 static void removeValueAndDecrementGreaterValues(std::vector<int32_t>* vec, uint32_t value) {
1027     if (vec) {
1028         // remove elements matching "value"
1029         vec->erase(std::remove(vec->begin(), vec->end(), value), vec->end());
1030 
1031         // decrement elements exceeding "value"
1032         std::transform(vec->begin(), vec->end(), vec->begin(),
1033                        [value](uint32_t v) { return v > value ? v-- : v; });
1034     }
1035 }
1036 
removeOperand(Model * model,uint32_t index)1037 static void removeOperand(Model* model, uint32_t index) {
1038     model->main.operands.erase(model->main.operands.begin() + index);
1039     for (Operation& operation : model->main.operations) {
1040         removeValueAndDecrementGreaterValues(&operation.inputs, index);
1041         removeValueAndDecrementGreaterValues(&operation.outputs, index);
1042     }
1043     removeValueAndDecrementGreaterValues(&model->main.inputIndexes, index);
1044     removeValueAndDecrementGreaterValues(&model->main.outputIndexes, index);
1045 }
1046 
removeOperandSkip(size_t operandIndex,const Model & model,const std::vector<uint32_t> & numberOfConsumers)1047 static bool removeOperandSkip(size_t operandIndex, const Model& model,
1048                               const std::vector<uint32_t>& numberOfConsumers) {
1049     if (numberOfConsumers[operandIndex] == 0) {
1050         // Removing an unused operand has no effect.
1051         return true;
1052     }
1053     for (const Operation& operation : model.main.operations) {
1054         // Skip removeOperandTest for the following operations.
1055         // - SPLIT's outputs are not checked during prepareModel.
1056         if (operation.type == OperationType::SPLIT) {
1057             for (const size_t index : operation.outputs) {
1058                 if (index == operandIndex) {
1059                     return true;
1060                 }
1061             }
1062         }
1063         // BIDIRECTIONAL_SEQUENCE_LSTM and BIDIRECTIONAL_SEQUENCE_RNN can have
1064         // either one, two, three or four outputs depending on their
1065         // mergeOutputs parameter and if state outputs are provided.
1066         // UNIDIRECTIONAL_SEQUENCE_LSTM and UNIDIRECTIONAL_SEQUENCE_RNN can have
1067         // either one or three outputs depending on whether state outputs are
1068         // provided.
1069         if (operation.type == OperationType::UNIDIRECTIONAL_SEQUENCE_LSTM ||
1070             operation.type == OperationType::UNIDIRECTIONAL_SEQUENCE_RNN ||
1071             operation.type == OperationType::BIDIRECTIONAL_SEQUENCE_LSTM ||
1072             operation.type == OperationType::BIDIRECTIONAL_SEQUENCE_RNN) {
1073             for (const size_t index : operation.outputs) {
1074                 if (index == operandIndex) {
1075                     return true;
1076                 }
1077             }
1078         }
1079     }
1080     return false;
1081 }
1082 
removeOperandTest(const std::shared_ptr<IDevice> & device,const Model & model,const std::vector<uint32_t> & numberOfConsumers)1083 static void removeOperandTest(const std::shared_ptr<IDevice>& device, const Model& model,
1084                               const std::vector<uint32_t>& numberOfConsumers) {
1085     for (size_t operand = 0; operand < model.main.operands.size(); ++operand) {
1086         if (removeOperandSkip(operand, model, numberOfConsumers)) {
1087             continue;
1088         }
1089         const std::string message = "removeOperandTest: operand " + std::to_string(operand);
1090         validate(device, message, model, [operand](Model* model, ExecutionPreference*, Priority*) {
1091             removeOperand(model, operand);
1092         });
1093     }
1094 }
1095 
1096 ///////////////////////// REMOVE OPERATION /////////////////////////
1097 
removeOperation(Model * model,uint32_t index)1098 static void removeOperation(Model* model, uint32_t index) {
1099     auto& operations = model->main.operations;
1100     operations.erase(operations.begin() + index);
1101 }
1102 
removeOperationTest(const std::shared_ptr<IDevice> & device,const Model & model)1103 static void removeOperationTest(const std::shared_ptr<IDevice>& device, const Model& model) {
1104     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1105         const std::string message = "removeOperationTest: operation " + std::to_string(operation);
1106         validate(device, message, model,
1107                  [operation](Model* model, ExecutionPreference*, Priority*) {
1108                      removeOperation(model, operation);
1109                  });
1110     }
1111 }
1112 
1113 ///////////////////////// REMOVE OPERATION INPUT /////////////////////////
1114 
removeOperationInputSkip(const Operation & op,size_t input)1115 static bool removeOperationInputSkip(const Operation& op, size_t input) {
1116     // Skip removeOperationInputTest for the following operations.
1117     // - CONCATENATION has at least 2 inputs, with the last element being INT32.
1118     // - CONV_2D, DEPTHWISE_CONV_2D, MAX_POOL_2D, AVERAGE_POOL_2D, L2_POOL_2D, RESIZE_BILINEAR,
1119     //   SPACE_TO_DEPTH, SPACE_TO_DEPTH, SPACE_TO_BATCH_ND, BATCH_TO_SPACE_ND can have an optional
1120     //   layout parameter.
1121     //   RESIZE_BILINEAR and RESIZE_NEAREST_NEIGHBOR can have optional
1122     //   align_corners and half_pixel_centers parameters.
1123     // - L2_NORMALIZATION, LOCAL_RESPONSE_NORMALIZATION, SOFTMAX can have an optional axis
1124     //   parameter.
1125     switch (op.type) {
1126         case OperationType::CONCATENATION: {
1127             if (op.inputs.size() > 2 && input != op.inputs.size() - 1) {
1128                 return true;
1129             }
1130         } break;
1131         case OperationType::DEPTHWISE_CONV_2D: {
1132             if ((op.inputs.size() == 12 && input == 11) || (op.inputs.size() == 9 && input == 8)) {
1133                 return true;
1134             }
1135         } break;
1136         case OperationType::CONV_2D:
1137         case OperationType::AVERAGE_POOL_2D:
1138         case OperationType::MAX_POOL_2D:
1139         case OperationType::L2_POOL_2D: {
1140             if ((op.inputs.size() == 11 && input == 10) || (op.inputs.size() == 8 && input == 7)) {
1141                 return true;
1142             }
1143         } break;
1144         case OperationType::RESIZE_BILINEAR: {
1145             if (op.inputs.size() >= 4 && input >= 3) {
1146                 return true;
1147             }
1148         } break;
1149         case OperationType::RESIZE_NEAREST_NEIGHBOR: {
1150             if (op.inputs.size() >= 5 && input >= 3) {
1151                 return true;
1152             }
1153         } break;
1154         case OperationType::SPACE_TO_DEPTH:
1155         case OperationType::DEPTH_TO_SPACE:
1156         case OperationType::BATCH_TO_SPACE_ND: {
1157             if (op.inputs.size() == 3 && input == 2) {
1158                 return true;
1159             }
1160         } break;
1161         case OperationType::SPACE_TO_BATCH_ND: {
1162             if (op.inputs.size() == 4 && input == 3) {
1163                 return true;
1164             }
1165         } break;
1166         case OperationType::L2_NORMALIZATION: {
1167             if (op.inputs.size() == 2 && input == 1) {
1168                 return true;
1169             }
1170         } break;
1171         case OperationType::LOCAL_RESPONSE_NORMALIZATION: {
1172             if (op.inputs.size() == 6 && input == 5) {
1173                 return true;
1174             }
1175         } break;
1176         case OperationType::SOFTMAX: {
1177             if (op.inputs.size() == 3 && input == 2) {
1178                 return true;
1179             }
1180         } break;
1181         default:
1182             break;
1183     }
1184     return false;
1185 }
1186 
removeOperationInputTest(const std::shared_ptr<IDevice> & device,const Model & model)1187 static void removeOperationInputTest(const std::shared_ptr<IDevice>& device, const Model& model) {
1188     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1189         for (size_t input = 0; input < model.main.operations[operation].inputs.size(); ++input) {
1190             const Operation& op = model.main.operations[operation];
1191             if (removeOperationInputSkip(op, input)) {
1192                 continue;
1193             }
1194             const std::string message = "removeOperationInputTest: operation " +
1195                                         std::to_string(operation) + ", input " +
1196                                         std::to_string(input);
1197             validate(device, message, model,
1198                      [operation, input](Model* model, ExecutionPreference*, Priority*) {
1199                          auto& inputs = model->main.operations[operation].inputs;
1200                          inputs.erase(inputs.begin() + input);
1201                      });
1202         }
1203     }
1204 }
1205 
1206 ///////////////////////// REMOVE OPERATION OUTPUT /////////////////////////
1207 
removeOperationOutputTest(const std::shared_ptr<IDevice> & device,const Model & model)1208 static void removeOperationOutputTest(const std::shared_ptr<IDevice>& device, const Model& model) {
1209     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1210         for (size_t output = 0; output < model.main.operations[operation].outputs.size();
1211              ++output) {
1212             const std::string message = "removeOperationOutputTest: operation " +
1213                                         std::to_string(operation) + ", output " +
1214                                         std::to_string(output);
1215             validate(device, message, model,
1216                      [operation, output](Model* model, ExecutionPreference*, Priority*) {
1217                          auto& outputs = model->main.operations[operation].outputs;
1218                          outputs.erase(outputs.begin() + output);
1219                      });
1220         }
1221     }
1222 }
1223 
1224 ///////////////////////// MODEL VALIDATION /////////////////////////
1225 
1226 // TODO: remove model input
1227 // TODO: remove model output
1228 // TODO: add unused operation
1229 
1230 ///////////////////////// ADD OPERATION INPUT /////////////////////////
1231 
addOperationInputSkip(const Operation & op)1232 static bool addOperationInputSkip(const Operation& op) {
1233     // Skip addOperationInputTest for the following operations.
1234     // - L2_NORMALIZATION, LOCAL_RESPONSE_NORMALIZATION, SOFTMAX can have an optional INT32 axis
1235     //   parameter.
1236     if ((op.type == OperationType::L2_NORMALIZATION && op.inputs.size() == 1) ||
1237         (op.type == OperationType::LOCAL_RESPONSE_NORMALIZATION && op.inputs.size() == 5) ||
1238         (op.type == OperationType::SOFTMAX && op.inputs.size() == 2) ||
1239         (op.type == OperationType::RESIZE_BILINEAR && op.inputs.size() < 6) ||
1240         (op.type == OperationType::RESIZE_NEAREST_NEIGHBOR && op.inputs.size() < 6)) {
1241         return true;
1242     }
1243     return false;
1244 }
1245 
addOperationInputTest(const std::shared_ptr<IDevice> & device,const Model & model)1246 static void addOperationInputTest(const std::shared_ptr<IDevice>& device, const Model& model) {
1247     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1248         if (addOperationInputSkip(model.main.operations[operation])) {
1249             continue;
1250         }
1251         const std::string message = "addOperationInputTest: operation " + std::to_string(operation);
1252         validate(device, message, model,
1253                  [operation](Model* model, ExecutionPreference*, Priority*) {
1254                      uint32_t index = addOperand(model, OperandLifeTime::SUBGRAPH_INPUT);
1255                      model->main.operations[operation].inputs.push_back(index);
1256                      model->main.inputIndexes.push_back(index);
1257                  });
1258     }
1259 }
1260 
1261 ///////////////////////// ADD OPERATION OUTPUT /////////////////////////
1262 
addOperationOutputTest(const std::shared_ptr<IDevice> & device,const Model & model)1263 static void addOperationOutputTest(const std::shared_ptr<IDevice>& device, const Model& model) {
1264     for (size_t operation = 0; operation < model.main.operations.size(); ++operation) {
1265         const std::string message =
1266                 "addOperationOutputTest: operation " + std::to_string(operation);
1267         validate(device, message, model,
1268                  [operation](Model* model, ExecutionPreference*, Priority*) {
1269                      uint32_t index = addOperand(model, OperandLifeTime::SUBGRAPH_OUTPUT);
1270                      model->main.operations[operation].outputs.push_back(index);
1271                      model->main.outputIndexes.push_back(index);
1272                  });
1273     }
1274 }
1275 
1276 ///////////////////////// VALIDATE EXECUTION PREFERENCE /////////////////////////
1277 
1278 static const int32_t invalidExecutionPreferences[] = {
1279         static_cast<int32_t>(ExecutionPreference::LOW_POWER) - 1,        // lower bound
1280         static_cast<int32_t>(ExecutionPreference::SUSTAINED_SPEED) + 1,  // upper bound
1281 };
1282 
mutateExecutionPreferenceTest(const std::shared_ptr<IDevice> & device,const Model & model)1283 static void mutateExecutionPreferenceTest(const std::shared_ptr<IDevice>& device,
1284                                           const Model& model) {
1285     for (int32_t invalidPreference : invalidExecutionPreferences) {
1286         const std::string message =
1287                 "mutateExecutionPreferenceTest: preference " + std::to_string(invalidPreference);
1288         validate(device, message, model,
1289                  [invalidPreference](Model*, ExecutionPreference* preference, Priority*) {
1290                      *preference = static_cast<ExecutionPreference>(invalidPreference);
1291                  });
1292     }
1293 }
1294 
1295 ///////////////////////// VALIDATE PRIORITY /////////////////////////
1296 
1297 static const int32_t invalidPriorities[] = {
1298         static_cast<int32_t>(Priority::LOW) - 1,   // lower bound
1299         static_cast<int32_t>(Priority::HIGH) + 1,  // upper bound
1300 };
1301 
mutateExecutionPriorityTest(const std::shared_ptr<IDevice> & device,const Model & model)1302 static void mutateExecutionPriorityTest(const std::shared_ptr<IDevice>& device,
1303                                         const Model& model) {
1304     for (int32_t invalidPriority : invalidPriorities) {
1305         const std::string message =
1306                 "mutatePriorityTest: priority " + std::to_string(invalidPriority);
1307         validate(device, message, model,
1308                  [invalidPriority](Model*, ExecutionPreference*, Priority* priority) {
1309                      *priority = static_cast<Priority>(invalidPriority);
1310                  });
1311     }
1312 }
1313 
1314 ////////////////////////// ENTRY POINT //////////////////////////////
1315 
validateModel(const std::shared_ptr<IDevice> & device,const Model & model)1316 void validateModel(const std::shared_ptr<IDevice>& device, const Model& model) {
1317     const auto numberOfConsumers =
1318             nn::countNumberOfConsumers(model.main.operands.size(),
1319                                        nn::unvalidatedConvert(model.main.operations).value())
1320                     .value();
1321     mutateExecutionOrderTest(device, model, numberOfConsumers);
1322     mutateOperandTypeTest(device, model);
1323     mutateOperandRankTest(device, model);
1324     mutateOperandScaleTest(device, model);
1325     mutateOperandZeroPointTest(device, model);
1326     mutateOperandLifeTimeTest(device, model);
1327     mutateOperandInputOutputTest(device, model);
1328     mutateOperandAddWriterTest(device, model);
1329     mutateOperationOperandTypeTest(device, model);
1330     mutateOperationTypeTest(device, model);
1331     mutateOperationInputOperandIndexTest(device, model);
1332     mutateOperationOutputOperandIndexTest(device, model);
1333     mutateOperationRemoveWriteTest(device, model, numberOfConsumers);
1334     removeOperandTest(device, model, numberOfConsumers);
1335     removeOperationTest(device, model);
1336     removeOperationInputTest(device, model);
1337     removeOperationOutputTest(device, model);
1338     addOperationInputTest(device, model);
1339     addOperationOutputTest(device, model);
1340     mutateExecutionPreferenceTest(device, model);
1341     mutateExecutionPriorityTest(device, model);
1342 }
1343 
1344 }  // namespace aidl::android::hardware::neuralnetworks::vts::functional
1345