1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18 
19 #include <android/hardware/neuralnetworks/1.1/types.h>
20 #include "1.0/Callbacks.h"
21 #include "1.0/Utils.h"
22 #include "GeneratedTestHarness.h"
23 #include "VtsHalNeuralnetworks.h"
24 
25 #include <optional>
26 #include <type_traits>
27 #include <utility>
28 
29 namespace android::hardware::neuralnetworks::V1_1::vts::functional {
30 
31 using V1_0::DataLocation;
32 using V1_0::ErrorStatus;
33 using V1_0::IPreparedModel;
34 using V1_0::Operand;
35 using V1_0::OperandLifeTime;
36 using V1_0::OperandType;
37 using V1_0::implementation::PreparedModelCallback;
38 
39 using PrepareModelMutation = std::function<void(Model*, ExecutionPreference*)>;
40 
41 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
42 
validateGetSupportedOperations(const sp<IDevice> & device,const std::string & message,const Model & model)43 static void validateGetSupportedOperations(const sp<IDevice>& device, const std::string& message,
44                                            const Model& model) {
45     SCOPED_TRACE(message + " [getSupportedOperations_1_1]");
46 
47     Return<void> ret = device->getSupportedOperations_1_1(
48             model, [&](ErrorStatus status, const hidl_vec<bool>&) {
49                 EXPECT_EQ(ErrorStatus::INVALID_ARGUMENT, status);
50             });
51     EXPECT_TRUE(ret.isOk());
52 }
53 
validatePrepareModel(const sp<IDevice> & device,const std::string & message,const Model & model,ExecutionPreference preference)54 static void validatePrepareModel(const sp<IDevice>& device, const std::string& message,
55                                  const Model& model, ExecutionPreference preference) {
56     SCOPED_TRACE(message + " [prepareModel_1_1]");
57 
58     sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
59     Return<ErrorStatus> prepareLaunchStatus =
60             device->prepareModel_1_1(model, preference, preparedModelCallback);
61     ASSERT_TRUE(prepareLaunchStatus.isOk());
62     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, static_cast<ErrorStatus>(prepareLaunchStatus));
63 
64     preparedModelCallback->wait();
65     ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
66     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, prepareReturnStatus);
67     sp<IPreparedModel> preparedModel = preparedModelCallback->getPreparedModel();
68     ASSERT_EQ(nullptr, preparedModel.get());
69 }
70 
validExecutionPreference(ExecutionPreference preference)71 static bool validExecutionPreference(ExecutionPreference preference) {
72     return preference == ExecutionPreference::LOW_POWER ||
73            preference == ExecutionPreference::FAST_SINGLE_ANSWER ||
74            preference == ExecutionPreference::SUSTAINED_SPEED;
75 }
76 
77 // Primary validation function. This function will take a valid model, apply a
78 // mutation to invalidate either the model or the execution preference, then
79 // pass these to supportedOperations and/or prepareModel if that method is
80 // called with an invalid argument.
validate(const sp<IDevice> & device,const std::string & message,const Model & originalModel,const PrepareModelMutation & mutate)81 static void validate(const sp<IDevice>& device, const std::string& message,
82                      const Model& originalModel, const PrepareModelMutation& mutate) {
83     Model model = originalModel;
84     ExecutionPreference preference = ExecutionPreference::FAST_SINGLE_ANSWER;
85     mutate(&model, &preference);
86 
87     if (validExecutionPreference(preference)) {
88         validateGetSupportedOperations(device, message, model);
89     }
90 
91     validatePrepareModel(device, message, model, preference);
92 }
93 
addOperand(Model * model)94 static uint32_t addOperand(Model* model) {
95     return hidl_vec_push_back(&model->operands,
96                               {
97                                       .type = OperandType::INT32,
98                                       .dimensions = {},
99                                       .numberOfConsumers = 0,
100                                       .scale = 0.0f,
101                                       .zeroPoint = 0,
102                                       .lifetime = OperandLifeTime::MODEL_INPUT,
103                                       .location = {.poolIndex = 0, .offset = 0, .length = 0},
104                               });
105 }
106 
addOperand(Model * model,OperandLifeTime lifetime)107 static uint32_t addOperand(Model* model, OperandLifeTime lifetime) {
108     uint32_t index = addOperand(model);
109     model->operands[index].numberOfConsumers = 1;
110     model->operands[index].lifetime = lifetime;
111     return index;
112 }
113 
114 // If we introduce a CONSTANT_COPY for an operand of size operandSize,
115 // how much will this increase the size of the model?  This assumes
116 // that we can (re)use all of model.operandValues for the operand
117 // value.
constantCopyExtraSize(const Model & model,size_t operandSize)118 static size_t constantCopyExtraSize(const Model& model, size_t operandSize) {
119     const size_t operandValuesSize = model.operandValues.size();
120     return (operandValuesSize < operandSize) ? (operandSize - operandValuesSize) : 0;
121 }
122 
123 // Highly specialized utility routine for converting an operand to
124 // CONSTANT_COPY lifetime.
125 //
126 // Expects that:
127 // - operand has a known size
128 // - operand->lifetime has already been set to CONSTANT_COPY
129 // - operand->location has been zeroed out
130 //
131 // Does the following:
132 // - initializes operand->location to point to the beginning of model->operandValues
133 // - resizes model->operandValues (if necessary) to be large enough for the operand
134 //   value, padding it with zeroes on the end
135 //
136 // Potential problem:
137 // By changing the operand to CONSTANT_COPY lifetime, this function is effectively initializing the
138 // operand with unspecified (but deterministic) data. This means that the model may be invalidated
139 // in two ways: not only is the lifetime of CONSTANT_COPY invalid, but the operand's value in the
140 // graph may also be invalid (e.g., if the operand is used as an activation code and has an invalid
141 // value). For now, this should be fine because it just means we're not testing what we think we're
142 // testing in certain cases; but we can handwave this and assume we're probabilistically likely to
143 // exercise the validation code over the span of the entire test set and operand space.
144 //
145 // Aborts if the specified operand type is an extension type or OEM type.
becomeConstantCopy(Model * model,Operand * operand)146 static void becomeConstantCopy(Model* model, Operand* operand) {
147     // sizeOfData will abort if the specified type is an extension type or OEM type.
148     const size_t sizeOfOperand = sizeOfData(*operand);
149     EXPECT_NE(sizeOfOperand, size_t(0));
150     operand->location.poolIndex = 0;
151     operand->location.offset = 0;
152     operand->location.length = sizeOfOperand;
153     if (model->operandValues.size() < sizeOfOperand) {
154         model->operandValues.resize(sizeOfOperand);
155     }
156 }
157 
158 // The sizeForBinder() functions estimate the size of the
159 // representation of a value when sent to binder.  It's probably a bit
160 // of an under-estimate, because we don't know the size of the
161 // metadata in the binder format (e.g., representation of the size of
162 // a vector); but at least it adds up "big" things like vector
163 // contents.  However, it doesn't treat inter-field or end-of-struct
164 // padding in a methodical way -- there's no attempt to be consistent
165 // in whether or not padding in the native (C++) representation
166 // contributes to the estimated size for the binder representation;
167 // and there's no attempt to understand what padding (if any) is
168 // needed in the binder representation.
169 //
170 // This assumes that non-metadata uses a fixed length encoding (e.g.,
171 // a uint32_t is always encoded in sizeof(uint32_t) bytes, rather than
172 // using an encoding whose length is related to the magnitude of the
173 // encoded value).
174 
175 template <typename Type>
sizeForBinder(const Type & val)176 static size_t sizeForBinder(const Type& val) {
177     static_assert(std::is_trivially_copyable_v<std::remove_reference_t<Type>>,
178                   "expected a trivially copyable type");
179     return sizeof(val);
180 }
181 
182 template <typename Type>
sizeForBinder(const hidl_vec<Type> & vec)183 static size_t sizeForBinder(const hidl_vec<Type>& vec) {
184     return std::accumulate(vec.begin(), vec.end(), 0,
185                            [](size_t acc, const Type& x) { return acc + sizeForBinder(x); });
186 }
187 
188 template <>
sizeForBinder(const Operand & operand)189 size_t sizeForBinder(const Operand& operand) {
190     size_t size = 0;
191 
192     size += sizeForBinder(operand.type);
193     size += sizeForBinder(operand.dimensions);
194     size += sizeForBinder(operand.numberOfConsumers);
195     size += sizeForBinder(operand.scale);
196     size += sizeForBinder(operand.zeroPoint);
197     size += sizeForBinder(operand.lifetime);
198     size += sizeForBinder(operand.location);
199 
200     return size;
201 }
202 
203 template <>
sizeForBinder(const Operation & operation)204 size_t sizeForBinder(const Operation& operation) {
205     size_t size = 0;
206 
207     size += sizeForBinder(operation.type);
208     size += sizeForBinder(operation.inputs);
209     size += sizeForBinder(operation.outputs);
210 
211     return size;
212 }
213 
214 template <>
sizeForBinder(const hidl_string & name)215 size_t sizeForBinder(const hidl_string& name) {
216     return name.size();
217 }
218 
219 template <>
sizeForBinder(const hidl_memory & memory)220 size_t sizeForBinder(const hidl_memory& memory) {
221     // This is just a guess.
222 
223     size_t size = 0;
224 
225     if (const native_handle_t* handle = memory.handle()) {
226         size += sizeof(*handle);
227         size += sizeof(handle->data[0] * (handle->numFds + handle->numInts));
228     }
229     size += sizeForBinder(memory.name());
230 
231     return size;
232 }
233 
234 template <>
sizeForBinder(const Model & model)235 size_t sizeForBinder(const Model& model) {
236     size_t size = 0;
237 
238     size += sizeForBinder(model.operands);
239     size += sizeForBinder(model.operations);
240     size += sizeForBinder(model.inputIndexes);
241     size += sizeForBinder(model.outputIndexes);
242     size += sizeForBinder(model.operandValues);
243     size += sizeForBinder(model.pools);
244     size += sizeForBinder(model.relaxComputationFloat32toFloat16);
245 
246     return size;
247 }
248 
249 // https://developer.android.com/reference/android/os/TransactionTooLargeException.html
250 //
251 //     "The Binder transaction buffer has a limited fixed size,
252 //     currently 1Mb, which is shared by all transactions in progress
253 //     for the process."
254 //
255 // Will our representation fit under this limit?  There are two complications:
256 // - Our representation size is just approximate (see sizeForBinder()).
257 // - This object may not be the only occupant of the Binder transaction buffer.
258 // So we'll be very conservative: We want the representation size to be no
259 // larger than half the transaction buffer size.
260 //
261 // If our representation grows large enough that it still fits within
262 // the transaction buffer but combined with other transactions may
263 // exceed the buffer size, then we may see intermittent HAL transport
264 // errors.
exceedsBinderSizeLimit(size_t representationSize)265 static bool exceedsBinderSizeLimit(size_t representationSize) {
266     // Instead of using this fixed buffer size, we might instead be able to use
267     // ProcessState::self()->getMmapSize(). However, this has a potential
268     // problem: The binder/mmap size of the current process does not necessarily
269     // indicate the binder/mmap size of the service (i.e., the other process).
270     // The only way it would be a good indication is if both the current process
271     // and the service use the default size.
272     static const size_t kHalfBufferSize = 1024 * 1024 / 2;
273 
274     return representationSize > kHalfBufferSize;
275 }
276 
277 ///////////////////////// VALIDATE EXECUTION ORDER ////////////////////////////
278 
mutateExecutionOrderTest(const sp<IDevice> & device,const V1_1::Model & model)279 static void mutateExecutionOrderTest(const sp<IDevice>& device, const V1_1::Model& model) {
280     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
281         const Operation& operationObj = model.operations[operation];
282         for (uint32_t input : operationObj.inputs) {
283             if (model.operands[input].lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
284                 model.operands[input].lifetime == OperandLifeTime::MODEL_OUTPUT) {
285                 // This operation reads an operand written by some
286                 // other operation.  Move this operation to the
287                 // beginning of the sequence, ensuring that it reads
288                 // the operand before that operand is written, thereby
289                 // violating execution order rules.
290                 const std::string message = "mutateExecutionOrderTest: operation " +
291                                             std::to_string(operation) + " is a reader";
292                 validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
293                     auto& operations = model->operations;
294                     std::rotate(operations.begin(), operations.begin() + operation,
295                                 operations.begin() + operation + 1);
296                 });
297                 break;  // only need to do this once per operation
298             }
299         }
300         for (uint32_t output : operationObj.outputs) {
301             if (model.operands[output].numberOfConsumers > 0) {
302                 // This operation writes an operand read by some other
303                 // operation.  Move this operation to the end of the
304                 // sequence, ensuring that it writes the operand after
305                 // that operand is read, thereby violating execution
306                 // order rules.
307                 const std::string message = "mutateExecutionOrderTest: operation " +
308                                             std::to_string(operation) + " is a writer";
309                 validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
310                     auto& operations = model->operations;
311                     std::rotate(operations.begin() + operation, operations.begin() + operation + 1,
312                                 operations.end());
313                 });
314                 break;  // only need to do this once per operation
315             }
316         }
317     }
318 }
319 
320 ///////////////////////// VALIDATE MODEL OPERAND TYPE /////////////////////////
321 
322 static const int32_t invalidOperandTypes[] = {
323         static_cast<int32_t>(OperandType::FLOAT32) - 1,              // lower bound fundamental
324         static_cast<int32_t>(OperandType::TENSOR_QUANT8_ASYMM) + 1,  // upper bound fundamental
325         static_cast<int32_t>(OperandType::OEM) - 1,                  // lower bound OEM
326         static_cast<int32_t>(OperandType::TENSOR_OEM_BYTE) + 1,      // upper bound OEM
327 };
328 
mutateOperandTypeTest(const sp<IDevice> & device,const Model & model)329 static void mutateOperandTypeTest(const sp<IDevice>& device, const Model& model) {
330     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
331         for (int32_t invalidOperandType : invalidOperandTypes) {
332             const std::string message = "mutateOperandTypeTest: operand " +
333                                         std::to_string(operand) + " set to value " +
334                                         std::to_string(invalidOperandType);
335             validate(device, message, model,
336                      [operand, invalidOperandType](Model* model, ExecutionPreference*) {
337                          model->operands[operand].type =
338                                  static_cast<OperandType>(invalidOperandType);
339                      });
340         }
341     }
342 }
343 
344 ///////////////////////// VALIDATE OPERAND RANK /////////////////////////
345 
getInvalidRank(OperandType type)346 static uint32_t getInvalidRank(OperandType type) {
347     switch (type) {
348         case OperandType::FLOAT32:
349         case OperandType::INT32:
350         case OperandType::UINT32:
351             return 1;
352         case OperandType::TENSOR_FLOAT32:
353         case OperandType::TENSOR_INT32:
354         case OperandType::TENSOR_QUANT8_ASYMM:
355             return 0;
356         default:
357             return 0;
358     }
359 }
360 
mutateOperandRankTest(const sp<IDevice> & device,const Model & model)361 static void mutateOperandRankTest(const sp<IDevice>& device, const Model& model) {
362     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
363         const uint32_t invalidRank = getInvalidRank(model.operands[operand].type);
364         const std::string message = "mutateOperandRankTest: operand " + std::to_string(operand) +
365                                     " has rank of " + std::to_string(invalidRank);
366         validate(device, message, model,
367                  [operand, invalidRank](Model* model, ExecutionPreference*) {
368                      model->operands[operand].dimensions = std::vector<uint32_t>(invalidRank, 0);
369                  });
370     }
371 }
372 
373 ///////////////////////// VALIDATE OPERAND SCALE /////////////////////////
374 
getInvalidScale(OperandType type)375 static float getInvalidScale(OperandType type) {
376     switch (type) {
377         case OperandType::FLOAT32:
378         case OperandType::INT32:
379         case OperandType::UINT32:
380         case OperandType::TENSOR_FLOAT32:
381             return 1.0f;
382         case OperandType::TENSOR_INT32:
383             return -1.0f;
384         case OperandType::TENSOR_QUANT8_ASYMM:
385             return 0.0f;
386         default:
387             return 0.0f;
388     }
389 }
390 
mutateOperandScaleTest(const sp<IDevice> & device,const Model & model)391 static void mutateOperandScaleTest(const sp<IDevice>& device, const Model& model) {
392     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
393         const float invalidScale = getInvalidScale(model.operands[operand].type);
394         const std::string message = "mutateOperandScaleTest: operand " + std::to_string(operand) +
395                                     " has scale of " + std::to_string(invalidScale);
396         validate(device, message, model,
397                  [operand, invalidScale](Model* model, ExecutionPreference*) {
398                      model->operands[operand].scale = invalidScale;
399                  });
400     }
401 }
402 
403 ///////////////////////// VALIDATE OPERAND ZERO POINT /////////////////////////
404 
getInvalidZeroPoints(OperandType type)405 static std::vector<int32_t> getInvalidZeroPoints(OperandType type) {
406     switch (type) {
407         case OperandType::FLOAT32:
408         case OperandType::INT32:
409         case OperandType::UINT32:
410         case OperandType::TENSOR_FLOAT32:
411         case OperandType::TENSOR_INT32:
412             return {1};
413         case OperandType::TENSOR_QUANT8_ASYMM:
414             return {-1, 256};
415         default:
416             return {};
417     }
418 }
419 
mutateOperandZeroPointTest(const sp<IDevice> & device,const Model & model)420 static void mutateOperandZeroPointTest(const sp<IDevice>& device, const Model& model) {
421     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
422         const std::vector<int32_t> invalidZeroPoints =
423                 getInvalidZeroPoints(model.operands[operand].type);
424         for (int32_t invalidZeroPoint : invalidZeroPoints) {
425             const std::string message = "mutateOperandZeroPointTest: operand " +
426                                         std::to_string(operand) + " has zero point of " +
427                                         std::to_string(invalidZeroPoint);
428             validate(device, message, model,
429                      [operand, invalidZeroPoint](Model* model, ExecutionPreference*) {
430                          model->operands[operand].zeroPoint = invalidZeroPoint;
431                      });
432         }
433     }
434 }
435 
436 ///////////////////////// VALIDATE OPERAND LIFETIME /////////////////////////////////////////////
437 
getInvalidLifeTimes(const Model & model,size_t modelSize,const Operand & operand)438 static std::vector<OperandLifeTime> getInvalidLifeTimes(const Model& model, size_t modelSize,
439                                                         const Operand& operand) {
440     // TODO: Support OperandLifeTime::CONSTANT_REFERENCE as an invalid lifetime
441     // TODO: Support OperandLifeTime::NO_VALUE as an invalid lifetime
442 
443     // Ways to get an invalid lifetime:
444     // - change whether a lifetime means an operand should have a writer
445     std::vector<OperandLifeTime> ret;
446     switch (operand.lifetime) {
447         case OperandLifeTime::MODEL_OUTPUT:
448         case OperandLifeTime::TEMPORARY_VARIABLE:
449             ret = {
450                     OperandLifeTime::MODEL_INPUT,
451                     OperandLifeTime::CONSTANT_COPY,
452             };
453             break;
454         case OperandLifeTime::CONSTANT_COPY:
455         case OperandLifeTime::CONSTANT_REFERENCE:
456         case OperandLifeTime::MODEL_INPUT:
457             ret = {
458                     OperandLifeTime::TEMPORARY_VARIABLE,
459                     OperandLifeTime::MODEL_OUTPUT,
460             };
461             break;
462         case OperandLifeTime::NO_VALUE:
463             // Not enough information to know whether
464             // TEMPORARY_VARIABLE or CONSTANT_COPY would be invalid --
465             // is this operand written (then CONSTANT_COPY would be
466             // invalid) or not (then TEMPORARY_VARIABLE would be
467             // invalid)?
468             break;
469         default:
470             ADD_FAILURE();
471             break;
472     }
473 
474     const size_t operandSize = sizeOfData(operand);  // will be zero if shape is unknown
475     if (!operandSize ||
476         exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
477         // Unknown size or too-large size
478         ret.erase(std::remove(ret.begin(), ret.end(), OperandLifeTime::CONSTANT_COPY), ret.end());
479     }
480 
481     return ret;
482 }
483 
mutateOperandLifeTimeTest(const sp<IDevice> & device,const V1_1::Model & model)484 static void mutateOperandLifeTimeTest(const sp<IDevice>& device, const V1_1::Model& model) {
485     const size_t modelSize = sizeForBinder(model);
486     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
487         const std::vector<OperandLifeTime> invalidLifeTimes =
488                 getInvalidLifeTimes(model, modelSize, model.operands[operand]);
489         for (OperandLifeTime invalidLifeTime : invalidLifeTimes) {
490             const std::string message = "mutateOperandLifetimeTest: operand " +
491                                         std::to_string(operand) + " has lifetime " +
492                                         toString(invalidLifeTime) + " instead of lifetime " +
493                                         toString(model.operands[operand].lifetime);
494             validate(device, message, model,
495                      [operand, invalidLifeTime](Model* model, ExecutionPreference*) {
496                          static const DataLocation kZeroDataLocation = {};
497                          Operand& operandObj = model->operands[operand];
498                          switch (operandObj.lifetime) {
499                              case OperandLifeTime::MODEL_INPUT: {
500                                  hidl_vec_remove(&model->inputIndexes, uint32_t(operand));
501                                  break;
502                              }
503                              case OperandLifeTime::MODEL_OUTPUT: {
504                                  hidl_vec_remove(&model->outputIndexes, uint32_t(operand));
505                                  break;
506                              }
507                              default:
508                                  break;
509                          }
510                          operandObj.lifetime = invalidLifeTime;
511                          operandObj.location = kZeroDataLocation;
512                          switch (invalidLifeTime) {
513                              case OperandLifeTime::CONSTANT_COPY: {
514                                  becomeConstantCopy(model, &operandObj);
515                                  break;
516                              }
517                              case OperandLifeTime::MODEL_INPUT:
518                                  hidl_vec_push_back(&model->inputIndexes, uint32_t(operand));
519                                  break;
520                              case OperandLifeTime::MODEL_OUTPUT:
521                                  hidl_vec_push_back(&model->outputIndexes, uint32_t(operand));
522                                  break;
523                              default:
524                                  break;
525                          }
526                      });
527         }
528     }
529 }
530 
531 ///////////////////////// VALIDATE OPERAND INPUT-or-OUTPUT //////////////////////////////////////
532 
getInputOutputLifeTime(const Model & model,size_t modelSize,const Operand & operand)533 static std::optional<OperandLifeTime> getInputOutputLifeTime(const Model& model, size_t modelSize,
534                                                              const Operand& operand) {
535     // Ways to get an invalid lifetime (with respect to model inputIndexes and outputIndexes):
536     // - change whether a lifetime means an operand is a model input, a model output, or neither
537     // - preserve whether or not a lifetime means an operand should have a writer
538     switch (operand.lifetime) {
539         case OperandLifeTime::CONSTANT_COPY:
540         case OperandLifeTime::CONSTANT_REFERENCE:
541             return OperandLifeTime::MODEL_INPUT;
542         case OperandLifeTime::MODEL_INPUT: {
543             const size_t operandSize = sizeOfData(operand);  // will be zero if shape is unknown
544             if (!operandSize ||
545                 exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
546                 // Unknown size or too-large size
547                 break;
548             }
549             return OperandLifeTime::CONSTANT_COPY;
550         }
551         case OperandLifeTime::MODEL_OUTPUT:
552             return OperandLifeTime::TEMPORARY_VARIABLE;
553         case OperandLifeTime::TEMPORARY_VARIABLE:
554             return OperandLifeTime::MODEL_OUTPUT;
555         case OperandLifeTime::NO_VALUE:
556             // Not enough information to know whether
557             // TEMPORARY_VARIABLE or CONSTANT_COPY would be an
558             // appropriate choice -- is this operand written (then
559             // TEMPORARY_VARIABLE would be appropriate) or not (then
560             // CONSTANT_COPY would be appropriate)?
561             break;
562         default:
563             ADD_FAILURE();
564             break;
565     }
566 
567     return std::nullopt;
568 }
569 
mutateOperandInputOutputTest(const sp<IDevice> & device,const V1_1::Model & model)570 static void mutateOperandInputOutputTest(const sp<IDevice>& device, const V1_1::Model& model) {
571     const size_t modelSize = sizeForBinder(model);
572     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
573         const std::optional<OperandLifeTime> changedLifeTime =
574                 getInputOutputLifeTime(model, modelSize, model.operands[operand]);
575         if (changedLifeTime) {
576             const std::string message = "mutateOperandInputOutputTest: operand " +
577                                         std::to_string(operand) + " has lifetime " +
578                                         toString(*changedLifeTime) + " instead of lifetime " +
579                                         toString(model.operands[operand].lifetime);
580             validate(device, message, model,
581                      [operand, changedLifeTime](Model* model, ExecutionPreference*) {
582                          static const DataLocation kZeroDataLocation = {};
583                          Operand& operandObj = model->operands[operand];
584                          operandObj.lifetime = *changedLifeTime;
585                          operandObj.location = kZeroDataLocation;
586                          if (*changedLifeTime == OperandLifeTime::CONSTANT_COPY) {
587                              becomeConstantCopy(model, &operandObj);
588                          }
589                      });
590         }
591     }
592 }
593 
594 ///////////////////////// VALIDATE OPERAND NUMBER OF CONSUMERS //////////////////////////////////
595 
getInvalidNumberOfConsumers(uint32_t numberOfConsumers)596 static std::vector<uint32_t> getInvalidNumberOfConsumers(uint32_t numberOfConsumers) {
597     if (numberOfConsumers == 0) {
598         return {1};
599     } else {
600         return {numberOfConsumers - 1, numberOfConsumers + 1};
601     }
602 }
603 
mutateOperandNumberOfConsumersTest(const sp<IDevice> & device,const V1_1::Model & model)604 static void mutateOperandNumberOfConsumersTest(const sp<IDevice>& device,
605                                                const V1_1::Model& model) {
606     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
607         const std::vector<uint32_t> invalidNumberOfConsumersVec =
608                 getInvalidNumberOfConsumers(model.operands[operand].numberOfConsumers);
609         for (uint32_t invalidNumberOfConsumers : invalidNumberOfConsumersVec) {
610             const std::string message =
611                     "mutateOperandNumberOfConsumersTest: operand " + std::to_string(operand) +
612                     " numberOfConsumers = " + std::to_string(invalidNumberOfConsumers);
613             validate(device, message, model,
614                      [operand, invalidNumberOfConsumers](Model* model, ExecutionPreference*) {
615                          model->operands[operand].numberOfConsumers = invalidNumberOfConsumers;
616                      });
617         }
618     }
619 }
620 
621 ///////////////////////// VALIDATE OPERAND NUMBER OF WRITERS ////////////////////////////////////
622 
mutateOperandAddWriterTest(const sp<IDevice> & device,const V1_1::Model & model)623 static void mutateOperandAddWriterTest(const sp<IDevice>& device, const V1_1::Model& model) {
624     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
625         for (size_t badOutputNum = 0; badOutputNum < model.operations[operation].outputs.size();
626              ++badOutputNum) {
627             const uint32_t outputOperandIndex = model.operations[operation].outputs[badOutputNum];
628             const std::string message = "mutateOperandAddWriterTest: operation " +
629                                         std::to_string(operation) + " writes to " +
630                                         std::to_string(outputOperandIndex);
631             // We'll insert a copy of the operation, all of whose
632             // OTHER output operands are newly-created -- i.e.,
633             // there'll only be a duplicate write of ONE of that
634             // operation's output operands.
635             validate(device, message, model,
636                      [operation, badOutputNum](Model* model, ExecutionPreference*) {
637                          Operation newOperation = model->operations[operation];
638                          for (uint32_t input : newOperation.inputs) {
639                              ++model->operands[input].numberOfConsumers;
640                          }
641                          for (size_t outputNum = 0; outputNum < newOperation.outputs.size();
642                               ++outputNum) {
643                              if (outputNum == badOutputNum) continue;
644 
645                              Operand operandValue =
646                                      model->operands[newOperation.outputs[outputNum]];
647                              operandValue.numberOfConsumers = 0;
648                              if (operandValue.lifetime == OperandLifeTime::MODEL_OUTPUT) {
649                                  operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
650                              } else {
651                                  ASSERT_EQ(operandValue.lifetime,
652                                            OperandLifeTime::TEMPORARY_VARIABLE);
653                              }
654                              newOperation.outputs[outputNum] =
655                                      hidl_vec_push_back(&model->operands, operandValue);
656                          }
657                          // Where do we insert the extra writer (a new
658                          // operation)?  It has to be later than all the
659                          // writers of its inputs.  The easiest thing to do
660                          // is to insert it at the end of the operation
661                          // sequence.
662                          hidl_vec_push_back(&model->operations, newOperation);
663                      });
664         }
665     }
666 }
667 
668 ///////////////////////// VALIDATE EXTRA ??? /////////////////////////
669 
670 // TODO: Operand::location
671 
672 ///////////////////////// VALIDATE OPERATION OPERAND TYPE /////////////////////////
673 
mutateOperand(Operand * operand,OperandType type)674 static void mutateOperand(Operand* operand, OperandType type) {
675     Operand newOperand = *operand;
676     newOperand.type = type;
677     switch (type) {
678         case OperandType::FLOAT32:
679         case OperandType::INT32:
680         case OperandType::UINT32:
681             newOperand.dimensions = hidl_vec<uint32_t>();
682             newOperand.scale = 0.0f;
683             newOperand.zeroPoint = 0;
684             break;
685         case OperandType::TENSOR_FLOAT32:
686             newOperand.dimensions =
687                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
688             newOperand.scale = 0.0f;
689             newOperand.zeroPoint = 0;
690             break;
691         case OperandType::TENSOR_INT32:
692             newOperand.dimensions =
693                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
694             newOperand.zeroPoint = 0;
695             break;
696         case OperandType::TENSOR_QUANT8_ASYMM:
697             newOperand.dimensions =
698                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
699             newOperand.scale = operand->scale != 0.0f ? operand->scale : 1.0f;
700             break;
701         case OperandType::OEM:
702         case OperandType::TENSOR_OEM_BYTE:
703         default:
704             break;
705     }
706     *operand = newOperand;
707 }
708 
mutateOperationOperandTypeSkip(size_t operand,const Model & model)709 static bool mutateOperationOperandTypeSkip(size_t operand, const Model& model) {
710     // LSH_PROJECTION's second argument is allowed to have any type. This is the
711     // only operation that currently has a type that can be anything independent
712     // from any other type. Changing the operand type to any other type will
713     // result in a valid model for LSH_PROJECTION. If this is the case, skip the
714     // test.
715     for (const Operation& operation : model.operations) {
716         if (operation.type == OperationType::LSH_PROJECTION && operand == operation.inputs[1]) {
717             return true;
718         }
719     }
720     return false;
721 }
722 
mutateOperationOperandTypeTest(const sp<IDevice> & device,const Model & model)723 static void mutateOperationOperandTypeTest(const sp<IDevice>& device, const Model& model) {
724     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
725         if (mutateOperationOperandTypeSkip(operand, model)) {
726             continue;
727         }
728         for (OperandType invalidOperandType : hidl_enum_range<OperandType>{}) {
729             // Do not test OEM types
730             if (invalidOperandType == model.operands[operand].type ||
731                 invalidOperandType == OperandType::OEM ||
732                 invalidOperandType == OperandType::TENSOR_OEM_BYTE) {
733                 continue;
734             }
735             const std::string message = "mutateOperationOperandTypeTest: operand " +
736                                         std::to_string(operand) + " set to type " +
737                                         toString(invalidOperandType);
738             validate(device, message, model,
739                      [operand, invalidOperandType](Model* model, ExecutionPreference*) {
740                          mutateOperand(&model->operands[operand], invalidOperandType);
741                      });
742         }
743     }
744 }
745 
746 ///////////////////////// VALIDATE MODEL OPERATION TYPE /////////////////////////
747 
748 static const int32_t invalidOperationTypes[] = {
749         static_cast<int32_t>(OperationType::ADD) - 1,            // lower bound fundamental
750         static_cast<int32_t>(OperationType::TRANSPOSE) + 1,      // upper bound fundamental
751         static_cast<int32_t>(OperationType::OEM_OPERATION) - 1,  // lower bound OEM
752         static_cast<int32_t>(OperationType::OEM_OPERATION) + 1,  // upper bound OEM
753 };
754 
mutateOperationTypeTest(const sp<IDevice> & device,const Model & model)755 static void mutateOperationTypeTest(const sp<IDevice>& device, const Model& model) {
756     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
757         for (int32_t invalidOperationType : invalidOperationTypes) {
758             const std::string message = "mutateOperationTypeTest: operation " +
759                                         std::to_string(operation) + " set to value " +
760                                         std::to_string(invalidOperationType);
761             validate(device, message, model,
762                      [operation, invalidOperationType](Model* model, ExecutionPreference*) {
763                          model->operations[operation].type =
764                                  static_cast<OperationType>(invalidOperationType);
765                      });
766         }
767     }
768 }
769 
770 ///////////////////////// VALIDATE MODEL OPERATION INPUT OPERAND INDEX /////////////////////////
771 
mutateOperationInputOperandIndexTest(const sp<IDevice> & device,const Model & model)772 static void mutateOperationInputOperandIndexTest(const sp<IDevice>& device, const Model& model) {
773     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
774         const uint32_t invalidOperand = model.operands.size();
775         for (size_t input = 0; input < model.operations[operation].inputs.size(); ++input) {
776             const std::string message = "mutateOperationInputOperandIndexTest: operation " +
777                                         std::to_string(operation) + " input " +
778                                         std::to_string(input);
779             validate(device, message, model,
780                      [operation, input, invalidOperand](Model* model, ExecutionPreference*) {
781                          model->operations[operation].inputs[input] = invalidOperand;
782                      });
783         }
784     }
785 }
786 
787 ///////////////////////// VALIDATE MODEL OPERATION OUTPUT OPERAND INDEX /////////////////////////
788 
mutateOperationOutputOperandIndexTest(const sp<IDevice> & device,const Model & model)789 static void mutateOperationOutputOperandIndexTest(const sp<IDevice>& device, const Model& model) {
790     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
791         const uint32_t invalidOperand = model.operands.size();
792         for (size_t output = 0; output < model.operations[operation].outputs.size(); ++output) {
793             const std::string message = "mutateOperationOutputOperandIndexTest: operation " +
794                                         std::to_string(operation) + " output " +
795                                         std::to_string(output);
796             validate(device, message, model,
797                      [operation, output, invalidOperand](Model* model, ExecutionPreference*) {
798                          model->operations[operation].outputs[output] = invalidOperand;
799                      });
800         }
801     }
802 }
803 
804 ///////////////////////// VALIDATE MODEL OPERANDS WRITTEN ///////////////////////////////////////
805 
mutateOperationRemoveWriteTest(const sp<IDevice> & device,const V1_1::Model & model)806 static void mutateOperationRemoveWriteTest(const sp<IDevice>& device, const V1_1::Model& model) {
807     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
808         for (size_t outputNum = 0; outputNum < model.operations[operation].outputs.size();
809              ++outputNum) {
810             const uint32_t outputOperandIndex = model.operations[operation].outputs[outputNum];
811             if (model.operands[outputOperandIndex].numberOfConsumers > 0) {
812                 const std::string message = "mutateOperationRemoveWriteTest: operation " +
813                                             std::to_string(operation) + " writes to " +
814                                             std::to_string(outputOperandIndex);
815                 validate(device, message, model,
816                          [operation, outputNum](Model* model, ExecutionPreference*) {
817                              uint32_t& outputOperandIndex =
818                                      model->operations[operation].outputs[outputNum];
819                              Operand operandValue = model->operands[outputOperandIndex];
820                              operandValue.numberOfConsumers = 0;
821                              if (operandValue.lifetime == OperandLifeTime::MODEL_OUTPUT) {
822                                  operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
823                              } else {
824                                  ASSERT_EQ(operandValue.lifetime,
825                                            OperandLifeTime::TEMPORARY_VARIABLE);
826                              }
827                              outputOperandIndex =
828                                      hidl_vec_push_back(&model->operands, operandValue);
829                          });
830             }
831         }
832     }
833 }
834 
835 ///////////////////////// REMOVE OPERAND FROM EVERYTHING /////////////////////////
836 
removeValueAndDecrementGreaterValues(hidl_vec<uint32_t> * vec,uint32_t value)837 static void removeValueAndDecrementGreaterValues(hidl_vec<uint32_t>* vec, uint32_t value) {
838     if (vec) {
839         // remove elements matching "value"
840         auto last = std::remove(vec->begin(), vec->end(), value);
841         vec->resize(std::distance(vec->begin(), last));
842 
843         // decrement elements exceeding "value"
844         std::transform(vec->begin(), vec->end(), vec->begin(),
845                        [value](uint32_t v) { return v > value ? v-- : v; });
846     }
847 }
848 
removeOperand(Model * model,uint32_t index)849 static void removeOperand(Model* model, uint32_t index) {
850     hidl_vec_removeAt(&model->operands, index);
851     for (Operation& operation : model->operations) {
852         removeValueAndDecrementGreaterValues(&operation.inputs, index);
853         removeValueAndDecrementGreaterValues(&operation.outputs, index);
854     }
855     removeValueAndDecrementGreaterValues(&model->inputIndexes, index);
856     removeValueAndDecrementGreaterValues(&model->outputIndexes, index);
857 }
858 
removeOperandTest(const sp<IDevice> & device,const Model & model)859 static void removeOperandTest(const sp<IDevice>& device, const Model& model) {
860     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
861         const std::string message = "removeOperandTest: operand " + std::to_string(operand);
862         validate(device, message, model,
863                  [operand](Model* model, ExecutionPreference*) { removeOperand(model, operand); });
864     }
865 }
866 
867 ///////////////////////// REMOVE OPERATION /////////////////////////
868 
removeOperation(Model * model,uint32_t index)869 static void removeOperation(Model* model, uint32_t index) {
870     for (uint32_t operand : model->operations[index].inputs) {
871         model->operands[operand].numberOfConsumers--;
872     }
873     hidl_vec_removeAt(&model->operations, index);
874 }
875 
removeOperationTest(const sp<IDevice> & device,const Model & model)876 static void removeOperationTest(const sp<IDevice>& device, const Model& model) {
877     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
878         const std::string message = "removeOperationTest: operation " + std::to_string(operation);
879         validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
880             removeOperation(model, operation);
881         });
882     }
883 }
884 
885 ///////////////////////// REMOVE OPERATION INPUT /////////////////////////
886 
removeOperationInputTest(const sp<IDevice> & device,const Model & model)887 static void removeOperationInputTest(const sp<IDevice>& device, const Model& model) {
888     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
889         for (size_t input = 0; input < model.operations[operation].inputs.size(); ++input) {
890             const Operation& op = model.operations[operation];
891             // CONCATENATION has at least 2 inputs, with the last element being
892             // INT32. Skip this test if removing one of CONCATENATION's
893             // inputs still produces a valid model.
894             if (op.type == OperationType::CONCATENATION && op.inputs.size() > 2 &&
895                 input != op.inputs.size() - 1) {
896                 continue;
897             }
898             const std::string message = "removeOperationInputTest: operation " +
899                                         std::to_string(operation) + ", input " +
900                                         std::to_string(input);
901             validate(device, message, model,
902                      [operation, input](Model* model, ExecutionPreference*) {
903                          uint32_t operand = model->operations[operation].inputs[input];
904                          model->operands[operand].numberOfConsumers--;
905                          hidl_vec_removeAt(&model->operations[operation].inputs, input);
906                      });
907         }
908     }
909 }
910 
911 ///////////////////////// REMOVE OPERATION OUTPUT /////////////////////////
912 
removeOperationOutputTest(const sp<IDevice> & device,const Model & model)913 static void removeOperationOutputTest(const sp<IDevice>& device, const Model& model) {
914     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
915         for (size_t output = 0; output < model.operations[operation].outputs.size(); ++output) {
916             const std::string message = "removeOperationOutputTest: operation " +
917                                         std::to_string(operation) + ", output " +
918                                         std::to_string(output);
919             validate(device, message, model,
920                      [operation, output](Model* model, ExecutionPreference*) {
921                          hidl_vec_removeAt(&model->operations[operation].outputs, output);
922                      });
923         }
924     }
925 }
926 
927 ///////////////////////// MODEL VALIDATION /////////////////////////
928 
929 // TODO: remove model input
930 // TODO: remove model output
931 // TODO: add unused operation
932 
933 ///////////////////////// ADD OPERATION INPUT /////////////////////////
934 
addOperationInputTest(const sp<IDevice> & device,const Model & model)935 static void addOperationInputTest(const sp<IDevice>& device, const Model& model) {
936     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
937         const std::string message = "addOperationInputTest: operation " + std::to_string(operation);
938         validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
939             uint32_t index = addOperand(model, OperandLifeTime::MODEL_INPUT);
940             hidl_vec_push_back(&model->operations[operation].inputs, index);
941             hidl_vec_push_back(&model->inputIndexes, index);
942         });
943     }
944 }
945 
946 ///////////////////////// ADD OPERATION OUTPUT /////////////////////////
947 
addOperationOutputTest(const sp<IDevice> & device,const Model & model)948 static void addOperationOutputTest(const sp<IDevice>& device, const Model& model) {
949     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
950         const std::string message =
951                 "addOperationOutputTest: operation " + std::to_string(operation);
952         validate(device, message, model, [operation](Model* model, ExecutionPreference*) {
953             uint32_t index = addOperand(model, OperandLifeTime::MODEL_OUTPUT);
954             hidl_vec_push_back(&model->operations[operation].outputs, index);
955             hidl_vec_push_back(&model->outputIndexes, index);
956         });
957     }
958 }
959 
960 ///////////////////////// VALIDATE EXECUTION PREFERENCE /////////////////////////
961 
962 static const int32_t invalidExecutionPreferences[] = {
963         static_cast<int32_t>(ExecutionPreference::LOW_POWER) - 1,        // lower bound
964         static_cast<int32_t>(ExecutionPreference::SUSTAINED_SPEED) + 1,  // upper bound
965 };
966 
mutateExecutionPreferenceTest(const sp<IDevice> & device,const Model & model)967 static void mutateExecutionPreferenceTest(const sp<IDevice>& device, const Model& model) {
968     for (int32_t invalidPreference : invalidExecutionPreferences) {
969         const std::string message =
970                 "mutateExecutionPreferenceTest: preference " + std::to_string(invalidPreference);
971         validate(device, message, model,
972                  [invalidPreference](Model*, ExecutionPreference* preference) {
973                      *preference = static_cast<ExecutionPreference>(invalidPreference);
974                  });
975     }
976 }
977 
978 ////////////////////////// ENTRY POINT //////////////////////////////
979 
validateModel(const sp<IDevice> & device,const Model & model)980 void validateModel(const sp<IDevice>& device, const Model& model) {
981     mutateExecutionOrderTest(device, model);
982     mutateOperandTypeTest(device, model);
983     mutateOperandRankTest(device, model);
984     mutateOperandScaleTest(device, model);
985     mutateOperandZeroPointTest(device, model);
986     mutateOperandLifeTimeTest(device, model);
987     mutateOperandInputOutputTest(device, model);
988     mutateOperandNumberOfConsumersTest(device, model);
989     mutateOperandAddWriterTest(device, model);
990     mutateOperationOperandTypeTest(device, model);
991     mutateOperationTypeTest(device, model);
992     mutateOperationInputOperandIndexTest(device, model);
993     mutateOperationOutputOperandIndexTest(device, model);
994     mutateOperationRemoveWriteTest(device, model);
995     removeOperandTest(device, model);
996     removeOperationTest(device, model);
997     removeOperationInputTest(device, model);
998     removeOperationOutputTest(device, model);
999     addOperationInputTest(device, model);
1000     addOperationOutputTest(device, model);
1001     mutateExecutionPreferenceTest(device, model);
1002 }
1003 
1004 }  // namespace android::hardware::neuralnetworks::V1_1::vts::functional
1005