1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "Conversions.h"
18 
19 #include <android-base/logging.h>
20 #include <android/hardware/neuralnetworks/1.0/types.h>
21 #include <nnapi/OperandTypes.h>
22 #include <nnapi/OperationTypes.h>
23 #include <nnapi/Result.h>
24 #include <nnapi/SharedMemory.h>
25 #include <nnapi/TypeUtils.h>
26 #include <nnapi/Types.h>
27 #include <nnapi/Validation.h>
28 #include <nnapi/hal/CommonUtils.h>
29 
30 #include <algorithm>
31 #include <functional>
32 #include <iterator>
33 #include <memory>
34 #include <type_traits>
35 #include <utility>
36 #include <variant>
37 
38 #include "Utils.h"
39 
40 namespace {
41 
42 template <typename Type>
underlyingType(Type value)43 constexpr std::underlying_type_t<Type> underlyingType(Type value) {
44     return static_cast<std::underlying_type_t<Type>>(value);
45 }
46 
47 }  // namespace
48 
49 namespace android::nn {
50 namespace {
51 
52 using hardware::hidl_memory;
53 using hardware::hidl_vec;
54 
55 template <typename Input>
56 using UnvalidatedConvertOutput =
57         std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
58 
59 template <typename Type>
unvalidatedConvert(const hidl_vec<Type> & arguments)60 GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
61         const hidl_vec<Type>& arguments) {
62     std::vector<UnvalidatedConvertOutput<Type>> canonical;
63     canonical.reserve(arguments.size());
64     for (const auto& argument : arguments) {
65         canonical.push_back(NN_TRY(nn::unvalidatedConvert(argument)));
66     }
67     return canonical;
68 }
69 
70 template <typename Type>
validatedConvert(const Type & halObject)71 GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& halObject) {
72     auto canonical = NN_TRY(nn::unvalidatedConvert(halObject));
73     NN_TRY(hal::V1_0::utils::compliantVersion(canonical));
74     return canonical;
75 }
76 
77 }  // anonymous namespace
78 
unvalidatedConvert(const hal::V1_0::OperandType & operandType)79 GeneralResult<OperandType> unvalidatedConvert(const hal::V1_0::OperandType& operandType) {
80     return static_cast<OperandType>(operandType);
81 }
82 
unvalidatedConvert(const hal::V1_0::OperationType & operationType)83 GeneralResult<OperationType> unvalidatedConvert(const hal::V1_0::OperationType& operationType) {
84     return static_cast<OperationType>(operationType);
85 }
86 
unvalidatedConvert(const hal::V1_0::OperandLifeTime & lifetime)87 GeneralResult<Operand::LifeTime> unvalidatedConvert(const hal::V1_0::OperandLifeTime& lifetime) {
88     return static_cast<Operand::LifeTime>(lifetime);
89 }
90 
unvalidatedConvert(const hal::V1_0::DeviceStatus & deviceStatus)91 GeneralResult<DeviceStatus> unvalidatedConvert(const hal::V1_0::DeviceStatus& deviceStatus) {
92     return static_cast<DeviceStatus>(deviceStatus);
93 }
94 
unvalidatedConvert(const hal::V1_0::PerformanceInfo & performanceInfo)95 GeneralResult<Capabilities::PerformanceInfo> unvalidatedConvert(
96         const hal::V1_0::PerformanceInfo& performanceInfo) {
97     return Capabilities::PerformanceInfo{
98             .execTime = performanceInfo.execTime,
99             .powerUsage = performanceInfo.powerUsage,
100     };
101 }
102 
unvalidatedConvert(const hal::V1_0::Capabilities & capabilities)103 GeneralResult<Capabilities> unvalidatedConvert(const hal::V1_0::Capabilities& capabilities) {
104     const auto quantized8Performance =
105             NN_TRY(unvalidatedConvert(capabilities.quantized8Performance));
106     const auto float32Performance = NN_TRY(unvalidatedConvert(capabilities.float32Performance));
107 
108     auto table = hal::utils::makeQuantized8PerformanceConsistentWithP(float32Performance,
109                                                                       quantized8Performance);
110 
111     return Capabilities{
112             .relaxedFloat32toFloat16PerformanceScalar = float32Performance,
113             .relaxedFloat32toFloat16PerformanceTensor = float32Performance,
114             .operandPerformance = std::move(table),
115     };
116 }
117 
unvalidatedConvert(const hal::V1_0::DataLocation & location)118 GeneralResult<DataLocation> unvalidatedConvert(const hal::V1_0::DataLocation& location) {
119     return DataLocation{
120             .poolIndex = location.poolIndex,
121             .offset = location.offset,
122             .length = location.length,
123     };
124 }
125 
unvalidatedConvert(const hal::V1_0::Operand & operand)126 GeneralResult<Operand> unvalidatedConvert(const hal::V1_0::Operand& operand) {
127     return Operand{
128             .type = NN_TRY(unvalidatedConvert(operand.type)),
129             .dimensions = operand.dimensions,
130             .scale = operand.scale,
131             .zeroPoint = operand.zeroPoint,
132             .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
133             .location = NN_TRY(unvalidatedConvert(operand.location)),
134     };
135 }
136 
unvalidatedConvert(const hal::V1_0::Operation & operation)137 GeneralResult<Operation> unvalidatedConvert(const hal::V1_0::Operation& operation) {
138     return Operation{
139             .type = NN_TRY(unvalidatedConvert(operation.type)),
140             .inputs = operation.inputs,
141             .outputs = operation.outputs,
142     };
143 }
144 
unvalidatedConvert(const hidl_vec<uint8_t> & operandValues)145 GeneralResult<Model::OperandValues> unvalidatedConvert(const hidl_vec<uint8_t>& operandValues) {
146     return Model::OperandValues(operandValues.data(), operandValues.size());
147 }
148 
unvalidatedConvert(const hidl_memory & memory)149 GeneralResult<SharedMemory> unvalidatedConvert(const hidl_memory& memory) {
150     return hal::utils::createSharedMemoryFromHidlMemory(memory);
151 }
152 
unvalidatedConvert(const hal::V1_0::Model & model)153 GeneralResult<Model> unvalidatedConvert(const hal::V1_0::Model& model) {
154     auto operations = NN_TRY(unvalidatedConvert(model.operations));
155 
156     // Verify number of consumers.
157     const auto numberOfConsumers =
158             NN_TRY(hal::utils::countNumberOfConsumers(model.operands.size(), operations));
159     CHECK(model.operands.size() == numberOfConsumers.size());
160     for (size_t i = 0; i < model.operands.size(); ++i) {
161         if (model.operands[i].numberOfConsumers != numberOfConsumers[i]) {
162             return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
163                    << "Invalid numberOfConsumers for operand " << i << ", expected "
164                    << numberOfConsumers[i] << " but found " << model.operands[i].numberOfConsumers;
165         }
166     }
167 
168     auto main = Model::Subgraph{
169             .operands = NN_TRY(unvalidatedConvert(model.operands)),
170             .operations = std::move(operations),
171             .inputIndexes = model.inputIndexes,
172             .outputIndexes = model.outputIndexes,
173     };
174 
175     return Model{
176             .main = std::move(main),
177             .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
178             .pools = NN_TRY(unvalidatedConvert(model.pools)),
179     };
180 }
181 
unvalidatedConvert(const hal::V1_0::RequestArgument & argument)182 GeneralResult<Request::Argument> unvalidatedConvert(const hal::V1_0::RequestArgument& argument) {
183     const auto lifetime = argument.hasNoValue ? Request::Argument::LifeTime::NO_VALUE
184                                               : Request::Argument::LifeTime::POOL;
185     return Request::Argument{
186             .lifetime = lifetime,
187             .location = NN_TRY(unvalidatedConvert(argument.location)),
188             .dimensions = argument.dimensions,
189     };
190 }
191 
unvalidatedConvert(const hal::V1_0::Request & request)192 GeneralResult<Request> unvalidatedConvert(const hal::V1_0::Request& request) {
193     auto memories = NN_TRY(unvalidatedConvert(request.pools));
194     std::vector<Request::MemoryPool> pools;
195     pools.reserve(memories.size());
196     std::move(memories.begin(), memories.end(), std::back_inserter(pools));
197 
198     return Request{
199             .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
200             .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
201             .pools = std::move(pools),
202     };
203 }
204 
unvalidatedConvert(const hal::V1_0::ErrorStatus & status)205 GeneralResult<ErrorStatus> unvalidatedConvert(const hal::V1_0::ErrorStatus& status) {
206     switch (status) {
207         case hal::V1_0::ErrorStatus::NONE:
208         case hal::V1_0::ErrorStatus::DEVICE_UNAVAILABLE:
209         case hal::V1_0::ErrorStatus::GENERAL_FAILURE:
210         case hal::V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
211         case hal::V1_0::ErrorStatus::INVALID_ARGUMENT:
212             return static_cast<ErrorStatus>(status);
213     }
214     return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
215            << "Invalid ErrorStatus " << underlyingType(status);
216 }
217 
convert(const hal::V1_0::DeviceStatus & deviceStatus)218 GeneralResult<DeviceStatus> convert(const hal::V1_0::DeviceStatus& deviceStatus) {
219     return validatedConvert(deviceStatus);
220 }
221 
convert(const hal::V1_0::Capabilities & capabilities)222 GeneralResult<Capabilities> convert(const hal::V1_0::Capabilities& capabilities) {
223     return validatedConvert(capabilities);
224 }
225 
convert(const hal::V1_0::Model & model)226 GeneralResult<Model> convert(const hal::V1_0::Model& model) {
227     return validatedConvert(model);
228 }
229 
convert(const hal::V1_0::Request & request)230 GeneralResult<Request> convert(const hal::V1_0::Request& request) {
231     return validatedConvert(request);
232 }
233 
convert(const hal::V1_0::ErrorStatus & status)234 GeneralResult<ErrorStatus> convert(const hal::V1_0::ErrorStatus& status) {
235     return validatedConvert(status);
236 }
237 
238 }  // namespace android::nn
239 
240 namespace android::hardware::neuralnetworks::V1_0::utils {
241 namespace {
242 
243 template <typename Input>
244 using UnvalidatedConvertOutput =
245         std::decay_t<decltype(unvalidatedConvert(std::declval<Input>()).value())>;
246 
247 template <typename Type>
unvalidatedConvert(const std::vector<Type> & arguments)248 nn::GeneralResult<hidl_vec<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
249         const std::vector<Type>& arguments) {
250     hidl_vec<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
251     for (size_t i = 0; i < arguments.size(); ++i) {
252         halObject[i] = NN_TRY(utils::unvalidatedConvert(arguments[i]));
253     }
254     return halObject;
255 }
256 
257 template <typename Type>
validatedConvert(const Type & canonical)258 nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
259     NN_TRY(compliantVersion(canonical));
260     return utils::unvalidatedConvert(canonical);
261 }
262 
263 }  // anonymous namespace
264 
unvalidatedConvert(const nn::OperandType & operandType)265 nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
266     return static_cast<OperandType>(operandType);
267 }
268 
unvalidatedConvert(const nn::OperationType & operationType)269 nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
270     return static_cast<OperationType>(operationType);
271 }
272 
unvalidatedConvert(const nn::Operand::LifeTime & lifetime)273 nn::GeneralResult<OperandLifeTime> unvalidatedConvert(const nn::Operand::LifeTime& lifetime) {
274     if (lifetime == nn::Operand::LifeTime::POINTER) {
275         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
276                << "Model cannot be unvalidatedConverted because it contains pointer-based memory";
277     }
278     return static_cast<OperandLifeTime>(lifetime);
279 }
280 
unvalidatedConvert(const nn::DeviceStatus & deviceStatus)281 nn::GeneralResult<DeviceStatus> unvalidatedConvert(const nn::DeviceStatus& deviceStatus) {
282     return static_cast<DeviceStatus>(deviceStatus);
283 }
284 
unvalidatedConvert(const nn::Capabilities::PerformanceInfo & performanceInfo)285 nn::GeneralResult<PerformanceInfo> unvalidatedConvert(
286         const nn::Capabilities::PerformanceInfo& performanceInfo) {
287     return PerformanceInfo{
288             .execTime = performanceInfo.execTime,
289             .powerUsage = performanceInfo.powerUsage,
290     };
291 }
292 
unvalidatedConvert(const nn::Capabilities & capabilities)293 nn::GeneralResult<Capabilities> unvalidatedConvert(const nn::Capabilities& capabilities) {
294     return Capabilities{
295             .float32Performance = NN_TRY(unvalidatedConvert(
296                     capabilities.operandPerformance.lookup(nn::OperandType::TENSOR_FLOAT32))),
297             .quantized8Performance = NN_TRY(unvalidatedConvert(
298                     capabilities.operandPerformance.lookup(nn::OperandType::TENSOR_QUANT8_ASYMM))),
299     };
300 }
301 
unvalidatedConvert(const nn::DataLocation & location)302 nn::GeneralResult<DataLocation> unvalidatedConvert(const nn::DataLocation& location) {
303     return DataLocation{
304             .poolIndex = location.poolIndex,
305             .offset = location.offset,
306             .length = location.length,
307     };
308 }
309 
unvalidatedConvert(const nn::Operand & operand)310 nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
311     return Operand{
312             .type = NN_TRY(unvalidatedConvert(operand.type)),
313             .dimensions = operand.dimensions,
314             .numberOfConsumers = 0,
315             .scale = operand.scale,
316             .zeroPoint = operand.zeroPoint,
317             .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
318             .location = NN_TRY(unvalidatedConvert(operand.location)),
319     };
320 }
321 
unvalidatedConvert(const nn::Operation & operation)322 nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation) {
323     return Operation{
324             .type = NN_TRY(unvalidatedConvert(operation.type)),
325             .inputs = operation.inputs,
326             .outputs = operation.outputs,
327     };
328 }
329 
unvalidatedConvert(const nn::Model::OperandValues & operandValues)330 nn::GeneralResult<hidl_vec<uint8_t>> unvalidatedConvert(
331         const nn::Model::OperandValues& operandValues) {
332     return hidl_vec<uint8_t>(operandValues.data(), operandValues.data() + operandValues.size());
333 }
334 
unvalidatedConvert(const nn::SharedMemory & memory)335 nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::SharedMemory& memory) {
336     return hal::utils::createHidlMemoryFromSharedMemory(memory);
337 }
338 
unvalidatedConvert(const nn::Model & model)339 nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
340     if (!hal::utils::hasNoPointerData(model)) {
341         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
342                << "Mdoel cannot be unvalidatedConverted because it contains pointer-based memory";
343     }
344 
345     auto operands = NN_TRY(unvalidatedConvert(model.main.operands));
346 
347     // Update number of consumers.
348     const auto numberOfConsumers =
349             NN_TRY(hal::utils::countNumberOfConsumers(operands.size(), model.main.operations));
350     CHECK(operands.size() == numberOfConsumers.size());
351     for (size_t i = 0; i < operands.size(); ++i) {
352         operands[i].numberOfConsumers = numberOfConsumers[i];
353     }
354 
355     return Model{
356             .operands = std::move(operands),
357             .operations = NN_TRY(unvalidatedConvert(model.main.operations)),
358             .inputIndexes = model.main.inputIndexes,
359             .outputIndexes = model.main.outputIndexes,
360             .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
361             .pools = NN_TRY(unvalidatedConvert(model.pools)),
362     };
363 }
364 
unvalidatedConvert(const nn::Request::Argument & requestArgument)365 nn::GeneralResult<RequestArgument> unvalidatedConvert(
366         const nn::Request::Argument& requestArgument) {
367     if (requestArgument.lifetime == nn::Request::Argument::LifeTime::POINTER) {
368         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
369                << "Request cannot be unvalidatedConverted because it contains pointer-based memory";
370     }
371     const bool hasNoValue = requestArgument.lifetime == nn::Request::Argument::LifeTime::NO_VALUE;
372     return RequestArgument{
373             .hasNoValue = hasNoValue,
374             .location = NN_TRY(unvalidatedConvert(requestArgument.location)),
375             .dimensions = requestArgument.dimensions,
376     };
377 }
378 
unvalidatedConvert(const nn::Request::MemoryPool & memoryPool)379 nn::GeneralResult<hidl_memory> unvalidatedConvert(const nn::Request::MemoryPool& memoryPool) {
380     return unvalidatedConvert(std::get<nn::SharedMemory>(memoryPool));
381 }
382 
unvalidatedConvert(const nn::Request & request)383 nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request) {
384     if (!hal::utils::hasNoPointerData(request)) {
385         return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
386                << "Request cannot be unvalidatedConverted because it contains pointer-based memory";
387     }
388 
389     return Request{
390             .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
391             .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
392             .pools = NN_TRY(unvalidatedConvert(request.pools)),
393     };
394 }
395 
unvalidatedConvert(const nn::ErrorStatus & status)396 nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& status) {
397     switch (status) {
398         case nn::ErrorStatus::NONE:
399         case nn::ErrorStatus::DEVICE_UNAVAILABLE:
400         case nn::ErrorStatus::GENERAL_FAILURE:
401         case nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
402         case nn::ErrorStatus::INVALID_ARGUMENT:
403             return static_cast<ErrorStatus>(status);
404         default:
405             return ErrorStatus::GENERAL_FAILURE;
406     }
407 }
408 
convert(const nn::DeviceStatus & deviceStatus)409 nn::GeneralResult<DeviceStatus> convert(const nn::DeviceStatus& deviceStatus) {
410     return validatedConvert(deviceStatus);
411 }
412 
convert(const nn::Capabilities & capabilities)413 nn::GeneralResult<Capabilities> convert(const nn::Capabilities& capabilities) {
414     return validatedConvert(capabilities);
415 }
416 
convert(const nn::Model & model)417 nn::GeneralResult<Model> convert(const nn::Model& model) {
418     return validatedConvert(model);
419 }
420 
convert(const nn::Request & request)421 nn::GeneralResult<Request> convert(const nn::Request& request) {
422     return validatedConvert(request);
423 }
424 
convert(const nn::ErrorStatus & status)425 nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& status) {
426     return validatedConvert(status);
427 }
428 
429 }  // namespace android::hardware::neuralnetworks::V1_0::utils
430