1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <cmath>
20 #include <vector>
21
22 #include "OperationResolver.h"
23 #include "Tracing.h"
24
25 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
26 #include "CpuOperationUtils.h"
27 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
28
29 namespace android {
30 namespace nn {
31 namespace instance_normalization {
32
33 constexpr char kOperationName[] = "INSTANCE_NORMALIZATION";
34
35 constexpr uint32_t kNumInputs = 5;
36 constexpr uint32_t kInputTensor = 0;
37 constexpr uint32_t kGammaScalar = 1;
38 constexpr uint32_t kBetaScalar = 2;
39 constexpr uint32_t kEpsilonScalar = 3;
40 constexpr uint32_t kLayoutScalar = 4;
41
42 constexpr uint32_t kNumOutputs = 1;
43 constexpr uint32_t kOutputTensor = 0;
44
45 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
46 namespace {
47
48 template <typename T>
instanceNormNhwc(const T * inputData,const Shape & inputShape,T gamma,T beta,T epsilon,T * outputData,const Shape & outputShape)49 inline bool instanceNormNhwc(const T* inputData, const Shape& inputShape, T gamma, T beta,
50 T epsilon, T* outputData, const Shape& outputShape) {
51 NNTRACE_TRANS("InstanceNormalizationNhwc");
52 uint32_t numBatches = getSizeOfDimension(inputShape, 0);
53 uint32_t height = getSizeOfDimension(inputShape, 1);
54 uint32_t width = getSizeOfDimension(inputShape, 2);
55 uint32_t depth = getSizeOfDimension(inputShape, 3);
56 for (uint32_t b = 0; b < numBatches; b++) {
57 for (uint32_t d = 0; d < depth; d++) {
58 uint32_t indexBase = b * height * width * depth + d;
59 T mean = 0, sigma = 0;
60
61 // Compute the mean of a single layer.
62 for (uint32_t h = 0; h < height; h++) {
63 for (uint32_t w = 0; w < width; w++) {
64 T val = inputData[indexBase + (h * width + w) * depth];
65 mean += val;
66 }
67 }
68 mean /= static_cast<T>(height * width);
69
70 // Compute the standard deviation (sigma) of a single layer.
71 for (uint32_t h = 0; h < height; h++) {
72 for (uint32_t w = 0; w < width; w++) {
73 T val = inputData[indexBase + (h * width + w) * depth] - mean;
74 sigma += val * val;
75 }
76 }
77 sigma = std::sqrt(static_cast<float>(sigma / static_cast<T>(height * width)) + epsilon);
78
79 // Apply instance normalization.
80 for (uint32_t h = 0; h < height; h++) {
81 for (uint32_t w = 0; w < width; w++) {
82 uint32_t ind = indexBase + (h * width + w) * depth;
83 outputData[ind] = (inputData[ind] - mean) * gamma / sigma + beta;
84 }
85 }
86 }
87 }
88 return true;
89 }
90
91 template <typename T>
instanceNorm(const T * inputData,const Shape & inputShape,T gamma,T beta,T epsilon,bool useNchw,T * outputData,const Shape & outputShape)92 inline bool instanceNorm(const T* inputData, const Shape& inputShape, T gamma, T beta, T epsilon,
93 bool useNchw, T* outputData, const Shape& outputShape) {
94 InputWithLayout<T> input(useNchw);
95 OutputWithLayout<T> output(useNchw);
96 NN_RET_CHECK(input.initialize(inputData, inputShape));
97 NN_RET_CHECK(output.initialize(outputData, outputShape));
98 NN_RET_CHECK(instanceNormNhwc(input.getNhwcBuffer(), input.getNhwcShape(), gamma, beta, epsilon,
99 output.getNhwcBuffer(), output.getNhwcShape()));
100 NN_RET_CHECK(output.commit());
101 return true;
102 }
103
104 } // namespace
105 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
106
validate(const IOperationValidationContext * context)107 Result<Version> validate(const IOperationValidationContext* context) {
108 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
109 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
110 std::vector<OperandType> inExpectedTypes;
111 auto inputType = context->getInputType(kInputTensor);
112 if (inputType == OperandType::TENSOR_FLOAT32) {
113 inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::FLOAT32, OperandType::FLOAT32,
114 OperandType::FLOAT32, OperandType::BOOL};
115 } else if (inputType == OperandType::TENSOR_FLOAT16) {
116 inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::FLOAT16, OperandType::FLOAT16,
117 OperandType::FLOAT16, OperandType::BOOL};
118 } else {
119 return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName;
120 }
121 NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
122 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
123 return Version::ANDROID_Q;
124 }
125
126 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
prepare(IOperationExecutionContext * context)127 bool prepare(IOperationExecutionContext* context) {
128 Shape input = context->getInputShape(kInputTensor);
129 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
130 return context->setOutputShape(kOutputTensor, input);
131 }
132
execute(IOperationExecutionContext * context)133 bool execute(IOperationExecutionContext* context) {
134 switch (context->getInputType(kInputTensor)) {
135 case OperandType::TENSOR_FLOAT16:
136 return instanceNorm(context->getInputBuffer<_Float16>(kInputTensor),
137 context->getInputShape(kInputTensor),
138 context->getInputValue<_Float16>(kGammaScalar),
139 context->getInputValue<_Float16>(kBetaScalar),
140 context->getInputValue<_Float16>(kEpsilonScalar),
141 context->getInputValue<bool>(kLayoutScalar),
142 context->getOutputBuffer<_Float16>(kOutputTensor),
143 context->getOutputShape(kOutputTensor));
144 case OperandType::TENSOR_FLOAT32:
145 return instanceNorm(context->getInputBuffer<float>(kInputTensor),
146 context->getInputShape(kInputTensor),
147 context->getInputValue<float>(kGammaScalar),
148 context->getInputValue<float>(kBetaScalar),
149 context->getInputValue<float>(kEpsilonScalar),
150 context->getInputValue<bool>(kLayoutScalar),
151 context->getOutputBuffer<float>(kOutputTensor),
152 context->getOutputShape(kOutputTensor));
153 default:
154 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
155 }
156 }
157 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
158
159 } // namespace instance_normalization
160
161 NN_REGISTER_OPERATION(INSTANCE_NORMALIZATION, instance_normalization::kOperationName,
162 instance_normalization::validate, instance_normalization::prepare,
163 instance_normalization::execute);
164
165 } // namespace nn
166 } // namespace android
167