1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_H 18 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_H 19 20 #include <stddef.h> 21 22 #include <cstdint> 23 #include <vector> 24 25 #include "ActivationFunctor.h" 26 27 #ifndef NN_COMPATIBILITY_LIBRARY_BUILD 28 #include "operations/BidirectionalSequenceLSTM.h" 29 #include "operations/Cast.h" 30 #include "operations/EmbeddingLookup.h" 31 #include "operations/ExpandDims.h" 32 #include "operations/HashtableLookup.h" 33 #include "operations/LSHProjection.h" 34 #include "operations/LSTM.h" 35 #include "operations/MaximumMinimum.h" 36 #include "operations/Multinomial.h" 37 #include "operations/Pow.h" 38 #include "operations/QuantizedLSTM.h" 39 #include "operations/RNN.h" 40 #include "operations/SVDF.h" 41 #include "operations/Tile.h" 42 #endif // NN_COMPATIBILITY_LIBRARY_BUILD 43 44 namespace android { 45 namespace nn { 46 47 struct Shape; 48 49 bool floorFloat16(const _Float16* inputData, _Float16* outputData, const Shape& shape); 50 bool floorFloat32(const float* inputData, float* outputData, const Shape& shape); 51 52 bool depthwiseConvFloat16(const _Float16* inputData, const Shape& inputShape, 53 const _Float16* filterData, const Shape& filterShape, 54 const _Float16* biasData, const Shape& biasShape, int32_t paddingLeft, 55 int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom, 56 int32_t strideWidth, int32_t strideHeight, int32_t dilationWidthFactor, 57 int32_t dilationHeightFactor, int32_t depthMultiplier, int32_t activation, 58 _Float16* outputData, const Shape& outputShape); 59 bool depthwiseConvFloat32(const float* inputData, const Shape& inputShape, const float* filterData, 60 const Shape& filterShape, const float* biasData, const Shape& biasShape, 61 int32_t paddingLeft, int32_t paddingRight, int32_t paddingTop, 62 int32_t paddingBottom, int32_t strideWidth, int32_t strideHeight, 63 int32_t dilationWidthFactor, int32_t dilationHeightFactor, 64 int32_t depthMultiplier, int32_t activation, float* outputData, 65 const Shape& outputShape); 66 bool depthwiseConvQuant8(const uint8_t* inputData, const Shape& inputShape, 67 const uint8_t* filterData, const Shape& filterShape, 68 const int32_t* biasData, const Shape& biasShape, int32_t paddingLeft, 69 int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom, 70 int32_t strideWidth, int32_t strideHeight, int32_t dilationWidthFactor, 71 int32_t dilationHeightFactor, int32_t depthMultiplier, int32_t activation, 72 uint8_t* outputData, const Shape& outputShape); 73 bool depthwiseConvQuant8PerChannel(const uint8_t* inputData, const Shape& inputShape, 74 const int8_t* filterData, const Shape& filterShape, 75 const float* filterScales, const int32_t* biasData, 76 const Shape& biasShape, int32_t paddingLeft, 77 int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom, 78 int32_t strideWidth, int32_t strideHeight, 79 int32_t dilationWidthFactor, int32_t dilationHeightFactor, 80 int32_t depthMultiplier, int32_t activation, uint8_t* outputData, 81 const Shape& outputShape); 82 83 bool localResponseNormFloat16(const _Float16* inputData, const Shape& inputShape, int32_t radius, 84 float bias, float alpha, float beta, int32_t axis, 85 _Float16* outputData, const Shape& outputShape); 86 bool localResponseNormFloat32(const float* inputData, const Shape& inputShape, int32_t radius, 87 float bias, float alpha, float beta, int32_t axis, float* outputData, 88 const Shape& outputShape); 89 90 bool copyData(const void* inputData, const Shape& inputShape, void* outputData, 91 const Shape& outputShape); 92 93 template <typename T> 94 bool depthToSpaceGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize, 95 T* outputData, const Shape& outputShape); 96 template <typename T> 97 bool spaceToDepthGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize, 98 T* outputData, const Shape& outputShape); 99 100 template <typename T> 101 bool padGeneric(const T* inputData, const Shape& inputShape, const int32_t* paddings, T pad_value, 102 T* outputData, const Shape& outputShape); 103 104 template <typename T> 105 bool batchToSpaceGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize, 106 T* outputData, const Shape& outputShape); 107 108 template <typename T> 109 bool spaceToBatchGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize, 110 const int32_t* padding, const Shape& paddingShape, T* outputData, 111 const Shape& outputShape); 112 113 bool meanFloat16(_Float16* inputData, const Shape& inputShape, const int32_t* axis, 114 const Shape& axisShape, bool keepDims, _Float16* outputData, 115 const Shape& outputShape); 116 template <typename T, typename U> 117 bool meanGeneric(T* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape, 118 bool keepDims, T* outputData, const Shape& outputShape); 119 120 bool stridedSliceGeneric(const uint8_t* inputData, const Shape& inputShape, 121 const int32_t* beginData, const int32_t* endData, 122 const int32_t* stridesData, int32_t beginMask, int32_t endMask, 123 int32_t shrinkAxisMask, uint8_t* outputData, const Shape& outputShape); 124 125 bool argMinMaxGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t axis, 126 bool isArgMin, uint8_t* outputData, const Shape& outputShape); 127 128 bool splitFloat16(const _Float16* inputData, const Shape& inputShape, int32_t axis, 129 const std::vector<_Float16*>* outputDataPtrs, 130 const std::vector<Shape>& outputShapes); 131 132 bool splitFloat32(const float* inputData, const Shape& inputShape, const int32_t axis, 133 const std::vector<float*>* outputDataPtrs, 134 const std::vector<Shape>& outputShapes); 135 136 bool splitInt32(const int32_t* inputData, const Shape& inputShape, const int32_t axis, 137 const std::vector<int32_t*>* outputDataPtrs, 138 const std::vector<Shape>& outputShapes); 139 140 bool splitQuant8(const uint8_t* inputData, const Shape& inputShape, const int32_t axis, 141 const std::vector<uint8_t*>* outputDataPtrs, 142 const std::vector<Shape>& outputShapes); 143 144 bool splitQuant8Signed(const int8_t* inputData, const Shape& inputShape, const int32_t axis, 145 const std::vector<int8_t*>* outputDataPtrs, 146 const std::vector<Shape>& outputShapes); 147 148 bool groupedConvFloat16(const _Float16* inputData, const Shape& inputShape, 149 const _Float16* filterData, const Shape& filterShape, 150 const _Float16* biasData, const Shape& biasShape, int32_t numGroups, 151 int32_t padding_left, int32_t padding_right, int32_t padding_top, 152 int32_t padding_bottom, int32_t stride_width, int32_t stride_height, 153 int32_t activation, _Float16* outputData, const Shape& outputShape); 154 155 bool groupedConvFloat32(const float* inputData, const Shape& inputShape, const float* filterData, 156 const Shape& filterShape, const float* biasData, const Shape& biasShape, 157 int32_t numGroups, int32_t padding_left, int32_t padding_right, 158 int32_t padding_top, int32_t padding_bottom, int32_t stride_width, 159 int32_t stride_height, int32_t activation, float* outputData, 160 const Shape& outputShape); 161 162 template <typename T> 163 bool groupedConvQuant8(const T* inputData, const Shape& inputShape, const T* filterData, 164 const Shape& filterShape, const int32_t* biasData, const Shape& biasShape, 165 int32_t numGroups, int32_t padding_left, int32_t padding_right, 166 int32_t padding_top, int32_t padding_bottom, int32_t stride_width, 167 int32_t stride_height, int32_t activation, T* outputData, 168 const Shape& outputShape); 169 170 template <typename T> 171 bool groupedConvQuant8PerChannel(const T* inputData, const Shape& inputShape, 172 const int8_t* filterData, const Shape& filterShape, 173 const float* filterScales, const int32_t* biasData, 174 const Shape& biasShape, int32_t padding_left, 175 int32_t padding_right, int32_t padding_top, int32_t padding_bottom, 176 int32_t stride_width, int32_t stride_height, int32_t numGroups, 177 int32_t activation, T* outputData, const Shape& outputShape); 178 179 bool channelShuffleGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t numGroups, 180 int32_t axis, uint8_t* outputData, const Shape& outputShape); 181 } // namespace nn 182 } // namespace android 183 #endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATIONS_H 184