1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include <algorithm>
20 #include <utility>
21 #include <vector>
22 
23 #include "OperationResolver.h"
24 #include "RNN.h"
25 
26 namespace android {
27 namespace nn {
28 namespace bidirectional_sequence_rnn {
29 
30 constexpr uint32_t kNumInputs = 15;
31 constexpr uint32_t kInputTensor = 0;
32 // Forward cell tensors
33 constexpr uint32_t kFwWeightsTensor = 1;
34 constexpr uint32_t kFwRecurrentWeightsTensor = 2;
35 constexpr uint32_t kFwBiasTensor = 3;
36 constexpr uint32_t kFwHiddenStateTensor = 4;
37 // Backward cell tensors
38 constexpr uint32_t kBwWeightsTensor = 5;
39 constexpr uint32_t kBwRecurrentWeightsTensor = 6;
40 constexpr uint32_t kBwBiasTensor = 7;
41 constexpr uint32_t kBwHiddenStateTensor = 8;
42 // Auxiliary inputs
43 constexpr uint32_t kAuxInputTensor = 9;       // optional
44 constexpr uint32_t kFwAuxWeightsTensor = 10;  // optional
45 constexpr uint32_t kBwAuxWeightsTensor = 11;  // optional
46 // Cell parameters
47 constexpr uint32_t kActivationParam = 12;
48 constexpr uint32_t kTimeMajorParam = 13;
49 constexpr uint32_t kMergeOutputsParam = 14;
50 
51 constexpr uint32_t kNumOutputs = 2;
52 constexpr uint32_t kNumOutputsMerged = 1;
53 constexpr uint32_t kNumOutputsWithState = 4;
54 constexpr uint32_t kNumOutputsMergedWithState = 3;
55 
56 constexpr uint32_t kFwOutputTensor = 0;
57 constexpr uint32_t kBwOutputTensor = 1;  // Only if mergeOutputs parameter is false
58 constexpr uint32_t kFwOutputHiddenStateTensor = 2;
59 constexpr uint32_t kBwOutputHiddenStateTensor = 3;
60 
61 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
62 namespace {
63 
64 template <typename T>
transposeFirstTwoDims(const T * input,const Shape & inputShape,T * output)65 void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
66     const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
67     const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
68     const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
69     for (int f = 0; f < firstDimSize; ++f) {
70         for (int s = 0; s < secondDimSize; ++s) {
71             for (int i = 0; i < inputSize; ++i) {
72                 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
73                 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
74                 output[outputIndex] = input[inputIndex];
75             }
76         }
77     }
78 }
79 
removeFirstDim(const Shape & input)80 Shape removeFirstDim(const Shape& input) {
81     Shape output = input;
82     output.dimensions.resize(input.dimensions.size() - 1);
83     for (int i = 0; i < input.dimensions.size() - 1; ++i) {
84         output.dimensions[i] = input.dimensions[i + 1];
85     }
86     return output;
87 }
88 
89 enum class LinkingMode {
90     NO_LINKING,
91     PARALLEL_LINKING,
92     CROSS_LINKING,
93 };
94 
getLinkingMode(IOperationExecutionContext * context,LinkingMode * linkingMode)95 bool getLinkingMode(IOperationExecutionContext* context, LinkingMode* linkingMode) {
96     const bool hasAuxInput = !context->isOmittedInput(kAuxInputTensor);
97     const bool hasFwAuxWeights = !context->isOmittedInput(kFwAuxWeightsTensor);
98     const bool hasBwAuxWeights = !context->isOmittedInput(kBwAuxWeightsTensor);
99 
100     // Three possible configurations for three possible linking modes:
101     // 1) NO_LINKING -- no auxiliary tensors at all
102     // 2) PARALLEL_LINKING -- auxiliary input is provided and used as a regular
103     //    input to the backward network, so the auxiliary weights are omitted.
104     // 3) CROSS_LINKING -- auxiliary input is provided and multiplied by
105     //    auxiliary weights.
106     if (!hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) {
107         *linkingMode = LinkingMode::NO_LINKING;
108     } else if (hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) {
109         *linkingMode = LinkingMode::PARALLEL_LINKING;
110     } else if (hasAuxInput && hasFwAuxWeights && hasBwAuxWeights) {
111         *linkingMode = LinkingMode::CROSS_LINKING;
112     } else {
113         NN_RET_CHECK_FAIL()
114                 << "Unsupported auxiliary tensors configuration for BIDIRECTIONAL_SEQUENCE_RNN.";
115     }
116 
117     return true;
118 }
119 
120 template <typename T>
executeTyped(IOperationExecutionContext * context)121 bool executeTyped(IOperationExecutionContext* context) {
122     const T* input = context->getInputBuffer<T>(kInputTensor);
123     Shape inputShape = context->getInputShape(kInputTensor);
124 
125     const T* fwWeights = context->getInputBuffer<T>(kFwWeightsTensor);
126     Shape fwWeightsShape = context->getInputShape(kFwWeightsTensor);
127     const T* fwRecurrentWeights = context->getInputBuffer<T>(kFwRecurrentWeightsTensor);
128     Shape fwRecurrentWeightsShape = context->getInputShape(kFwRecurrentWeightsTensor);
129     const T* fwBias = context->getInputBuffer<T>(kFwBiasTensor);
130     const T* fwHiddenState = context->getInputBuffer<T>(kFwHiddenStateTensor);
131 
132     const T* bwWeights = context->getInputBuffer<T>(kBwWeightsTensor);
133     Shape bwWeightsShape = context->getInputShape(kBwWeightsTensor);
134     const T* bwRecurrentWeights = context->getInputBuffer<T>(kBwRecurrentWeightsTensor);
135     Shape bwRecurrentWeightsShape = context->getInputShape(kBwRecurrentWeightsTensor);
136     const T* bwBias = context->getInputBuffer<T>(kBwBiasTensor);
137     const T* bwHiddenState = context->getInputBuffer<T>(kBwHiddenStateTensor);
138 
139     const T* auxInput = nullptr;
140     const T* fwAuxWeights = nullptr;
141     const T* bwAuxWeights = nullptr;
142     LinkingMode linkingMode;
143     NN_RET_CHECK(getLinkingMode(context, &linkingMode));
144     if (linkingMode == LinkingMode::CROSS_LINKING) {
145         auxInput = context->getInputBuffer<T>(kAuxInputTensor);
146         fwAuxWeights = context->getInputBuffer<T>(kFwAuxWeightsTensor);
147         bwAuxWeights = context->getInputBuffer<T>(kBwAuxWeightsTensor);
148     } else if (linkingMode == LinkingMode::PARALLEL_LINKING) {
149         auxInput = context->getInputBuffer<T>(kAuxInputTensor);
150     }
151     const bool hasAuxInput = (linkingMode == LinkingMode::CROSS_LINKING ||
152                               linkingMode == LinkingMode::PARALLEL_LINKING);
153     const bool hasAuxWeights = (linkingMode == LinkingMode::CROSS_LINKING);
154     Shape auxInputShape = context->getInputShape(kAuxInputTensor);
155     Shape fwAuxWeightsShape = context->getInputShape(kFwAuxWeightsTensor);
156     Shape bwAuxWeightsShape = context->getInputShape(kBwAuxWeightsTensor);
157 
158     const int32_t activation = context->getInputValue<int32_t>(kActivationParam);
159     const bool timeMajor = context->getInputValue<bool>(kTimeMajorParam);
160     const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
161 
162     T* fwOutput = context->getOutputBuffer<T>(kFwOutputTensor);
163     Shape fwOutputShape = context->getOutputShape(kFwOutputTensor);
164     T* bwOutput = nullptr;
165     Shape bwOutputShape;
166     if (!mergeOutputs) {
167         bwOutputShape = context->getOutputShape(kBwOutputTensor);
168         bwOutput = context->getOutputBuffer<T>(kBwOutputTensor);
169     }
170 
171     // If the input tensors are not in time major format, we transpose the first
172     // two dimensions, and set input and output pointers to temporary vectors
173     // which are transposed back after the RNN is applied.
174     std::vector<T> inputTransposed;
175     std::vector<T> auxInputTransposed;
176     std::vector<T> fwOutputTransposed;
177     std::vector<T> bwOutputTransposed;
178     if (!timeMajor) {
179         // First, resize temporary buffers to accommodate for transposed tensors.
180         inputTransposed.resize(getNumberOfElements(inputShape));
181         if (hasAuxInput) {
182             auxInputTransposed.resize(getNumberOfElements(auxInputShape));
183         }
184         fwOutputTransposed.resize(getNumberOfElements(fwOutputShape));
185         if (!mergeOutputs) {
186             bwOutputTransposed.resize(getNumberOfElements(bwOutputShape));
187         }
188 
189         // Transpose the input tensors.
190         transposeFirstTwoDims(input, inputShape, inputTransposed.data());
191         if (hasAuxInput) {
192             transposeFirstTwoDims(auxInput, auxInputShape, auxInputTransposed.data());
193         }
194 
195         // Change input and output pointers to the temporary buffers.
196         input = inputTransposed.data();
197         if (hasAuxInput) {
198             auxInput = auxInputTransposed.data();
199         }
200         fwOutput = fwOutputTransposed.data();
201         if (!mergeOutputs) {
202             bwOutput = bwOutputTransposed.data();
203         }
204 
205         // Swap the first two dimensions in the Shapes to reflect the
206         // transposition.
207         std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
208         if (hasAuxInput) {
209             std::swap(auxInputShape.dimensions[0], auxInputShape.dimensions[1]);
210         }
211         std::swap(fwOutputShape.dimensions[0], fwOutputShape.dimensions[1]);
212         if (!mergeOutputs) {
213             std::swap(bwOutputShape.dimensions[0], bwOutputShape.dimensions[1]);
214         }
215     }
216 
217     const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
218     const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
219     const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
220     uint32_t auxInputSize = 0;
221     if (hasAuxInput) {
222         auxInputSize = getSizeOfDimension(auxInputShape, 2);
223     }
224     const uint32_t fwNumUnits = getSizeOfDimension(fwWeightsShape, 0);
225     const uint32_t bwNumUnits = getSizeOfDimension(bwWeightsShape, 0);
226 
227     Shape fixedTimeInputShape = removeFirstDim(inputShape);
228     Shape fixedTimeAuxInputShape = auxInputShape;
229     if (hasAuxInput) {
230         fixedTimeAuxInputShape = removeFirstDim(auxInputShape);
231     }
232 
233     const T* bwInput = input;
234     if (linkingMode == LinkingMode::PARALLEL_LINKING) {
235         bwInput = auxInput;
236         auxInput = nullptr;
237     }
238 
239     const bool outputState = (context->getNumOutputs() == kNumOutputsWithState ||
240                               context->getNumOutputs() == kNumOutputsMergedWithState);
241     T* fwOutputHiddenState = nullptr;
242     T* bwOutputHiddenState = nullptr;
243     // Create an additional buffer to store a hidden state between steps.
244     std::vector<T> tempHiddenState;
245     if (outputState) {
246         const int delta = mergeOutputs ? 1 : 0;
247         fwOutputHiddenState = context->getOutputBuffer<T>(kFwOutputHiddenStateTensor - delta);
248         bwOutputHiddenState = context->getOutputBuffer<T>(kBwOutputHiddenStateTensor - delta);
249     } else {
250         tempHiddenState.resize(std::max(batchSize * fwNumUnits, batchSize * bwNumUnits));
251         fwOutputHiddenState = tempHiddenState.data();
252         bwOutputHiddenState = tempHiddenState.data();
253     }
254 
255     // Forward pass
256     for (int i = 0; i < maxTime; ++i) {
257         const T* inputBatchPtr = input + i * batchSize * inputSize;
258         const T* auxInputBatchPtr = nullptr;
259         if (hasAuxWeights) {
260             auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
261         }
262         const uint32_t fwOutputBatchStride = mergeOutputs ? (fwNumUnits + bwNumUnits) : fwNumUnits;
263         T* fwOutputBatchPtr = fwOutput + i * batchSize * fwOutputBatchStride;
264 
265         RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
266                         fixedTimeAuxInputShape, fwHiddenState, fwBias, fwWeights, fwWeightsShape,
267                         fwAuxWeights, fwAuxWeightsShape, fwRecurrentWeights,
268                         fwRecurrentWeightsShape, activation, fwOutputBatchStride,
269                         /*outputBatchOffset=*/0, fwOutputBatchPtr, fwOutputHiddenState);
270 
271         fwHiddenState = fwOutputHiddenState;
272     }
273 
274     // Backward pass
275     for (int i = maxTime - 1; i >= 0; --i) {
276         const T* inputBatchPtr = bwInput + i * batchSize * inputSize;
277         const T* auxInputBatchPtr = nullptr;
278         if (hasAuxWeights) {
279             auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
280         }
281         T* bwOutputBatchPtr;
282         uint32_t bwOutputBatchOffset = 0;
283         uint32_t bwOutputBatchStride;
284         if (mergeOutputs) {
285             bwOutputBatchStride = fwNumUnits + bwNumUnits;
286             bwOutputBatchOffset = fwNumUnits;
287             bwOutputBatchPtr = fwOutput + i * batchSize * bwOutputBatchStride;
288         } else {
289             bwOutputBatchStride = bwNumUnits;
290             bwOutputBatchPtr = bwOutput + i * batchSize * bwOutputBatchStride;
291         }
292 
293         RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
294                         fixedTimeAuxInputShape, bwHiddenState, bwBias, bwWeights, bwWeightsShape,
295                         bwAuxWeights, bwAuxWeightsShape, bwRecurrentWeights,
296                         bwRecurrentWeightsShape, activation, bwOutputBatchStride,
297                         bwOutputBatchOffset, bwOutputBatchPtr, bwOutputHiddenState);
298 
299         bwHiddenState = bwOutputHiddenState;
300     }
301 
302     // If the inputs were in batch major format, transpose data in temporary
303     // buffers and write to the output(s).
304     if (!timeMajor) {
305         transposeFirstTwoDims(fwOutputTransposed.data(), fwOutputShape,
306                               context->getOutputBuffer<T>(kFwOutputTensor));
307         if (!mergeOutputs) {
308             transposeFirstTwoDims(bwOutputTransposed.data(), bwOutputShape,
309                                   context->getOutputBuffer<T>(kBwOutputTensor));
310         }
311     }
312     return true;
313 }
314 
315 }  // namespace
316 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
317 
validate(const IOperationValidationContext * context)318 Result<Version> validate(const IOperationValidationContext* context) {
319     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
320     // Exact number is dependent on the mergeOutputs parameter and checked
321     // during preparation.
322     const uint32_t numOutputs = context->getNumOutputs();
323     NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsMerged ||
324                  numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState);
325 
326     OperandType inputType = context->getInputType(kInputTensor);
327     if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) {
328         return NN_ERROR() << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
329                           << inputType;
330     }
331     NN_RET_CHECK(validateInputTypes(
332             context, {inputType, inputType, inputType, inputType, inputType, inputType, inputType,
333                       inputType, inputType, inputType, inputType, inputType, OperandType::INT32,
334                       OperandType::BOOL, OperandType::BOOL}));
335 
336     std::vector<OperandType> outExpectedTypes(numOutputs, inputType);
337     NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
338 
339     Version minSupportedVersion = Version::ANDROID_Q;
340     if (numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState) {
341         minSupportedVersion = Version::ANDROID_R;
342     }
343     return minSupportedVersion;
344 }
345 
346 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
prepare(IOperationExecutionContext * context)347 bool prepare(IOperationExecutionContext* context) {
348     const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
349     const int32_t numOutputs = context->getNumOutputs();
350     if (mergeOutputs) {
351         NN_RET_CHECK(numOutputs == kNumOutputsMerged || numOutputs == kNumOutputsMergedWithState);
352     } else {
353         NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState);
354     }
355 
356     // Check that none of the required inputs are omitted.
357     const std::vector<int> requiredInputs = {
358             kInputTensor,         kFwWeightsTensor, kFwRecurrentWeightsTensor, kFwBiasTensor,
359             kFwHiddenStateTensor, kBwWeightsTensor, kBwRecurrentWeightsTensor, kBwBiasTensor,
360             kBwHiddenStateTensor, kActivationParam, kTimeMajorParam,           kMergeOutputsParam,
361     };
362     for (const int requiredInput : requiredInputs) {
363         NN_RET_CHECK(!context->isOmittedInput(requiredInput))
364                 << "required input " << requiredInput << " is omitted";
365     }
366 
367     Shape input = context->getInputShape(kInputTensor);
368     Shape fwWeights = context->getInputShape(kFwWeightsTensor);
369     Shape fwRecurrentWeights = context->getInputShape(kFwRecurrentWeightsTensor);
370     Shape fwBias = context->getInputShape(kFwBiasTensor);
371     Shape fwHiddenState = context->getInputShape(kFwHiddenStateTensor);
372     Shape bwWeights = context->getInputShape(kBwWeightsTensor);
373     Shape bwRecurrentWeights = context->getInputShape(kBwRecurrentWeightsTensor);
374     Shape bwBias = context->getInputShape(kBwBiasTensor);
375     Shape bwHiddenState = context->getInputShape(kBwHiddenStateTensor);
376 
377     Shape auxInput = context->getInputShape(kAuxInputTensor);
378     Shape fwAuxWeights = context->getInputShape(kFwAuxWeightsTensor);
379     Shape bwAuxWeights = context->getInputShape(kBwAuxWeightsTensor);
380 
381     LinkingMode linkingMode;
382     NN_RET_CHECK(getLinkingMode(context, &linkingMode));
383 
384     bool timeMajor = context->getInputValue<bool>(kTimeMajorParam);
385     const uint32_t batchSize =
386             timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
387     const uint32_t maxTime =
388             timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
389     const uint32_t fwNumUnits = getSizeOfDimension(fwWeights, 0);
390     const uint32_t bwNumUnits = getSizeOfDimension(bwWeights, 0);
391     const uint32_t inputSize = getSizeOfDimension(input, 2);
392 
393     NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3);
394     NN_RET_CHECK_EQ(getNumberOfDimensions(fwWeights), 2);
395     NN_RET_CHECK_EQ(getNumberOfDimensions(fwRecurrentWeights), 2);
396     NN_RET_CHECK_EQ(getNumberOfDimensions(fwBias), 1);
397     NN_RET_CHECK_EQ(getNumberOfDimensions(fwHiddenState), 2);
398     NN_RET_CHECK_EQ(getNumberOfDimensions(bwWeights), 2);
399     NN_RET_CHECK_EQ(getNumberOfDimensions(bwRecurrentWeights), 2);
400     NN_RET_CHECK_EQ(getNumberOfDimensions(bwBias), 1);
401     NN_RET_CHECK_EQ(getNumberOfDimensions(bwHiddenState), 2);
402 
403     NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(fwWeights, 1));
404     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwBias, 0));
405     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 0));
406     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 1));
407     NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(fwHiddenState, 0));
408     NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwHiddenState, 1));
409 
410     if (linkingMode != LinkingMode::PARALLEL_LINKING) {
411         NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(bwWeights, 1));
412     }
413     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwBias, 0));
414     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 0));
415     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 1));
416     NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(bwHiddenState, 0));
417     NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwHiddenState, 1));
418 
419     if (linkingMode == LinkingMode::CROSS_LINKING) {
420         NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3);
421         NN_RET_CHECK_EQ(getNumberOfDimensions(fwAuxWeights), 2);
422         NN_RET_CHECK_EQ(getNumberOfDimensions(bwAuxWeights), 2);
423 
424         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
425         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
426         NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 0), fwNumUnits);
427         NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
428         NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 0), bwNumUnits);
429         NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
430     } else if (linkingMode == LinkingMode::PARALLEL_LINKING) {
431         NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3);
432 
433         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
434         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
435         NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 2), getSizeOfDimension(bwWeights, 1));
436     }
437 
438     Shape fwOutput = context->getOutputShape(kFwOutputTensor);
439     fwOutput.dimensions.resize(3);
440     fwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
441     fwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
442     fwOutput.dimensions[2] = mergeOutputs ? fwNumUnits + bwNumUnits : fwNumUnits;
443     NN_RET_CHECK(context->setOutputShape(kFwOutputTensor, fwOutput));
444     if (!mergeOutputs) {
445         Shape bwOutput = context->getOutputShape(kBwOutputTensor);
446         bwOutput.dimensions.resize(3);
447         bwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
448         bwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
449         bwOutput.dimensions[2] = bwNumUnits;
450         NN_RET_CHECK(context->setOutputShape(kBwOutputTensor, bwOutput));
451     }
452 
453     const bool outputState =
454             (numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState);
455     if (outputState) {
456         const int delta = mergeOutputs ? 1 : 0;
457         NN_RET_CHECK(context->setOutputShape(kFwOutputHiddenStateTensor - delta,
458                                              context->getInputShape(kFwHiddenStateTensor)));
459         NN_RET_CHECK(context->setOutputShape(kBwOutputHiddenStateTensor - delta,
460                                              context->getInputShape(kBwHiddenStateTensor)));
461     }
462 
463     return true;
464 }
465 
execute(IOperationExecutionContext * context)466 bool execute(IOperationExecutionContext* context) {
467     if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
468         executeTyped<_Float16>(context);
469     } else {
470         executeTyped<float>(context);
471     }
472     return true;
473 }
474 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
475 
476 }  // namespace bidirectional_sequence_rnn
477 
478 NN_REGISTER_OPERATION(BIDIRECTIONAL_SEQUENCE_RNN, "BIDIRECTIONAL_SEQUENCE_RNN",
479                       bidirectional_sequence_rnn::validate, bidirectional_sequence_rnn::prepare,
480                       bidirectional_sequence_rnn::execute, .allowOmittedOperand = true);
481 
482 }  // namespace nn
483 }  // namespace android
484