1 /*
2  * Copyright (c) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "kws_plugin.h"
17 
18 #include "aie_log.h"
19 #include "aie_retcode_inner.h"
20 #include "encdec_facade.h"
21 #include "norm_processor.h"
22 #include "plugin_helper.h"
23 #include "slide_window_processor.h"
24 #include "type_converter.h"
25 
26 #ifdef USE_NNIE
27 #include "nnie_adapter.h"
28 #endif
29 
30 using namespace OHOS::AI;
31 using namespace OHOS::AI::Feature;
32 namespace {
33     const std::string PLUGIN_MODEL_PATH = "/storage/data/keyword_spotting.wk";
34     const std::string DEFAULT_INFER_MODE = "SYNC";
35     const std::string ALGORITHM_NAME_KWS = "KWS";
36     const int32_t OPTION_GET_INPUT_SIZE = 1001;
37     const int32_t OPTION_GET_OUTPUT_SIZE = 1002;
38     const intptr_t EMPTY_UINTPTR = 0;
39     const int32_t MODEL_INPUT_NODE_ID = 0;
40     const int32_t MODEL_OUTPUT_NODE_ID = 0;
41 }
42 
InitWorkplace(KWSWorkplace & worker,SlideWindowProcessorConfig & slideCfg,TypeConverterConfig & convertCfg,NormProcessorConfig & normCfg)43 static int32_t InitWorkplace(KWSWorkplace &worker, SlideWindowProcessorConfig &slideCfg,
44     TypeConverterConfig &convertCfg, NormProcessorConfig &normCfg)
45 {
46     worker.slideProcessor = std::make_shared<SlideWindowProcessor>();
47     worker.typeConverter = std::make_shared<TypeConverter>();
48     worker.normProcessor = std::make_shared<NormProcessor>();
49     if (worker.slideProcessor == nullptr ||
50         worker.typeConverter == nullptr ||
51         worker.normProcessor == nullptr) {
52         HILOGE("[KWSPlugin]Fail to allocate workplaces");
53         return RETCODE_FAILURE;
54     }
55     if (worker.slideProcessor->Init(&slideCfg) != RETCODE_SUCCESS) {
56         HILOGE("[KWSPlugin]Fail to init slideProcessor");
57         return RETCODE_FAILURE;
58     }
59     if (worker.typeConverter->Init(&convertCfg) != RETCODE_SUCCESS) {
60         HILOGE("[KWSPlugin]Fail to init typeConverter");
61         return RETCODE_FAILURE;
62     }
63     if (worker.normProcessor->Init(&normCfg) != RETCODE_SUCCESS) {
64         HILOGE("[KWSPlugin]Fail to init normConfig");
65         return RETCODE_FAILURE;
66     }
67     return RETCODE_SUCCESS;
68 }
69 
KWSPlugin()70 KWSPlugin::KWSPlugin()
71 {
72     HILOGD("[KWSPlugin]ctor");
73     handles_.clear();
74 }
75 
~KWSPlugin()76 KWSPlugin::~KWSPlugin()
77 {
78     ReleaseAllHandles();
79     HILOGD("[KWSPlugin]dtor");
80 }
81 
ReleaseAllHandles()82 void KWSPlugin::ReleaseAllHandles()
83 {
84     for (auto iter = handles_.begin(); iter != handles_.end(); ++iter) {
85         (void)adapter_->ReleaseHandle(iter->first);
86     }
87     adapter_->Deinit();
88     handles_.clear();
89 }
90 
Prepare(long long transactionId,const DataInfo & inputInfo,DataInfo & outputInfo)91 int32_t KWSPlugin::Prepare(long long transactionId, const DataInfo &inputInfo, DataInfo &outputInfo)
92 {
93     HILOGI("[KWSPlugin]Begin to prepare, transactionId = %lld", transactionId);
94     std::lock_guard<std::mutex> lock(mutex_);
95     if (adapter_ ==  nullptr) {
96 #ifdef USE_NNIE
97         adapter_ = std::make_shared<NNIEAdapter>();
98 #endif
99         if (adapter_ == nullptr) {
100             HILOGE("[KWSPlugin]Fail to create engine adapter");
101             return RETCODE_FAILURE;
102         }
103     }
104     intptr_t handle = 0;
105     if (adapter_->Init(PLUGIN_MODEL_PATH.c_str(), handle) != RETCODE_SUCCESS) {
106         HILOGE("[KWSPlugin]NNIEAdapterInit failed");
107         return RETCODE_FAILURE;
108     }
109     const auto iter = handles_.find(handle);
110     if (iter != handles_.end()) {
111         HILOGE("[KWSPlugin]handle=%lld has already existed", (long long)handle);
112         return RETCODE_SUCCESS;
113     }
114     PluginConfig config;
115     if (BuildConfig(handle, config) != RETCODE_SUCCESS) {
116         HILOGE("[KWSPlugin]BuildConfig failed");
117         return RETCODE_FAILURE;
118     }
119     KWSWorkplace worker = {
120         .config = config,
121         .normProcessor = nullptr,
122         .typeConverter = nullptr,
123         .slideProcessor = nullptr
124     };
125     if (InitComponents(worker) != RETCODE_SUCCESS) {
126         HILOGE("[KWSPlugin]InitComponents failed");
127         return RETCODE_FAILURE;
128     }
129     handles_.emplace(handle, worker);
130     return EncdecFacade::ProcessEncode(outputInfo, handle);
131 }
132 
GetVersion() const133 const long long KWSPlugin::GetVersion() const
134 {
135     return ALGOTYPE_VERSION_KWS;
136 }
137 
GetName() const138 const char *KWSPlugin::GetName() const
139 {
140     return ALGORITHM_NAME_KWS.c_str();
141 }
142 
GetInferMode() const143 const char *KWSPlugin::GetInferMode() const
144 {
145     return DEFAULT_INFER_MODE.c_str();
146 }
147 
SyncProcess(IRequest * request,IResponse * & response)148 int32_t KWSPlugin::SyncProcess(IRequest *request, IResponse *&response)
149 {
150     HILOGI("[KWSPlugin]SyncProcess start");
151     std::lock_guard<std::mutex> lock(mutex_);
152     if (request == nullptr) {
153         HILOGE("[KWSPlugin]SyncProcess request is nullptr");
154         return RETCODE_NULL_PARAM;
155     }
156     DataInfo inputInfo = request->GetMsg();
157     if (inputInfo.data == nullptr || inputInfo.length <= 0) {
158         HILOGE("[KWSPlugin]SyncProcess inputInfo data is nullptr");
159         return RETCODE_NULL_PARAM;
160     }
161     intptr_t handle = 0;
162     Array<uint16_t> audioInput = {
163         .data = nullptr,
164         .size = 0
165     };
166     int32_t ret = EncdecFacade::ProcessDecode(inputInfo, handle, audioInput);
167     if (ret != RETCODE_SUCCESS) {
168         HILOGE("[KWSPlugin]SyncProcess load inputData failed");
169         return RETCODE_FAILURE;
170     }
171     const auto iter = handles_.find(handle);
172     if (iter == handles_.end()) {
173         HILOGE("[KWSPlugin]SyncProcess no matched handle [%lld]", (long long)handle);
174         return RETCODE_NULL_PARAM;
175     }
176     Array<int32_t> processorOutput = {
177         .data = nullptr,
178         .size = 0
179     };
180     ret = GetNormedFeatures(audioInput, processorOutput, iter->second);
181     if (ret != RETCODE_SUCCESS) {
182         HILOGE("[KWSPlugin]Fail to get normed features");
183         return RETCODE_FAILURE;
184     }
185     DataInfo outputInfo = {
186         .data = nullptr,
187         .length = 0
188     };
189     ret = MakeInference(handle, processorOutput, iter->second.config, outputInfo);
190     if (ret != RETCODE_SUCCESS) {
191         HILOGE("[KWSPlugin]SyncProcess MakeInference failed");
192         return RETCODE_FAILURE;
193     }
194     response = IResponse::Create(request);
195     response->SetResult(outputInfo);
196     return RETCODE_SUCCESS;
197 }
198 
GetNormedFeatures(const Array<uint16_t> & input,Array<int32_t> & output,const KWSWorkplace & worker)199 int32_t KWSPlugin::GetNormedFeatures(const Array<uint16_t> &input, Array<int32_t> &output, const KWSWorkplace &worker)
200 {
201     FeatureData inputData = {
202         .dataType = UINT16,
203         .data = static_cast<void *>(input.data),
204         .size = input.size
205     };
206     FeatureData normedOutput = {
207         .dataType = FLOAT,
208         .data = nullptr,
209         .size = 0
210     };
211     int32_t retCode = worker.normProcessor->Process(inputData, normedOutput);
212     if (retCode != RETCODE_SUCCESS) {
213         HILOGE("[KWSPlugin]Fail to get nomred output via normProcessor");
214         return RETCODE_FAILURE;
215     }
216     FeatureData convertedOutput = {
217         .dataType = INT32,
218         .data = nullptr,
219         .size = 0
220     };
221     retCode = worker.typeConverter->Process(normedOutput, convertedOutput);
222     if (retCode != RETCODE_SUCCESS) {
223         HILOGE("[KWSPlugin]Fail to convert normed output via typeConverter");
224         return RETCODE_FAILURE;
225     }
226     FeatureData slideOutput = {
227         .dataType = INT32,
228         .data = nullptr,
229         .size = 0
230     };
231     retCode = worker.slideProcessor->Process(convertedOutput, slideOutput);
232     if (retCode != RETCODE_SUCCESS) {
233         HILOGE("[KWSPlugin]Fail to get slided output via slideProcessor");
234         return RETCODE_FAILURE;
235     }
236     output.data = static_cast<int32_t *>(slideOutput.data);
237     output.size = slideOutput.size;
238     return RETCODE_SUCCESS;
239 }
240 
AsyncProcess(IRequest * request,IPluginCallback * callback)241 int32_t KWSPlugin::AsyncProcess(IRequest *request, IPluginCallback *callback)
242 {
243     return RETCODE_SUCCESS;
244 }
245 
SetOption(int optionType,const DataInfo & inputInfo)246 int32_t KWSPlugin::SetOption(int optionType, const DataInfo &inputInfo)
247 {
248     std::lock_guard<std::mutex> lock(mutex_);
249     if (inputInfo.data == nullptr) {
250         HILOGE("[KWSPlugin]SetOption inputInfo data is [NULL]");
251         return RETCODE_FAILURE;
252     }
253     int retCode = RETCODE_SUCCESS;
254     switch (optionType) {
255         default:
256             HILOGE("[KWSPlugin]SetOption optionType[%d] undefined", optionType);
257             break;
258     }
259     return retCode;
260 }
261 
GetOption(int optionType,const DataInfo & inputInfo,DataInfo & outputInfo)262 int32_t KWSPlugin::GetOption(int optionType, const DataInfo &inputInfo, DataInfo &outputInfo)
263 {
264     std::lock_guard<std::mutex> lock(mutex_);
265     if (inputInfo.data == nullptr || inputInfo.length <= 0) {
266         HILOGE("[KWSPlugin]GetOption failed for empty inputInfo");
267         return RETCODE_FAILURE;
268     }
269     intptr_t handle = 0;
270     int32_t ret = EncdecFacade::ProcessDecode(inputInfo, handle);
271     if (ret != RETCODE_SUCCESS) {
272         HILOGE("[KWSPlugin]GetOption get handle from inputInfo failed");
273         return RETCODE_FAILURE;
274     }
275     const auto &iter = handles_.find(handle);
276     if (iter == handles_.end()) {
277         HILOGE("[KWSPlugin]GetOption no matched handle [%lld]", (long long)handle);
278         return RETCODE_FAILURE;
279     }
280     outputInfo.length = 0;
281     switch (optionType) {
282         case OPTION_GET_INPUT_SIZE:
283             return EncdecFacade::ProcessEncode(outputInfo, handle, iter->second.config.inputSize);
284         case OPTION_GET_OUTPUT_SIZE:
285             return EncdecFacade::ProcessEncode(outputInfo, handle, iter->second.config.outputSize);
286         default:
287             HILOGE("[KWSPlugin]GetOption optionType[%d] undefined", optionType);
288             return RETCODE_FAILURE;
289     }
290     return RETCODE_SUCCESS;
291 }
292 
Release(bool isFullUnload,long long transactionId,const DataInfo & inputInfo)293 int32_t KWSPlugin::Release(bool isFullUnload, long long transactionId, const DataInfo &inputInfo)
294 {
295     if (adapter_ == nullptr) {
296         HILOGE("[KWSPlugin]The engine adapter has not been created");
297         return RETCODE_FAILURE;
298     }
299     HILOGI("[KWSPlugin]Begin to release, transactionId = %lld", transactionId);
300     intptr_t handle = 0;
301     int32_t ret = EncdecFacade::ProcessDecode(inputInfo, handle);
302     if (ret != RETCODE_SUCCESS) {
303         HILOGE("[KWSPlugin]UnSerializeHandle Failed");
304         return RETCODE_FAILURE;
305     }
306     std::lock_guard<std::mutex> lock(mutex_);
307     ret = adapter_->ReleaseHandle(handle);
308     if (ret != RETCODE_SUCCESS) {
309         HILOGE("[KWSPlugin]ReleaseHandle failed");
310         return RETCODE_FAILURE;
311     }
312     FreeHandle(handle);
313     if (isFullUnload) {
314         ret = adapter_->Deinit();
315         if (ret != RETCODE_SUCCESS) {
316             HILOGE("[KWSPlugin]Engine adapter deinit failed");
317             return RETCODE_FAILURE;
318         }
319     }
320     return RETCODE_SUCCESS;
321 }
322 
FreeHandle(intptr_t handle)323 void KWSPlugin::FreeHandle(intptr_t handle)
324 {
325     const auto iter = handles_.find(handle);
326     if (iter != handles_.end()) {
327         (void)handles_.erase(iter);
328     }
329 }
330 
InitComponents(KWSWorkplace & worker)331 int32_t KWSPlugin::InitComponents(KWSWorkplace &worker)
332 {
333     SlideWindowProcessorConfig slideWindowConfig;
334     slideWindowConfig.dataType = INT32;
335     slideWindowConfig.stepSize = DEFAULT_SLIDE_STEP_SIZE;
336     slideWindowConfig.windowSize = DEFAULT_SLIDE_WINDOW_SIZE;
337     TypeConverterConfig convertConfig;
338     convertConfig.dataType = INT32;
339     convertConfig.size = DEFAULT_SLIDE_STEP_SIZE;
340     NormProcessorConfig normConfig;
341     normConfig.meanFilePath = DEFAULT_NORM_MEAN_FILE_PATH;
342     normConfig.stdFilePath = DEFAULT_NORM_STD_FILE_PATH;
343     normConfig.numChannels = DEFAULT_NORM_NUM_CHANNELS;
344     normConfig.inputSize = DEFAULT_NORM_INPUT_SIZE;
345     normConfig.scale = DEFAULT_NORM_SCALE;
346     if (InitWorkplace(worker, slideWindowConfig, convertConfig, normConfig) != RETCODE_SUCCESS) {
347         HILOGE("[KWSPlugin]Fail to init workplace");
348         worker.slideProcessor = nullptr;
349         worker.typeConverter = nullptr;
350         worker.normProcessor = nullptr;
351         return RETCODE_FAILURE;
352     }
353     return RETCODE_SUCCESS;
354 }
355 
BuildConfig(intptr_t handle,PluginConfig & config)356 int32_t KWSPlugin::BuildConfig(intptr_t handle, PluginConfig &config)
357 {
358     if (adapter_ == nullptr) {
359         return RETCODE_NULL_PARAM;
360     }
361     int32_t retcode = adapter_->GetInputAddr(handle, MODEL_INPUT_NODE_ID, config.inputAddr, config.inputSize);
362     if (retcode != RETCODE_SUCCESS) {
363         HILOGE("[KWSPlugin]NNIEAdapter GetInputAddr failed with [%d]", retcode);
364         return RETCODE_NULL_PARAM;
365     }
366     retcode = adapter_->GetOutputAddr(handle, MODEL_OUTPUT_NODE_ID, config.outputAddr, config.outputSize);
367     if (retcode != RETCODE_SUCCESS) {
368         HILOGE("[KWSPlugin]NNIEAdapter GetOutputAddr failed with [%d]", retcode);
369         return RETCODE_NULL_PARAM;
370     }
371     return RETCODE_SUCCESS;
372 }
373 
MakeInference(intptr_t handle,Array<int32_t> & input,PluginConfig & config,DataInfo & outputInfo)374 int32_t KWSPlugin::MakeInference(intptr_t handle, Array<int32_t> &input, PluginConfig &config, DataInfo &outputInfo)
375 {
376     HILOGI("[KWSPlugin]start with handle = %lld", (long long)handle);
377     if (adapter_ == nullptr || config.inputAddr == EMPTY_UINTPTR || config.outputAddr == EMPTY_UINTPTR) {
378         HILOGE("[KWSPlugin]MakeInference inference engine is not ready");
379         return RETCODE_NULL_PARAM;
380     }
381     if (input.data == nullptr || input.size != config.inputSize) {
382         HILOGE("[KWSPlugin]The input size is not equal to the size of model input");
383         return RETCODE_FAILURE;
384     }
385     int32_t *inputData = reinterpret_cast<int32_t *>(config.inputAddr);
386     size_t bufferSize = config.inputSize * sizeof(input.data[0]);
387     errno_t retCode = memcpy_s(inputData, bufferSize, input.data, bufferSize);
388     if (retCode != EOK) {
389         HILOGE("[KWSPlugin]MakeInference memory copy failed");
390         return RETCODE_NULL_PARAM;
391     }
392     int32_t ret = adapter_->Invoke(handle);
393     if (ret != RETCODE_SUCCESS) {
394         HILOGE("[KWSPlugin]MakeInference failed");
395         return RETCODE_FAILURE;
396     }
397     Array<int32_t> result;
398     result.size = config.outputSize;
399     result.data = reinterpret_cast<int32_t *>(config.outputAddr);
400     return EncdecFacade::ProcessEncode(outputInfo, handle, result);
401 }
402 
403 PLUGIN_INTERFACE_IMPL(KWSPlugin);