1 /*
2 * Copyright (c) 2021 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "kws_plugin.h"
17
18 #include "aie_log.h"
19 #include "aie_retcode_inner.h"
20 #include "encdec_facade.h"
21 #include "norm_processor.h"
22 #include "plugin_helper.h"
23 #include "slide_window_processor.h"
24 #include "type_converter.h"
25
26 #ifdef USE_NNIE
27 #include "nnie_adapter.h"
28 #endif
29
30 using namespace OHOS::AI;
31 using namespace OHOS::AI::Feature;
32 namespace {
33 const std::string PLUGIN_MODEL_PATH = "/storage/data/keyword_spotting.wk";
34 const std::string DEFAULT_INFER_MODE = "SYNC";
35 const std::string ALGORITHM_NAME_KWS = "KWS";
36 const int32_t OPTION_GET_INPUT_SIZE = 1001;
37 const int32_t OPTION_GET_OUTPUT_SIZE = 1002;
38 const intptr_t EMPTY_UINTPTR = 0;
39 const int32_t MODEL_INPUT_NODE_ID = 0;
40 const int32_t MODEL_OUTPUT_NODE_ID = 0;
41 }
42
InitWorkplace(KWSWorkplace & worker,SlideWindowProcessorConfig & slideCfg,TypeConverterConfig & convertCfg,NormProcessorConfig & normCfg)43 static int32_t InitWorkplace(KWSWorkplace &worker, SlideWindowProcessorConfig &slideCfg,
44 TypeConverterConfig &convertCfg, NormProcessorConfig &normCfg)
45 {
46 worker.slideProcessor = std::make_shared<SlideWindowProcessor>();
47 worker.typeConverter = std::make_shared<TypeConverter>();
48 worker.normProcessor = std::make_shared<NormProcessor>();
49 if (worker.slideProcessor == nullptr ||
50 worker.typeConverter == nullptr ||
51 worker.normProcessor == nullptr) {
52 HILOGE("[KWSPlugin]Fail to allocate workplaces");
53 return RETCODE_FAILURE;
54 }
55 if (worker.slideProcessor->Init(&slideCfg) != RETCODE_SUCCESS) {
56 HILOGE("[KWSPlugin]Fail to init slideProcessor");
57 return RETCODE_FAILURE;
58 }
59 if (worker.typeConverter->Init(&convertCfg) != RETCODE_SUCCESS) {
60 HILOGE("[KWSPlugin]Fail to init typeConverter");
61 return RETCODE_FAILURE;
62 }
63 if (worker.normProcessor->Init(&normCfg) != RETCODE_SUCCESS) {
64 HILOGE("[KWSPlugin]Fail to init normConfig");
65 return RETCODE_FAILURE;
66 }
67 return RETCODE_SUCCESS;
68 }
69
KWSPlugin()70 KWSPlugin::KWSPlugin()
71 {
72 HILOGD("[KWSPlugin]ctor");
73 handles_.clear();
74 }
75
~KWSPlugin()76 KWSPlugin::~KWSPlugin()
77 {
78 ReleaseAllHandles();
79 HILOGD("[KWSPlugin]dtor");
80 }
81
ReleaseAllHandles()82 void KWSPlugin::ReleaseAllHandles()
83 {
84 for (auto iter = handles_.begin(); iter != handles_.end(); ++iter) {
85 (void)adapter_->ReleaseHandle(iter->first);
86 }
87 adapter_->Deinit();
88 handles_.clear();
89 }
90
Prepare(long long transactionId,const DataInfo & inputInfo,DataInfo & outputInfo)91 int32_t KWSPlugin::Prepare(long long transactionId, const DataInfo &inputInfo, DataInfo &outputInfo)
92 {
93 HILOGI("[KWSPlugin]Begin to prepare, transactionId = %lld", transactionId);
94 std::lock_guard<std::mutex> lock(mutex_);
95 if (adapter_ == nullptr) {
96 #ifdef USE_NNIE
97 adapter_ = std::make_shared<NNIEAdapter>();
98 #endif
99 if (adapter_ == nullptr) {
100 HILOGE("[KWSPlugin]Fail to create engine adapter");
101 return RETCODE_FAILURE;
102 }
103 }
104 intptr_t handle = 0;
105 if (adapter_->Init(PLUGIN_MODEL_PATH.c_str(), handle) != RETCODE_SUCCESS) {
106 HILOGE("[KWSPlugin]NNIEAdapterInit failed");
107 return RETCODE_FAILURE;
108 }
109 const auto iter = handles_.find(handle);
110 if (iter != handles_.end()) {
111 HILOGE("[KWSPlugin]handle=%lld has already existed", (long long)handle);
112 return RETCODE_SUCCESS;
113 }
114 PluginConfig config;
115 if (BuildConfig(handle, config) != RETCODE_SUCCESS) {
116 HILOGE("[KWSPlugin]BuildConfig failed");
117 return RETCODE_FAILURE;
118 }
119 KWSWorkplace worker = {
120 .config = config,
121 .normProcessor = nullptr,
122 .typeConverter = nullptr,
123 .slideProcessor = nullptr
124 };
125 if (InitComponents(worker) != RETCODE_SUCCESS) {
126 HILOGE("[KWSPlugin]InitComponents failed");
127 return RETCODE_FAILURE;
128 }
129 handles_.emplace(handle, worker);
130 return EncdecFacade::ProcessEncode(outputInfo, handle);
131 }
132
GetVersion() const133 const long long KWSPlugin::GetVersion() const
134 {
135 return ALGOTYPE_VERSION_KWS;
136 }
137
GetName() const138 const char *KWSPlugin::GetName() const
139 {
140 return ALGORITHM_NAME_KWS.c_str();
141 }
142
GetInferMode() const143 const char *KWSPlugin::GetInferMode() const
144 {
145 return DEFAULT_INFER_MODE.c_str();
146 }
147
SyncProcess(IRequest * request,IResponse * & response)148 int32_t KWSPlugin::SyncProcess(IRequest *request, IResponse *&response)
149 {
150 HILOGI("[KWSPlugin]SyncProcess start");
151 std::lock_guard<std::mutex> lock(mutex_);
152 if (request == nullptr) {
153 HILOGE("[KWSPlugin]SyncProcess request is nullptr");
154 return RETCODE_NULL_PARAM;
155 }
156 DataInfo inputInfo = request->GetMsg();
157 if (inputInfo.data == nullptr || inputInfo.length <= 0) {
158 HILOGE("[KWSPlugin]SyncProcess inputInfo data is nullptr");
159 return RETCODE_NULL_PARAM;
160 }
161 intptr_t handle = 0;
162 Array<uint16_t> audioInput = {
163 .data = nullptr,
164 .size = 0
165 };
166 int32_t ret = EncdecFacade::ProcessDecode(inputInfo, handle, audioInput);
167 if (ret != RETCODE_SUCCESS) {
168 HILOGE("[KWSPlugin]SyncProcess load inputData failed");
169 return RETCODE_FAILURE;
170 }
171 const auto iter = handles_.find(handle);
172 if (iter == handles_.end()) {
173 HILOGE("[KWSPlugin]SyncProcess no matched handle [%lld]", (long long)handle);
174 return RETCODE_NULL_PARAM;
175 }
176 Array<int32_t> processorOutput = {
177 .data = nullptr,
178 .size = 0
179 };
180 ret = GetNormedFeatures(audioInput, processorOutput, iter->second);
181 if (ret != RETCODE_SUCCESS) {
182 HILOGE("[KWSPlugin]Fail to get normed features");
183 return RETCODE_FAILURE;
184 }
185 DataInfo outputInfo = {
186 .data = nullptr,
187 .length = 0
188 };
189 ret = MakeInference(handle, processorOutput, iter->second.config, outputInfo);
190 if (ret != RETCODE_SUCCESS) {
191 HILOGE("[KWSPlugin]SyncProcess MakeInference failed");
192 return RETCODE_FAILURE;
193 }
194 response = IResponse::Create(request);
195 response->SetResult(outputInfo);
196 return RETCODE_SUCCESS;
197 }
198
GetNormedFeatures(const Array<uint16_t> & input,Array<int32_t> & output,const KWSWorkplace & worker)199 int32_t KWSPlugin::GetNormedFeatures(const Array<uint16_t> &input, Array<int32_t> &output, const KWSWorkplace &worker)
200 {
201 FeatureData inputData = {
202 .dataType = UINT16,
203 .data = static_cast<void *>(input.data),
204 .size = input.size
205 };
206 FeatureData normedOutput = {
207 .dataType = FLOAT,
208 .data = nullptr,
209 .size = 0
210 };
211 int32_t retCode = worker.normProcessor->Process(inputData, normedOutput);
212 if (retCode != RETCODE_SUCCESS) {
213 HILOGE("[KWSPlugin]Fail to get nomred output via normProcessor");
214 return RETCODE_FAILURE;
215 }
216 FeatureData convertedOutput = {
217 .dataType = INT32,
218 .data = nullptr,
219 .size = 0
220 };
221 retCode = worker.typeConverter->Process(normedOutput, convertedOutput);
222 if (retCode != RETCODE_SUCCESS) {
223 HILOGE("[KWSPlugin]Fail to convert normed output via typeConverter");
224 return RETCODE_FAILURE;
225 }
226 FeatureData slideOutput = {
227 .dataType = INT32,
228 .data = nullptr,
229 .size = 0
230 };
231 retCode = worker.slideProcessor->Process(convertedOutput, slideOutput);
232 if (retCode != RETCODE_SUCCESS) {
233 HILOGE("[KWSPlugin]Fail to get slided output via slideProcessor");
234 return RETCODE_FAILURE;
235 }
236 output.data = static_cast<int32_t *>(slideOutput.data);
237 output.size = slideOutput.size;
238 return RETCODE_SUCCESS;
239 }
240
AsyncProcess(IRequest * request,IPluginCallback * callback)241 int32_t KWSPlugin::AsyncProcess(IRequest *request, IPluginCallback *callback)
242 {
243 return RETCODE_SUCCESS;
244 }
245
SetOption(int optionType,const DataInfo & inputInfo)246 int32_t KWSPlugin::SetOption(int optionType, const DataInfo &inputInfo)
247 {
248 std::lock_guard<std::mutex> lock(mutex_);
249 if (inputInfo.data == nullptr) {
250 HILOGE("[KWSPlugin]SetOption inputInfo data is [NULL]");
251 return RETCODE_FAILURE;
252 }
253 int retCode = RETCODE_SUCCESS;
254 switch (optionType) {
255 default:
256 HILOGE("[KWSPlugin]SetOption optionType[%d] undefined", optionType);
257 break;
258 }
259 return retCode;
260 }
261
GetOption(int optionType,const DataInfo & inputInfo,DataInfo & outputInfo)262 int32_t KWSPlugin::GetOption(int optionType, const DataInfo &inputInfo, DataInfo &outputInfo)
263 {
264 std::lock_guard<std::mutex> lock(mutex_);
265 if (inputInfo.data == nullptr || inputInfo.length <= 0) {
266 HILOGE("[KWSPlugin]GetOption failed for empty inputInfo");
267 return RETCODE_FAILURE;
268 }
269 intptr_t handle = 0;
270 int32_t ret = EncdecFacade::ProcessDecode(inputInfo, handle);
271 if (ret != RETCODE_SUCCESS) {
272 HILOGE("[KWSPlugin]GetOption get handle from inputInfo failed");
273 return RETCODE_FAILURE;
274 }
275 const auto &iter = handles_.find(handle);
276 if (iter == handles_.end()) {
277 HILOGE("[KWSPlugin]GetOption no matched handle [%lld]", (long long)handle);
278 return RETCODE_FAILURE;
279 }
280 outputInfo.length = 0;
281 switch (optionType) {
282 case OPTION_GET_INPUT_SIZE:
283 return EncdecFacade::ProcessEncode(outputInfo, handle, iter->second.config.inputSize);
284 case OPTION_GET_OUTPUT_SIZE:
285 return EncdecFacade::ProcessEncode(outputInfo, handle, iter->second.config.outputSize);
286 default:
287 HILOGE("[KWSPlugin]GetOption optionType[%d] undefined", optionType);
288 return RETCODE_FAILURE;
289 }
290 return RETCODE_SUCCESS;
291 }
292
Release(bool isFullUnload,long long transactionId,const DataInfo & inputInfo)293 int32_t KWSPlugin::Release(bool isFullUnload, long long transactionId, const DataInfo &inputInfo)
294 {
295 if (adapter_ == nullptr) {
296 HILOGE("[KWSPlugin]The engine adapter has not been created");
297 return RETCODE_FAILURE;
298 }
299 HILOGI("[KWSPlugin]Begin to release, transactionId = %lld", transactionId);
300 intptr_t handle = 0;
301 int32_t ret = EncdecFacade::ProcessDecode(inputInfo, handle);
302 if (ret != RETCODE_SUCCESS) {
303 HILOGE("[KWSPlugin]UnSerializeHandle Failed");
304 return RETCODE_FAILURE;
305 }
306 std::lock_guard<std::mutex> lock(mutex_);
307 ret = adapter_->ReleaseHandle(handle);
308 if (ret != RETCODE_SUCCESS) {
309 HILOGE("[KWSPlugin]ReleaseHandle failed");
310 return RETCODE_FAILURE;
311 }
312 FreeHandle(handle);
313 if (isFullUnload) {
314 ret = adapter_->Deinit();
315 if (ret != RETCODE_SUCCESS) {
316 HILOGE("[KWSPlugin]Engine adapter deinit failed");
317 return RETCODE_FAILURE;
318 }
319 }
320 return RETCODE_SUCCESS;
321 }
322
FreeHandle(intptr_t handle)323 void KWSPlugin::FreeHandle(intptr_t handle)
324 {
325 const auto iter = handles_.find(handle);
326 if (iter != handles_.end()) {
327 (void)handles_.erase(iter);
328 }
329 }
330
InitComponents(KWSWorkplace & worker)331 int32_t KWSPlugin::InitComponents(KWSWorkplace &worker)
332 {
333 SlideWindowProcessorConfig slideWindowConfig;
334 slideWindowConfig.dataType = INT32;
335 slideWindowConfig.stepSize = DEFAULT_SLIDE_STEP_SIZE;
336 slideWindowConfig.windowSize = DEFAULT_SLIDE_WINDOW_SIZE;
337 TypeConverterConfig convertConfig;
338 convertConfig.dataType = INT32;
339 convertConfig.size = DEFAULT_SLIDE_STEP_SIZE;
340 NormProcessorConfig normConfig;
341 normConfig.meanFilePath = DEFAULT_NORM_MEAN_FILE_PATH;
342 normConfig.stdFilePath = DEFAULT_NORM_STD_FILE_PATH;
343 normConfig.numChannels = DEFAULT_NORM_NUM_CHANNELS;
344 normConfig.inputSize = DEFAULT_NORM_INPUT_SIZE;
345 normConfig.scale = DEFAULT_NORM_SCALE;
346 if (InitWorkplace(worker, slideWindowConfig, convertConfig, normConfig) != RETCODE_SUCCESS) {
347 HILOGE("[KWSPlugin]Fail to init workplace");
348 worker.slideProcessor = nullptr;
349 worker.typeConverter = nullptr;
350 worker.normProcessor = nullptr;
351 return RETCODE_FAILURE;
352 }
353 return RETCODE_SUCCESS;
354 }
355
BuildConfig(intptr_t handle,PluginConfig & config)356 int32_t KWSPlugin::BuildConfig(intptr_t handle, PluginConfig &config)
357 {
358 if (adapter_ == nullptr) {
359 return RETCODE_NULL_PARAM;
360 }
361 int32_t retcode = adapter_->GetInputAddr(handle, MODEL_INPUT_NODE_ID, config.inputAddr, config.inputSize);
362 if (retcode != RETCODE_SUCCESS) {
363 HILOGE("[KWSPlugin]NNIEAdapter GetInputAddr failed with [%d]", retcode);
364 return RETCODE_NULL_PARAM;
365 }
366 retcode = adapter_->GetOutputAddr(handle, MODEL_OUTPUT_NODE_ID, config.outputAddr, config.outputSize);
367 if (retcode != RETCODE_SUCCESS) {
368 HILOGE("[KWSPlugin]NNIEAdapter GetOutputAddr failed with [%d]", retcode);
369 return RETCODE_NULL_PARAM;
370 }
371 return RETCODE_SUCCESS;
372 }
373
MakeInference(intptr_t handle,Array<int32_t> & input,PluginConfig & config,DataInfo & outputInfo)374 int32_t KWSPlugin::MakeInference(intptr_t handle, Array<int32_t> &input, PluginConfig &config, DataInfo &outputInfo)
375 {
376 HILOGI("[KWSPlugin]start with handle = %lld", (long long)handle);
377 if (adapter_ == nullptr || config.inputAddr == EMPTY_UINTPTR || config.outputAddr == EMPTY_UINTPTR) {
378 HILOGE("[KWSPlugin]MakeInference inference engine is not ready");
379 return RETCODE_NULL_PARAM;
380 }
381 if (input.data == nullptr || input.size != config.inputSize) {
382 HILOGE("[KWSPlugin]The input size is not equal to the size of model input");
383 return RETCODE_FAILURE;
384 }
385 int32_t *inputData = reinterpret_cast<int32_t *>(config.inputAddr);
386 size_t bufferSize = config.inputSize * sizeof(input.data[0]);
387 errno_t retCode = memcpy_s(inputData, bufferSize, input.data, bufferSize);
388 if (retCode != EOK) {
389 HILOGE("[KWSPlugin]MakeInference memory copy failed");
390 return RETCODE_NULL_PARAM;
391 }
392 int32_t ret = adapter_->Invoke(handle);
393 if (ret != RETCODE_SUCCESS) {
394 HILOGE("[KWSPlugin]MakeInference failed");
395 return RETCODE_FAILURE;
396 }
397 Array<int32_t> result;
398 result.size = config.outputSize;
399 result.data = reinterpret_cast<int32_t *>(config.outputAddr);
400 return EncdecFacade::ProcessEncode(outputInfo, handle, result);
401 }
402
403 PLUGIN_INTERFACE_IMPL(KWSPlugin);