1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22
23 #include "OperationResolver.h"
24 #include "RNN.h"
25 #include "nnapi/TypeUtils.h"
26
27 namespace android {
28 namespace nn {
29 namespace unidirectional_sequence_rnn {
30
31 constexpr uint32_t kNumInputs = 7;
32 constexpr uint32_t kInputTensor = 0;
33 constexpr uint32_t kWeightsTensor = 1;
34 constexpr uint32_t kRecurrentWeightsTensor = 2;
35 constexpr uint32_t kBiasTensor = 3;
36 constexpr uint32_t kHiddenStateTensor = 4;
37 constexpr uint32_t kActivationParam = 5;
38 constexpr uint32_t kTimeMajorParam = 6;
39
40 constexpr uint32_t kNumOutputs = 1;
41 constexpr uint32_t kNumOutputsWithState = 2;
42 constexpr uint32_t kOutputTensor = 0;
43 constexpr uint32_t kStateOutputTensor = 1;
44
45 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
46 namespace {
47
48 template <typename T>
transposeFirstTwoDims(const T * input,const Shape & inputShape,T * output)49 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
50 const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
51 const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
52 const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
53 for (int f = 0; f < firstDimSize; ++f) {
54 for (int s = 0; s < secondDimSize; ++s) {
55 for (int i = 0; i < inputSize; ++i) {
56 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
57 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
58 output[outputIndex] = input[inputIndex];
59 }
60 }
61 }
62 }
63
64 template <typename T>
executeTyped(IOperationExecutionContext * context)65 bool executeTyped(IOperationExecutionContext* context) {
66 const T* input = context->getInputBuffer<T>(kInputTensor);
67 Shape inputShape = context->getInputShape(kInputTensor);
68 const T* weights = context->getInputBuffer<T>(kWeightsTensor);
69 Shape weightsShape = context->getInputShape(kWeightsTensor);
70 const T* recurrentWeights = context->getInputBuffer<T>(kRecurrentWeightsTensor);
71 Shape recurrentWeightsShape = context->getInputShape(kRecurrentWeightsTensor);
72 const T* bias = context->getInputBuffer<T>(kBiasTensor);
73 const T* hiddenState = context->getInputBuffer<T>(kHiddenStateTensor);
74 int32_t activation = context->getInputValue<int32_t>(kActivationParam);
75
76 T* output = context->getOutputBuffer<T>(kOutputTensor);
77 Shape outputShape = context->getOutputShape(kOutputTensor);
78
79 int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam);
80 // If the input tensors are not in time major format, we transpose the first
81 // two dimensions, and set input and output pointers to temporary vectors
82 // which are transposed back after the RNN is applied.
83 std::vector<T> inputTransposed;
84 std::vector<T> outputTransposed;
85 if (!timeMajor) {
86 // Convert input and output to time major format.
87 inputTransposed.resize(getNumberOfElements(inputShape));
88 outputTransposed.resize(getNumberOfElements(outputShape));
89 transposeFirstTwoDims(input, inputShape, inputTransposed.data());
90 input = inputTransposed.data();
91 output = outputTransposed.data();
92 std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
93 std::swap(outputShape.dimensions[0], outputShape.dimensions[1]);
94 }
95
96 const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
97 const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
98 const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
99 const uint32_t numUnits = getSizeOfDimension(weightsShape, 0);
100
101 // A shape at a fixed step (removed time dimension).
102 Shape fixedTimeInputShape = inputShape;
103 fixedTimeInputShape.dimensions.resize(2);
104 fixedTimeInputShape.dimensions[0] = inputShape.dimensions[1];
105 fixedTimeInputShape.dimensions[1] = inputShape.dimensions[2];
106
107 for (int i = 0; i < maxTime; ++i) {
108 RNN::RNNStep<T>(input, fixedTimeInputShape, hiddenState, bias, weights, weightsShape,
109 recurrentWeights, recurrentWeightsShape, activation, output);
110 input += batchSize * inputSize;
111 hiddenState = output;
112 output += batchSize * numUnits;
113 }
114
115 if (!timeMajor) {
116 transposeFirstTwoDims(outputTransposed.data(), outputShape,
117 context->getOutputBuffer<T>(kOutputTensor));
118 }
119
120 if (context->getNumOutputs() == kNumOutputsWithState) {
121 // We checked that the state output is not omitted during preparation.
122 T* stateOutput = context->getOutputBuffer<T>(kStateOutputTensor);
123 std::copy(hiddenState, hiddenState + batchSize * numUnits, stateOutput);
124 }
125 return true;
126 }
127
128 } // namespace
129 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
130
validate(const IOperationValidationContext * context)131 Result<Version> validate(const IOperationValidationContext* context) {
132 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
133 const int numOutputs = context->getNumOutputs();
134 NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState);
135 OperandType inputType = context->getInputType(kInputTensor);
136 if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) {
137 return NN_ERROR() << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
138 << inputType;
139 }
140 NN_RET_CHECK(validateInputTypes(context, {inputType, inputType, inputType, inputType, inputType,
141 OperandType::INT32, OperandType::INT32}));
142 std::vector<OperandType> outputTypes = {inputType};
143 Version minVersionSupported = Version::ANDROID_Q;
144 if (numOutputs == kNumOutputsWithState) {
145 minVersionSupported = Version::ANDROID_R;
146 outputTypes.push_back(inputType);
147 }
148 NN_RET_CHECK(validateOutputTypes(context, outputTypes));
149 return minVersionSupported;
150 }
151
152 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
prepare(IOperationExecutionContext * context)153 bool prepare(IOperationExecutionContext* context) {
154 Shape input = context->getInputShape(kInputTensor);
155 Shape weights = context->getInputShape(kWeightsTensor);
156 Shape recurrentWeights = context->getInputShape(kRecurrentWeightsTensor);
157 Shape bias = context->getInputShape(kBiasTensor);
158 Shape hiddenState = context->getInputShape(kHiddenStateTensor);
159
160 int32_t timeMajor = context->getInputValue<int32_t>(kTimeMajorParam);
161 NN_RET_CHECK(timeMajor == 0 || timeMajor == 1);
162 const uint32_t batchSize =
163 timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
164 const uint32_t maxTime =
165 timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
166 const uint32_t numUnits = getSizeOfDimension(weights, 0);
167 const uint32_t inputSize = getSizeOfDimension(input, 2);
168
169 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3);
170 NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2);
171 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentWeights), 2);
172 NN_RET_CHECK_EQ(getNumberOfDimensions(bias), 1);
173 NN_RET_CHECK_EQ(getNumberOfDimensions(hiddenState), 2);
174
175 NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(weights, 1));
176 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(bias, 0));
177 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 0));
178 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(recurrentWeights, 1));
179 NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(hiddenState, 0));
180 NN_RET_CHECK_EQ(numUnits, getSizeOfDimension(hiddenState, 1));
181
182 Shape output = context->getOutputShape(kOutputTensor);
183 output.dimensions.resize(3);
184 output.dimensions[0] = timeMajor ? maxTime : batchSize;
185 output.dimensions[1] = timeMajor ? batchSize : maxTime;
186 output.dimensions[2] = numUnits;
187
188 if (context->getNumOutputs() == kNumOutputsWithState) {
189 NN_RET_CHECK(!context->isOmittedOutput(kStateOutputTensor));
190 Shape outputStateShape = context->getInputShape(kHiddenStateTensor);
191 outputStateShape.dimensions.resize(2);
192 outputStateShape.dimensions[0] = batchSize;
193 outputStateShape.dimensions[1] = numUnits;
194 NN_RET_CHECK(context->setOutputShape(kStateOutputTensor, outputStateShape));
195 }
196
197 return context->setOutputShape(kOutputTensor, output);
198 }
199
execute(IOperationExecutionContext * context)200 bool execute(IOperationExecutionContext* context) {
201 if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
202 executeTyped<_Float16>(context);
203 } else {
204 executeTyped<float>(context);
205 }
206 return true;
207 }
208 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
209
210 } // namespace unidirectional_sequence_rnn
211
212 NN_REGISTER_OPERATION(UNIDIRECTIONAL_SEQUENCE_RNN, "UNIDIRECTIONAL_SEQUENCE_RNN",
213 unidirectional_sequence_rnn::validate, unidirectional_sequence_rnn::prepare,
214 unidirectional_sequence_rnn::execute);
215
216 } // namespace nn
217 } // namespace android
218