1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "neuralnetworks_hidl_hal_test"
18 
19 #include "1.0/Callbacks.h"
20 #include "1.0/Utils.h"
21 #include "GeneratedTestHarness.h"
22 #include "VtsHalNeuralnetworks.h"
23 
24 #include <optional>
25 #include <type_traits>
26 #include <utility>
27 
28 namespace android::hardware::neuralnetworks::V1_0::vts::functional {
29 
30 using implementation::PreparedModelCallback;
31 
32 using PrepareModelMutation = std::function<void(Model*)>;
33 
34 ///////////////////////// UTILITY FUNCTIONS /////////////////////////
35 
validateGetSupportedOperations(const sp<IDevice> & device,const std::string & message,const Model & model)36 static void validateGetSupportedOperations(const sp<IDevice>& device, const std::string& message,
37                                            const Model& model) {
38     SCOPED_TRACE(message + " [getSupportedOperations]");
39 
40     Return<void> ret =
41             device->getSupportedOperations(model, [&](ErrorStatus status, const hidl_vec<bool>&) {
42                 EXPECT_EQ(ErrorStatus::INVALID_ARGUMENT, status);
43             });
44     EXPECT_TRUE(ret.isOk());
45 }
46 
validatePrepareModel(const sp<IDevice> & device,const std::string & message,const Model & model)47 static void validatePrepareModel(const sp<IDevice>& device, const std::string& message,
48                                  const Model& model) {
49     SCOPED_TRACE(message + " [prepareModel]");
50 
51     sp<PreparedModelCallback> preparedModelCallback = new PreparedModelCallback();
52     Return<ErrorStatus> prepareLaunchStatus = device->prepareModel(model, preparedModelCallback);
53     ASSERT_TRUE(prepareLaunchStatus.isOk());
54     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, static_cast<ErrorStatus>(prepareLaunchStatus));
55 
56     preparedModelCallback->wait();
57     ErrorStatus prepareReturnStatus = preparedModelCallback->getStatus();
58     ASSERT_EQ(ErrorStatus::INVALID_ARGUMENT, prepareReturnStatus);
59     sp<IPreparedModel> preparedModel = preparedModelCallback->getPreparedModel();
60     ASSERT_EQ(nullptr, preparedModel.get());
61 }
62 
63 // Primary validation function. This function will take a valid model, apply a
64 // mutation to invalidate the model, then pass these to supportedOperations and
65 // prepareModel.
validate(const sp<IDevice> & device,const std::string & message,const Model & originalModel,const PrepareModelMutation & mutate)66 static void validate(const sp<IDevice>& device, const std::string& message,
67                      const Model& originalModel, const PrepareModelMutation& mutate) {
68     Model model = originalModel;
69     mutate(&model);
70 
71     validateGetSupportedOperations(device, message, model);
72     validatePrepareModel(device, message, model);
73 }
74 
addOperand(Model * model)75 static uint32_t addOperand(Model* model) {
76     return hidl_vec_push_back(&model->operands,
77                               {
78                                       .type = OperandType::INT32,
79                                       .dimensions = {},
80                                       .numberOfConsumers = 0,
81                                       .scale = 0.0f,
82                                       .zeroPoint = 0,
83                                       .lifetime = OperandLifeTime::MODEL_INPUT,
84                                       .location = {.poolIndex = 0, .offset = 0, .length = 0},
85                               });
86 }
87 
addOperand(Model * model,OperandLifeTime lifetime)88 static uint32_t addOperand(Model* model, OperandLifeTime lifetime) {
89     uint32_t index = addOperand(model);
90     model->operands[index].numberOfConsumers = 1;
91     model->operands[index].lifetime = lifetime;
92     return index;
93 }
94 
95 // If we introduce a CONSTANT_COPY for an operand of size operandSize,
96 // how much will this increase the size of the model?  This assumes
97 // that we can (re)use all of model.operandValues for the operand
98 // value.
constantCopyExtraSize(const Model & model,size_t operandSize)99 static size_t constantCopyExtraSize(const Model& model, size_t operandSize) {
100     const size_t operandValuesSize = model.operandValues.size();
101     return (operandValuesSize < operandSize) ? (operandSize - operandValuesSize) : 0;
102 }
103 
104 // Highly specialized utility routine for converting an operand to
105 // CONSTANT_COPY lifetime.
106 //
107 // Expects that:
108 // - operand has a known size
109 // - operand->lifetime has already been set to CONSTANT_COPY
110 // - operand->location has been zeroed out
111 //
112 // Does the following:
113 // - initializes operand->location to point to the beginning of model->operandValues
114 // - resizes model->operandValues (if necessary) to be large enough for the operand
115 //   value, padding it with zeroes on the end
116 //
117 // Potential problem:
118 // By changing the operand to CONSTANT_COPY lifetime, this function is effectively initializing the
119 // operand with unspecified (but deterministic) data. This means that the model may be invalidated
120 // in two ways: not only is the lifetime of CONSTANT_COPY invalid, but the operand's value in the
121 // graph may also be invalid (e.g., if the operand is used as an activation code and has an invalid
122 // value). For now, this should be fine because it just means we're not testing what we think we're
123 // testing in certain cases; but we can handwave this and assume we're probabilistically likely to
124 // exercise the validation code over the span of the entire test set and operand space.
125 //
126 // Aborts if the specified operand type is an extension type or OEM type.
becomeConstantCopy(Model * model,Operand * operand)127 static void becomeConstantCopy(Model* model, Operand* operand) {
128     // sizeOfData will abort if the specified type is an extension type or OEM type.
129     const size_t sizeOfOperand = sizeOfData(*operand);
130     EXPECT_NE(sizeOfOperand, size_t(0));
131     operand->location.poolIndex = 0;
132     operand->location.offset = 0;
133     operand->location.length = sizeOfOperand;
134     if (model->operandValues.size() < sizeOfOperand) {
135         model->operandValues.resize(sizeOfOperand);
136     }
137 }
138 
139 // The sizeForBinder() functions estimate the size of the
140 // representation of a value when sent to binder.  It's probably a bit
141 // of an under-estimate, because we don't know the size of the
142 // metadata in the binder format (e.g., representation of the size of
143 // a vector); but at least it adds up "big" things like vector
144 // contents.  However, it doesn't treat inter-field or end-of-struct
145 // padding in a methodical way -- there's no attempt to be consistent
146 // in whether or not padding in the native (C++) representation
147 // contributes to the estimated size for the binder representation;
148 // and there's no attempt to understand what padding (if any) is
149 // needed in the binder representation.
150 //
151 // This assumes that non-metadata uses a fixed length encoding (e.g.,
152 // a uint32_t is always encoded in sizeof(uint32_t) bytes, rather than
153 // using an encoding whose length is related to the magnitude of the
154 // encoded value).
155 
156 template <typename Type>
sizeForBinder(const Type & val)157 static size_t sizeForBinder(const Type& val) {
158     static_assert(std::is_trivially_copyable_v<std::remove_reference_t<Type>>,
159                   "expected a trivially copyable type");
160     return sizeof(val);
161 }
162 
163 template <typename Type>
sizeForBinder(const hidl_vec<Type> & vec)164 static size_t sizeForBinder(const hidl_vec<Type>& vec) {
165     return std::accumulate(vec.begin(), vec.end(), 0,
166                            [](size_t acc, const Type& x) { return acc + sizeForBinder(x); });
167 }
168 
169 template <>
sizeForBinder(const Operand & operand)170 size_t sizeForBinder(const Operand& operand) {
171     size_t size = 0;
172 
173     size += sizeForBinder(operand.type);
174     size += sizeForBinder(operand.dimensions);
175     size += sizeForBinder(operand.numberOfConsumers);
176     size += sizeForBinder(operand.scale);
177     size += sizeForBinder(operand.zeroPoint);
178     size += sizeForBinder(operand.lifetime);
179     size += sizeForBinder(operand.location);
180 
181     return size;
182 }
183 
184 template <>
sizeForBinder(const Operation & operation)185 size_t sizeForBinder(const Operation& operation) {
186     size_t size = 0;
187 
188     size += sizeForBinder(operation.type);
189     size += sizeForBinder(operation.inputs);
190     size += sizeForBinder(operation.outputs);
191 
192     return size;
193 }
194 
195 template <>
sizeForBinder(const hidl_string & name)196 size_t sizeForBinder(const hidl_string& name) {
197     return name.size();
198 }
199 
200 template <>
sizeForBinder(const hidl_memory & memory)201 size_t sizeForBinder(const hidl_memory& memory) {
202     // This is just a guess.
203 
204     size_t size = 0;
205 
206     if (const native_handle_t* handle = memory.handle()) {
207         size += sizeof(*handle);
208         size += sizeof(handle->data[0] * (handle->numFds + handle->numInts));
209     }
210     size += sizeForBinder(memory.name());
211 
212     return size;
213 }
214 
215 template <>
sizeForBinder(const Model & model)216 size_t sizeForBinder(const Model& model) {
217     size_t size = 0;
218 
219     size += sizeForBinder(model.operands);
220     size += sizeForBinder(model.operations);
221     size += sizeForBinder(model.inputIndexes);
222     size += sizeForBinder(model.outputIndexes);
223     size += sizeForBinder(model.operandValues);
224     size += sizeForBinder(model.pools);
225 
226     return size;
227 }
228 
229 // https://developer.android.com/reference/android/os/TransactionTooLargeException.html
230 //
231 //     "The Binder transaction buffer has a limited fixed size,
232 //     currently 1Mb, which is shared by all transactions in progress
233 //     for the process."
234 //
235 // Will our representation fit under this limit?  There are two complications:
236 // - Our representation size is just approximate (see sizeForBinder()).
237 // - This object may not be the only occupant of the Binder transaction buffer.
238 // So we'll be very conservative: We want the representation size to be no
239 // larger than half the transaction buffer size.
240 //
241 // If our representation grows large enough that it still fits within
242 // the transaction buffer but combined with other transactions may
243 // exceed the buffer size, then we may see intermittent HAL transport
244 // errors.
exceedsBinderSizeLimit(size_t representationSize)245 static bool exceedsBinderSizeLimit(size_t representationSize) {
246     // Instead of using this fixed buffer size, we might instead be able to use
247     // ProcessState::self()->getMmapSize(). However, this has a potential
248     // problem: The binder/mmap size of the current process does not necessarily
249     // indicate the binder/mmap size of the service (i.e., the other process).
250     // The only way it would be a good indication is if both the current process
251     // and the service use the default size.
252     static const size_t kHalfBufferSize = 1024 * 1024 / 2;
253 
254     return representationSize > kHalfBufferSize;
255 }
256 
257 ///////////////////////// VALIDATE EXECUTION ORDER ////////////////////////////
258 
mutateExecutionOrderTest(const sp<IDevice> & device,const V1_0::Model & model)259 static void mutateExecutionOrderTest(const sp<IDevice>& device, const V1_0::Model& model) {
260     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
261         const Operation& operationObj = model.operations[operation];
262         for (uint32_t input : operationObj.inputs) {
263             if (model.operands[input].lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
264                 model.operands[input].lifetime == OperandLifeTime::MODEL_OUTPUT) {
265                 // This operation reads an operand written by some
266                 // other operation.  Move this operation to the
267                 // beginning of the sequence, ensuring that it reads
268                 // the operand before that operand is written, thereby
269                 // violating execution order rules.
270                 const std::string message = "mutateExecutionOrderTest: operation " +
271                                             std::to_string(operation) + " is a reader";
272                 validate(device, message, model, [operation](Model* model) {
273                     auto& operations = model->operations;
274                     std::rotate(operations.begin(), operations.begin() + operation,
275                                 operations.begin() + operation + 1);
276                 });
277                 break;  // only need to do this once per operation
278             }
279         }
280         for (uint32_t output : operationObj.outputs) {
281             if (model.operands[output].numberOfConsumers > 0) {
282                 // This operation writes an operand read by some other
283                 // operation.  Move this operation to the end of the
284                 // sequence, ensuring that it writes the operand after
285                 // that operand is read, thereby violating execution
286                 // order rules.
287                 const std::string message = "mutateExecutionOrderTest: operation " +
288                                             std::to_string(operation) + " is a writer";
289                 validate(device, message, model, [operation](Model* model) {
290                     auto& operations = model->operations;
291                     std::rotate(operations.begin() + operation, operations.begin() + operation + 1,
292                                 operations.end());
293                 });
294                 break;  // only need to do this once per operation
295             }
296         }
297     }
298 }
299 
300 ///////////////////////// VALIDATE MODEL OPERAND TYPE /////////////////////////
301 
302 static const int32_t invalidOperandTypes[] = {
303         static_cast<int32_t>(OperandType::FLOAT32) - 1,              // lower bound fundamental
304         static_cast<int32_t>(OperandType::TENSOR_QUANT8_ASYMM) + 1,  // upper bound fundamental
305         static_cast<int32_t>(OperandType::OEM) - 1,                  // lower bound OEM
306         static_cast<int32_t>(OperandType::TENSOR_OEM_BYTE) + 1,      // upper bound OEM
307 };
308 
mutateOperandTypeTest(const sp<IDevice> & device,const Model & model)309 static void mutateOperandTypeTest(const sp<IDevice>& device, const Model& model) {
310     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
311         for (int32_t invalidOperandType : invalidOperandTypes) {
312             const std::string message = "mutateOperandTypeTest: operand " +
313                                         std::to_string(operand) + " set to value " +
314                                         std::to_string(invalidOperandType);
315             validate(device, message, model, [operand, invalidOperandType](Model* model) {
316                 model->operands[operand].type = static_cast<OperandType>(invalidOperandType);
317             });
318         }
319     }
320 }
321 
322 ///////////////////////// VALIDATE OPERAND RANK /////////////////////////
323 
getInvalidRank(OperandType type)324 static uint32_t getInvalidRank(OperandType type) {
325     switch (type) {
326         case OperandType::FLOAT32:
327         case OperandType::INT32:
328         case OperandType::UINT32:
329             return 1;
330         case OperandType::TENSOR_FLOAT32:
331         case OperandType::TENSOR_INT32:
332         case OperandType::TENSOR_QUANT8_ASYMM:
333             return 0;
334         default:
335             return 0;
336     }
337 }
338 
mutateOperandRankTest(const sp<IDevice> & device,const Model & model)339 static void mutateOperandRankTest(const sp<IDevice>& device, const Model& model) {
340     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
341         const uint32_t invalidRank = getInvalidRank(model.operands[operand].type);
342         const std::string message = "mutateOperandRankTest: operand " + std::to_string(operand) +
343                                     " has rank of " + std::to_string(invalidRank);
344         validate(device, message, model, [operand, invalidRank](Model* model) {
345             model->operands[operand].dimensions = std::vector<uint32_t>(invalidRank, 0);
346         });
347     }
348 }
349 
350 ///////////////////////// VALIDATE OPERAND SCALE /////////////////////////
351 
getInvalidScale(OperandType type)352 static float getInvalidScale(OperandType type) {
353     switch (type) {
354         case OperandType::FLOAT32:
355         case OperandType::INT32:
356         case OperandType::UINT32:
357         case OperandType::TENSOR_FLOAT32:
358             return 1.0f;
359         case OperandType::TENSOR_INT32:
360             return -1.0f;
361         case OperandType::TENSOR_QUANT8_ASYMM:
362             return 0.0f;
363         default:
364             return 0.0f;
365     }
366 }
367 
mutateOperandScaleTest(const sp<IDevice> & device,const Model & model)368 static void mutateOperandScaleTest(const sp<IDevice>& device, const Model& model) {
369     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
370         const float invalidScale = getInvalidScale(model.operands[operand].type);
371         const std::string message = "mutateOperandScaleTest: operand " + std::to_string(operand) +
372                                     " has scale of " + std::to_string(invalidScale);
373         validate(device, message, model, [operand, invalidScale](Model* model) {
374             model->operands[operand].scale = invalidScale;
375         });
376     }
377 }
378 
379 ///////////////////////// VALIDATE OPERAND ZERO POINT /////////////////////////
380 
getInvalidZeroPoints(OperandType type)381 static std::vector<int32_t> getInvalidZeroPoints(OperandType type) {
382     switch (type) {
383         case OperandType::FLOAT32:
384         case OperandType::INT32:
385         case OperandType::UINT32:
386         case OperandType::TENSOR_FLOAT32:
387         case OperandType::TENSOR_INT32:
388             return {1};
389         case OperandType::TENSOR_QUANT8_ASYMM:
390             return {-1, 256};
391         default:
392             return {};
393     }
394 }
395 
mutateOperandZeroPointTest(const sp<IDevice> & device,const Model & model)396 static void mutateOperandZeroPointTest(const sp<IDevice>& device, const Model& model) {
397     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
398         const std::vector<int32_t> invalidZeroPoints =
399                 getInvalidZeroPoints(model.operands[operand].type);
400         for (int32_t invalidZeroPoint : invalidZeroPoints) {
401             const std::string message = "mutateOperandZeroPointTest: operand " +
402                                         std::to_string(operand) + " has zero point of " +
403                                         std::to_string(invalidZeroPoint);
404             validate(device, message, model, [operand, invalidZeroPoint](Model* model) {
405                 model->operands[operand].zeroPoint = invalidZeroPoint;
406             });
407         }
408     }
409 }
410 
411 ///////////////////////// VALIDATE OPERAND LIFETIME /////////////////////////////////////////////
412 
getInvalidLifeTimes(const Model & model,size_t modelSize,const Operand & operand)413 static std::vector<OperandLifeTime> getInvalidLifeTimes(const Model& model, size_t modelSize,
414                                                         const Operand& operand) {
415     // TODO: Support OperandLifeTime::CONSTANT_REFERENCE as an invalid lifetime
416     // TODO: Support OperandLifeTime::NO_VALUE as an invalid lifetime
417 
418     // Ways to get an invalid lifetime:
419     // - change whether a lifetime means an operand should have a writer
420     std::vector<OperandLifeTime> ret;
421     switch (operand.lifetime) {
422         case OperandLifeTime::MODEL_OUTPUT:
423         case OperandLifeTime::TEMPORARY_VARIABLE:
424             ret = {
425                     OperandLifeTime::MODEL_INPUT,
426                     OperandLifeTime::CONSTANT_COPY,
427             };
428             break;
429         case OperandLifeTime::CONSTANT_COPY:
430         case OperandLifeTime::CONSTANT_REFERENCE:
431         case OperandLifeTime::MODEL_INPUT:
432             ret = {
433                     OperandLifeTime::TEMPORARY_VARIABLE,
434                     OperandLifeTime::MODEL_OUTPUT,
435             };
436             break;
437         case OperandLifeTime::NO_VALUE:
438             // Not enough information to know whether
439             // TEMPORARY_VARIABLE or CONSTANT_COPY would be invalid --
440             // is this operand written (then CONSTANT_COPY would be
441             // invalid) or not (then TEMPORARY_VARIABLE would be
442             // invalid)?
443             break;
444         default:
445             ADD_FAILURE();
446             break;
447     }
448 
449     const size_t operandSize = sizeOfData(operand);  // will be zero if shape is unknown
450     if (!operandSize ||
451         exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
452         // Unknown size or too-large size
453         ret.erase(std::remove(ret.begin(), ret.end(), OperandLifeTime::CONSTANT_COPY), ret.end());
454     }
455 
456     return ret;
457 }
458 
mutateOperandLifeTimeTest(const sp<IDevice> & device,const V1_0::Model & model)459 static void mutateOperandLifeTimeTest(const sp<IDevice>& device, const V1_0::Model& model) {
460     const size_t modelSize = sizeForBinder(model);
461     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
462         const std::vector<OperandLifeTime> invalidLifeTimes =
463                 getInvalidLifeTimes(model, modelSize, model.operands[operand]);
464         for (OperandLifeTime invalidLifeTime : invalidLifeTimes) {
465             const std::string message = "mutateOperandLifetimeTest: operand " +
466                                         std::to_string(operand) + " has lifetime " +
467                                         toString(invalidLifeTime) + " instead of lifetime " +
468                                         toString(model.operands[operand].lifetime);
469             validate(device, message, model, [operand, invalidLifeTime](Model* model) {
470                 static const DataLocation kZeroDataLocation = {};
471                 Operand& operandObj = model->operands[operand];
472                 switch (operandObj.lifetime) {
473                     case OperandLifeTime::MODEL_INPUT: {
474                         hidl_vec_remove(&model->inputIndexes, uint32_t(operand));
475                         break;
476                     }
477                     case OperandLifeTime::MODEL_OUTPUT: {
478                         hidl_vec_remove(&model->outputIndexes, uint32_t(operand));
479                         break;
480                     }
481                     default:
482                         break;
483                 }
484                 operandObj.lifetime = invalidLifeTime;
485                 operandObj.location = kZeroDataLocation;
486                 switch (invalidLifeTime) {
487                     case OperandLifeTime::CONSTANT_COPY: {
488                         becomeConstantCopy(model, &operandObj);
489                         break;
490                     }
491                     case OperandLifeTime::MODEL_INPUT:
492                         hidl_vec_push_back(&model->inputIndexes, uint32_t(operand));
493                         break;
494                     case OperandLifeTime::MODEL_OUTPUT:
495                         hidl_vec_push_back(&model->outputIndexes, uint32_t(operand));
496                         break;
497                     default:
498                         break;
499                 }
500             });
501         }
502     }
503 }
504 
505 ///////////////////////// VALIDATE OPERAND INPUT-or-OUTPUT //////////////////////////////////////
506 
getInputOutputLifeTime(const Model & model,size_t modelSize,const Operand & operand)507 static std::optional<OperandLifeTime> getInputOutputLifeTime(const Model& model, size_t modelSize,
508                                                              const Operand& operand) {
509     // Ways to get an invalid lifetime (with respect to model inputIndexes and outputIndexes):
510     // - change whether a lifetime means an operand is a model input, a model output, or neither
511     // - preserve whether or not a lifetime means an operand should have a writer
512     switch (operand.lifetime) {
513         case OperandLifeTime::CONSTANT_COPY:
514         case OperandLifeTime::CONSTANT_REFERENCE:
515             return OperandLifeTime::MODEL_INPUT;
516         case OperandLifeTime::MODEL_INPUT: {
517             const size_t operandSize = sizeOfData(operand);  // will be zero if shape is unknown
518             if (!operandSize ||
519                 exceedsBinderSizeLimit(modelSize + constantCopyExtraSize(model, operandSize))) {
520                 // Unknown size or too-large size
521                 break;
522             }
523             return OperandLifeTime::CONSTANT_COPY;
524         }
525         case OperandLifeTime::MODEL_OUTPUT:
526             return OperandLifeTime::TEMPORARY_VARIABLE;
527         case OperandLifeTime::TEMPORARY_VARIABLE:
528             return OperandLifeTime::MODEL_OUTPUT;
529         case OperandLifeTime::NO_VALUE:
530             // Not enough information to know whether
531             // TEMPORARY_VARIABLE or CONSTANT_COPY would be an
532             // appropriate choice -- is this operand written (then
533             // TEMPORARY_VARIABLE would be appropriate) or not (then
534             // CONSTANT_COPY would be appropriate)?
535             break;
536         default:
537             ADD_FAILURE();
538             break;
539     }
540 
541     return std::nullopt;
542 }
543 
mutateOperandInputOutputTest(const sp<IDevice> & device,const V1_0::Model & model)544 static void mutateOperandInputOutputTest(const sp<IDevice>& device, const V1_0::Model& model) {
545     const size_t modelSize = sizeForBinder(model);
546     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
547         const std::optional<OperandLifeTime> changedLifeTime =
548                 getInputOutputLifeTime(model, modelSize, model.operands[operand]);
549         if (changedLifeTime) {
550             const std::string message = "mutateOperandInputOutputTest: operand " +
551                                         std::to_string(operand) + " has lifetime " +
552                                         toString(*changedLifeTime) + " instead of lifetime " +
553                                         toString(model.operands[operand].lifetime);
554             validate(device, message, model, [operand, changedLifeTime](Model* model) {
555                 static const DataLocation kZeroDataLocation = {};
556                 Operand& operandObj = model->operands[operand];
557                 operandObj.lifetime = *changedLifeTime;
558                 operandObj.location = kZeroDataLocation;
559                 if (*changedLifeTime == OperandLifeTime::CONSTANT_COPY) {
560                     becomeConstantCopy(model, &operandObj);
561                 }
562             });
563         }
564     }
565 }
566 
567 ///////////////////////// VALIDATE OPERAND NUMBER OF CONSUMERS //////////////////////////////////
568 
getInvalidNumberOfConsumers(uint32_t numberOfConsumers)569 static std::vector<uint32_t> getInvalidNumberOfConsumers(uint32_t numberOfConsumers) {
570     if (numberOfConsumers == 0) {
571         return {1};
572     } else {
573         return {numberOfConsumers - 1, numberOfConsumers + 1};
574     }
575 }
576 
mutateOperandNumberOfConsumersTest(const sp<IDevice> & device,const V1_0::Model & model)577 static void mutateOperandNumberOfConsumersTest(const sp<IDevice>& device,
578                                                const V1_0::Model& model) {
579     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
580         const std::vector<uint32_t> invalidNumberOfConsumersVec =
581                 getInvalidNumberOfConsumers(model.operands[operand].numberOfConsumers);
582         for (uint32_t invalidNumberOfConsumers : invalidNumberOfConsumersVec) {
583             const std::string message =
584                     "mutateOperandNumberOfConsumersTest: operand " + std::to_string(operand) +
585                     " numberOfConsumers = " + std::to_string(invalidNumberOfConsumers);
586             validate(device, message, model, [operand, invalidNumberOfConsumers](Model* model) {
587                 model->operands[operand].numberOfConsumers = invalidNumberOfConsumers;
588             });
589         }
590     }
591 }
592 
593 ///////////////////////// VALIDATE OPERAND NUMBER OF WRITERS ////////////////////////////////////
594 
mutateOperandAddWriterTest(const sp<IDevice> & device,const V1_0::Model & model)595 static void mutateOperandAddWriterTest(const sp<IDevice>& device, const V1_0::Model& model) {
596     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
597         for (size_t badOutputNum = 0; badOutputNum < model.operations[operation].outputs.size();
598              ++badOutputNum) {
599             const uint32_t outputOperandIndex = model.operations[operation].outputs[badOutputNum];
600             const std::string message = "mutateOperandAddWriterTest: operation " +
601                                         std::to_string(operation) + " writes to " +
602                                         std::to_string(outputOperandIndex);
603             // We'll insert a copy of the operation, all of whose
604             // OTHER output operands are newly-created -- i.e.,
605             // there'll only be a duplicate write of ONE of that
606             // operation's output operands.
607             validate(device, message, model, [operation, badOutputNum](Model* model) {
608                 Operation newOperation = model->operations[operation];
609                 for (uint32_t input : newOperation.inputs) {
610                     ++model->operands[input].numberOfConsumers;
611                 }
612                 for (size_t outputNum = 0; outputNum < newOperation.outputs.size(); ++outputNum) {
613                     if (outputNum == badOutputNum) continue;
614 
615                     Operand operandValue = model->operands[newOperation.outputs[outputNum]];
616                     operandValue.numberOfConsumers = 0;
617                     if (operandValue.lifetime == OperandLifeTime::MODEL_OUTPUT) {
618                         operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
619                     } else {
620                         ASSERT_EQ(operandValue.lifetime, OperandLifeTime::TEMPORARY_VARIABLE);
621                     }
622                     newOperation.outputs[outputNum] =
623                             hidl_vec_push_back(&model->operands, operandValue);
624                 }
625                 // Where do we insert the extra writer (a new
626                 // operation)?  It has to be later than all the
627                 // writers of its inputs.  The easiest thing to do
628                 // is to insert it at the end of the operation
629                 // sequence.
630                 hidl_vec_push_back(&model->operations, newOperation);
631             });
632         }
633     }
634 }
635 
636 ///////////////////////// VALIDATE EXTRA ??? /////////////////////////
637 
638 // TODO: Operand::location
639 
640 ///////////////////////// VALIDATE OPERATION OPERAND TYPE /////////////////////////
641 
mutateOperand(Operand * operand,OperandType type)642 static void mutateOperand(Operand* operand, OperandType type) {
643     Operand newOperand = *operand;
644     newOperand.type = type;
645     switch (type) {
646         case OperandType::FLOAT32:
647         case OperandType::INT32:
648         case OperandType::UINT32:
649             newOperand.dimensions = hidl_vec<uint32_t>();
650             newOperand.scale = 0.0f;
651             newOperand.zeroPoint = 0;
652             break;
653         case OperandType::TENSOR_FLOAT32:
654             newOperand.dimensions =
655                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
656             newOperand.scale = 0.0f;
657             newOperand.zeroPoint = 0;
658             break;
659         case OperandType::TENSOR_INT32:
660             newOperand.dimensions =
661                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
662             newOperand.zeroPoint = 0;
663             break;
664         case OperandType::TENSOR_QUANT8_ASYMM:
665             newOperand.dimensions =
666                     operand->dimensions.size() > 0 ? operand->dimensions : hidl_vec<uint32_t>({1});
667             newOperand.scale = operand->scale != 0.0f ? operand->scale : 1.0f;
668             break;
669         case OperandType::OEM:
670         case OperandType::TENSOR_OEM_BYTE:
671         default:
672             break;
673     }
674     *operand = newOperand;
675 }
676 
mutateOperationOperandTypeSkip(size_t operand,const Model & model)677 static bool mutateOperationOperandTypeSkip(size_t operand, const Model& model) {
678     // LSH_PROJECTION's second argument is allowed to have any type. This is the
679     // only operation that currently has a type that can be anything independent
680     // from any other type. Changing the operand type to any other type will
681     // result in a valid model for LSH_PROJECTION. If this is the case, skip the
682     // test.
683     for (const Operation& operation : model.operations) {
684         if (operation.type == OperationType::LSH_PROJECTION && operand == operation.inputs[1]) {
685             return true;
686         }
687     }
688     return false;
689 }
690 
mutateOperationOperandTypeTest(const sp<IDevice> & device,const Model & model)691 static void mutateOperationOperandTypeTest(const sp<IDevice>& device, const Model& model) {
692     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
693         if (mutateOperationOperandTypeSkip(operand, model)) {
694             continue;
695         }
696         for (OperandType invalidOperandType : hidl_enum_range<OperandType>{}) {
697             // Do not test OEM types
698             if (invalidOperandType == model.operands[operand].type ||
699                 invalidOperandType == OperandType::OEM ||
700                 invalidOperandType == OperandType::TENSOR_OEM_BYTE) {
701                 continue;
702             }
703             const std::string message = "mutateOperationOperandTypeTest: operand " +
704                                         std::to_string(operand) + " set to type " +
705                                         toString(invalidOperandType);
706             validate(device, message, model, [operand, invalidOperandType](Model* model) {
707                 mutateOperand(&model->operands[operand], invalidOperandType);
708             });
709         }
710     }
711 }
712 
713 ///////////////////////// VALIDATE MODEL OPERATION TYPE /////////////////////////
714 
715 static const int32_t invalidOperationTypes[] = {
716         static_cast<int32_t>(OperationType::ADD) - 1,            // lower bound fundamental
717         static_cast<int32_t>(OperationType::TANH) + 1,           // upper bound fundamental
718         static_cast<int32_t>(OperationType::OEM_OPERATION) - 1,  // lower bound OEM
719         static_cast<int32_t>(OperationType::OEM_OPERATION) + 1,  // upper bound OEM
720 };
721 
mutateOperationTypeTest(const sp<IDevice> & device,const Model & model)722 static void mutateOperationTypeTest(const sp<IDevice>& device, const Model& model) {
723     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
724         for (int32_t invalidOperationType : invalidOperationTypes) {
725             const std::string message = "mutateOperationTypeTest: operation " +
726                                         std::to_string(operation) + " set to value " +
727                                         std::to_string(invalidOperationType);
728             validate(device, message, model, [operation, invalidOperationType](Model* model) {
729                 model->operations[operation].type =
730                         static_cast<OperationType>(invalidOperationType);
731             });
732         }
733     }
734 }
735 
736 ///////////////////////// VALIDATE MODEL OPERATION INPUT OPERAND INDEX /////////////////////////
737 
mutateOperationInputOperandIndexTest(const sp<IDevice> & device,const Model & model)738 static void mutateOperationInputOperandIndexTest(const sp<IDevice>& device, const Model& model) {
739     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
740         const uint32_t invalidOperand = model.operands.size();
741         for (size_t input = 0; input < model.operations[operation].inputs.size(); ++input) {
742             const std::string message = "mutateOperationInputOperandIndexTest: operation " +
743                                         std::to_string(operation) + " input " +
744                                         std::to_string(input);
745             validate(device, message, model, [operation, input, invalidOperand](Model* model) {
746                 model->operations[operation].inputs[input] = invalidOperand;
747             });
748         }
749     }
750 }
751 
752 ///////////////////////// VALIDATE MODEL OPERATION OUTPUT OPERAND INDEX /////////////////////////
753 
mutateOperationOutputOperandIndexTest(const sp<IDevice> & device,const Model & model)754 static void mutateOperationOutputOperandIndexTest(const sp<IDevice>& device, const Model& model) {
755     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
756         const uint32_t invalidOperand = model.operands.size();
757         for (size_t output = 0; output < model.operations[operation].outputs.size(); ++output) {
758             const std::string message = "mutateOperationOutputOperandIndexTest: operation " +
759                                         std::to_string(operation) + " output " +
760                                         std::to_string(output);
761             validate(device, message, model, [operation, output, invalidOperand](Model* model) {
762                 model->operations[operation].outputs[output] = invalidOperand;
763             });
764         }
765     }
766 }
767 
768 ///////////////////////// VALIDATE MODEL OPERANDS WRITTEN ///////////////////////////////////////
769 
mutateOperationRemoveWriteTest(const sp<IDevice> & device,const V1_0::Model & model)770 static void mutateOperationRemoveWriteTest(const sp<IDevice>& device, const V1_0::Model& model) {
771     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
772         for (size_t outputNum = 0; outputNum < model.operations[operation].outputs.size();
773              ++outputNum) {
774             const uint32_t outputOperandIndex = model.operations[operation].outputs[outputNum];
775             if (model.operands[outputOperandIndex].numberOfConsumers > 0) {
776                 const std::string message = "mutateOperationRemoveWriteTest: operation " +
777                                             std::to_string(operation) + " writes to " +
778                                             std::to_string(outputOperandIndex);
779                 validate(device, message, model, [operation, outputNum](Model* model) {
780                     uint32_t& outputOperandIndex = model->operations[operation].outputs[outputNum];
781                     Operand operandValue = model->operands[outputOperandIndex];
782                     operandValue.numberOfConsumers = 0;
783                     if (operandValue.lifetime == OperandLifeTime::MODEL_OUTPUT) {
784                         operandValue.lifetime = OperandLifeTime::TEMPORARY_VARIABLE;
785                     } else {
786                         ASSERT_EQ(operandValue.lifetime, OperandLifeTime::TEMPORARY_VARIABLE);
787                     }
788                     outputOperandIndex = hidl_vec_push_back(&model->operands, operandValue);
789                 });
790             }
791         }
792     }
793 }
794 
795 ///////////////////////// REMOVE OPERAND FROM EVERYTHING /////////////////////////
796 
removeValueAndDecrementGreaterValues(hidl_vec<uint32_t> * vec,uint32_t value)797 static void removeValueAndDecrementGreaterValues(hidl_vec<uint32_t>* vec, uint32_t value) {
798     if (vec) {
799         // remove elements matching "value"
800         auto last = std::remove(vec->begin(), vec->end(), value);
801         vec->resize(std::distance(vec->begin(), last));
802 
803         // decrement elements exceeding "value"
804         std::transform(vec->begin(), vec->end(), vec->begin(),
805                        [value](uint32_t v) { return v > value ? v-- : v; });
806     }
807 }
808 
removeOperand(Model * model,uint32_t index)809 static void removeOperand(Model* model, uint32_t index) {
810     hidl_vec_removeAt(&model->operands, index);
811     for (Operation& operation : model->operations) {
812         removeValueAndDecrementGreaterValues(&operation.inputs, index);
813         removeValueAndDecrementGreaterValues(&operation.outputs, index);
814     }
815     removeValueAndDecrementGreaterValues(&model->inputIndexes, index);
816     removeValueAndDecrementGreaterValues(&model->outputIndexes, index);
817 }
818 
removeOperandTest(const sp<IDevice> & device,const Model & model)819 static void removeOperandTest(const sp<IDevice>& device, const Model& model) {
820     for (size_t operand = 0; operand < model.operands.size(); ++operand) {
821         const std::string message = "removeOperandTest: operand " + std::to_string(operand);
822         validate(device, message, model,
823                  [operand](Model* model) { removeOperand(model, operand); });
824     }
825 }
826 
827 ///////////////////////// REMOVE OPERATION /////////////////////////
828 
removeOperation(Model * model,uint32_t index)829 static void removeOperation(Model* model, uint32_t index) {
830     for (uint32_t operand : model->operations[index].inputs) {
831         model->operands[operand].numberOfConsumers--;
832     }
833     hidl_vec_removeAt(&model->operations, index);
834 }
835 
removeOperationTest(const sp<IDevice> & device,const Model & model)836 static void removeOperationTest(const sp<IDevice>& device, const Model& model) {
837     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
838         const std::string message = "removeOperationTest: operation " + std::to_string(operation);
839         validate(device, message, model,
840                  [operation](Model* model) { removeOperation(model, operation); });
841     }
842 }
843 
844 ///////////////////////// REMOVE OPERATION INPUT /////////////////////////
845 
removeOperationInputTest(const sp<IDevice> & device,const Model & model)846 static void removeOperationInputTest(const sp<IDevice>& device, const Model& model) {
847     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
848         for (size_t input = 0; input < model.operations[operation].inputs.size(); ++input) {
849             const Operation& op = model.operations[operation];
850             // CONCATENATION has at least 2 inputs, with the last element being
851             // INT32. Skip this test if removing one of CONCATENATION's
852             // inputs still produces a valid model.
853             if (op.type == OperationType::CONCATENATION && op.inputs.size() > 2 &&
854                 input != op.inputs.size() - 1) {
855                 continue;
856             }
857             const std::string message = "removeOperationInputTest: operation " +
858                                         std::to_string(operation) + ", input " +
859                                         std::to_string(input);
860             validate(device, message, model, [operation, input](Model* model) {
861                 uint32_t operand = model->operations[operation].inputs[input];
862                 model->operands[operand].numberOfConsumers--;
863                 hidl_vec_removeAt(&model->operations[operation].inputs, input);
864             });
865         }
866     }
867 }
868 
869 ///////////////////////// REMOVE OPERATION OUTPUT /////////////////////////
870 
removeOperationOutputTest(const sp<IDevice> & device,const Model & model)871 static void removeOperationOutputTest(const sp<IDevice>& device, const Model& model) {
872     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
873         for (size_t output = 0; output < model.operations[operation].outputs.size(); ++output) {
874             const std::string message = "removeOperationOutputTest: operation " +
875                                         std::to_string(operation) + ", output " +
876                                         std::to_string(output);
877             validate(device, message, model, [operation, output](Model* model) {
878                 hidl_vec_removeAt(&model->operations[operation].outputs, output);
879             });
880         }
881     }
882 }
883 
884 ///////////////////////// MODEL VALIDATION /////////////////////////
885 
886 // TODO: remove model input
887 // TODO: remove model output
888 // TODO: add unused operation
889 
890 ///////////////////////// ADD OPERATION INPUT /////////////////////////
891 
addOperationInputTest(const sp<IDevice> & device,const Model & model)892 static void addOperationInputTest(const sp<IDevice>& device, const Model& model) {
893     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
894         const std::string message = "addOperationInputTest: operation " + std::to_string(operation);
895         validate(device, message, model, [operation](Model* model) {
896             uint32_t index = addOperand(model, OperandLifeTime::MODEL_INPUT);
897             hidl_vec_push_back(&model->operations[operation].inputs, index);
898             hidl_vec_push_back(&model->inputIndexes, index);
899         });
900     }
901 }
902 
903 ///////////////////////// ADD OPERATION OUTPUT /////////////////////////
904 
addOperationOutputTest(const sp<IDevice> & device,const Model & model)905 static void addOperationOutputTest(const sp<IDevice>& device, const Model& model) {
906     for (size_t operation = 0; operation < model.operations.size(); ++operation) {
907         const std::string message =
908                 "addOperationOutputTest: operation " + std::to_string(operation);
909         validate(device, message, model, [operation](Model* model) {
910             uint32_t index = addOperand(model, OperandLifeTime::MODEL_OUTPUT);
911             hidl_vec_push_back(&model->operations[operation].outputs, index);
912             hidl_vec_push_back(&model->outputIndexes, index);
913         });
914     }
915 }
916 
917 ////////////////////////// ENTRY POINT //////////////////////////////
918 
validateModel(const sp<IDevice> & device,const Model & model)919 void validateModel(const sp<IDevice>& device, const Model& model) {
920     mutateExecutionOrderTest(device, model);
921     mutateOperandTypeTest(device, model);
922     mutateOperandRankTest(device, model);
923     mutateOperandScaleTest(device, model);
924     mutateOperandZeroPointTest(device, model);
925     mutateOperandLifeTimeTest(device, model);
926     mutateOperandInputOutputTest(device, model);
927     mutateOperandNumberOfConsumersTest(device, model);
928     mutateOperandAddWriterTest(device, model);
929     mutateOperationOperandTypeTest(device, model);
930     mutateOperationTypeTest(device, model);
931     mutateOperationInputOperandIndexTest(device, model);
932     mutateOperationOutputOperandIndexTest(device, model);
933     mutateOperationRemoveWriteTest(device, model);
934     removeOperandTest(device, model);
935     removeOperationTest(device, model);
936     removeOperationInputTest(device, model);
937     removeOperationOutputTest(device, model);
938     addOperationInputTest(device, model);
939     addOperationOutputTest(device, model);
940 }
941 
942 }  // namespace android::hardware::neuralnetworks::V1_0::vts::functional
943