1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include <cmath>
20 #include <vector>
21 
22 #include "OperationResolver.h"
23 #include "Tracing.h"
24 
25 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
26 #include "CpuOperationUtils.h"
27 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
28 
29 namespace android {
30 namespace nn {
31 namespace instance_normalization {
32 
33 constexpr char kOperationName[] = "INSTANCE_NORMALIZATION";
34 
35 constexpr uint32_t kNumInputs = 5;
36 constexpr uint32_t kInputTensor = 0;
37 constexpr uint32_t kGammaScalar = 1;
38 constexpr uint32_t kBetaScalar = 2;
39 constexpr uint32_t kEpsilonScalar = 3;
40 constexpr uint32_t kLayoutScalar = 4;
41 
42 constexpr uint32_t kNumOutputs = 1;
43 constexpr uint32_t kOutputTensor = 0;
44 
45 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
46 namespace {
47 
48 template <typename T>
instanceNormNhwc(const T * inputData,const Shape & inputShape,T gamma,T beta,T epsilon,T * outputData,const Shape & outputShape)49 inline bool instanceNormNhwc(const T* inputData, const Shape& inputShape, T gamma, T beta,
50                              T epsilon, T* outputData, const Shape& outputShape) {
51     NNTRACE_TRANS("InstanceNormalizationNhwc");
52     uint32_t numBatches = getSizeOfDimension(inputShape, 0);
53     uint32_t height = getSizeOfDimension(inputShape, 1);
54     uint32_t width = getSizeOfDimension(inputShape, 2);
55     uint32_t depth = getSizeOfDimension(inputShape, 3);
56     for (uint32_t b = 0; b < numBatches; b++) {
57         for (uint32_t d = 0; d < depth; d++) {
58             uint32_t indexBase = b * height * width * depth + d;
59             T mean = 0, sigma = 0;
60 
61             // Compute the mean of a single layer.
62             for (uint32_t h = 0; h < height; h++) {
63                 for (uint32_t w = 0; w < width; w++) {
64                     T val = inputData[indexBase + (h * width + w) * depth];
65                     mean += val;
66                 }
67             }
68             mean /= static_cast<T>(height * width);
69 
70             // Compute the standard deviation (sigma) of a single layer.
71             for (uint32_t h = 0; h < height; h++) {
72                 for (uint32_t w = 0; w < width; w++) {
73                     T val = inputData[indexBase + (h * width + w) * depth] - mean;
74                     sigma += val * val;
75                 }
76             }
77             sigma = std::sqrt(static_cast<float>(sigma / static_cast<T>(height * width)) + epsilon);
78 
79             // Apply instance normalization.
80             for (uint32_t h = 0; h < height; h++) {
81                 for (uint32_t w = 0; w < width; w++) {
82                     uint32_t ind = indexBase + (h * width + w) * depth;
83                     outputData[ind] = (inputData[ind] - mean) * gamma / sigma + beta;
84                 }
85             }
86         }
87     }
88     return true;
89 }
90 
91 template <typename T>
instanceNorm(const T * inputData,const Shape & inputShape,T gamma,T beta,T epsilon,bool useNchw,T * outputData,const Shape & outputShape)92 inline bool instanceNorm(const T* inputData, const Shape& inputShape, T gamma, T beta, T epsilon,
93                          bool useNchw, T* outputData, const Shape& outputShape) {
94     InputWithLayout<T> input(useNchw);
95     OutputWithLayout<T> output(useNchw);
96     NN_RET_CHECK(input.initialize(inputData, inputShape));
97     NN_RET_CHECK(output.initialize(outputData, outputShape));
98     NN_RET_CHECK(instanceNormNhwc(input.getNhwcBuffer(), input.getNhwcShape(), gamma, beta, epsilon,
99                                   output.getNhwcBuffer(), output.getNhwcShape()));
100     NN_RET_CHECK(output.commit());
101     return true;
102 }
103 
104 }  // namespace
105 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
106 
validate(const IOperationValidationContext * context)107 Result<Version> validate(const IOperationValidationContext* context) {
108     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
109     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
110     std::vector<OperandType> inExpectedTypes;
111     auto inputType = context->getInputType(kInputTensor);
112     if (inputType == OperandType::TENSOR_FLOAT32) {
113         inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::FLOAT32, OperandType::FLOAT32,
114                            OperandType::FLOAT32, OperandType::BOOL};
115     } else if (inputType == OperandType::TENSOR_FLOAT16) {
116         inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::FLOAT16, OperandType::FLOAT16,
117                            OperandType::FLOAT16, OperandType::BOOL};
118     } else {
119         return NN_ERROR() << "Unsupported input tensor type for operation " << kOperationName;
120     }
121     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
122     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
123     return Version::ANDROID_Q;
124 }
125 
126 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
prepare(IOperationExecutionContext * context)127 bool prepare(IOperationExecutionContext* context) {
128     Shape input = context->getInputShape(kInputTensor);
129     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
130     return context->setOutputShape(kOutputTensor, input);
131 }
132 
execute(IOperationExecutionContext * context)133 bool execute(IOperationExecutionContext* context) {
134     switch (context->getInputType(kInputTensor)) {
135         case OperandType::TENSOR_FLOAT16:
136             return instanceNorm(context->getInputBuffer<_Float16>(kInputTensor),
137                                 context->getInputShape(kInputTensor),
138                                 context->getInputValue<_Float16>(kGammaScalar),
139                                 context->getInputValue<_Float16>(kBetaScalar),
140                                 context->getInputValue<_Float16>(kEpsilonScalar),
141                                 context->getInputValue<bool>(kLayoutScalar),
142                                 context->getOutputBuffer<_Float16>(kOutputTensor),
143                                 context->getOutputShape(kOutputTensor));
144         case OperandType::TENSOR_FLOAT32:
145             return instanceNorm(context->getInputBuffer<float>(kInputTensor),
146                                 context->getInputShape(kInputTensor),
147                                 context->getInputValue<float>(kGammaScalar),
148                                 context->getInputValue<float>(kBetaScalar),
149                                 context->getInputValue<float>(kEpsilonScalar),
150                                 context->getInputValue<bool>(kLayoutScalar),
151                                 context->getOutputBuffer<float>(kOutputTensor),
152                                 context->getOutputShape(kOutputTensor));
153         default:
154             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
155     }
156 }
157 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
158 
159 }  // namespace instance_normalization
160 
161 NN_REGISTER_OPERATION(INSTANCE_NORMALIZATION, instance_normalization::kOperationName,
162                       instance_normalization::validate, instance_normalization::prepare,
163                       instance_normalization::execute);
164 
165 }  // namespace nn
166 }  // namespace android
167