1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "model_manager.h"
17 
18 #include <dlfcn.h>
19 
20 #include "directory_ex.h"
21 
22 #include "config_data_manager.h"
23 #include "security_guard_log.h"
24 #include "model_manager_impl.h"
25 #include "database_manager.h"
26 #include "security_guard_define.h"
27 
28 namespace OHOS::Security::SecurityGuard {
29 std::shared_ptr<IModelManager> ModelManager::modelManagerApi_ = std::make_shared<ModelManagerImpl>();
30 
31 namespace {
32     constexpr const char *PREFIX_MODEL_PATH = "/system/lib";
33     constexpr uint32_t AUDIT_MODEL = 3001000003;
34 }
35 
Init()36 void ModelManager::Init()
37 {
38     std::vector<uint32_t> modelIds = ConfigDataManager::GetInstance().GetAllModelIds();
39     ModelCfg cfg;
40     for (uint32_t modelId : modelIds) {
41         bool success = ConfigDataManager::GetInstance().GetModelConfig(modelId, cfg);
42         if (!success) {
43             continue;
44         }
45         SGLOGI("modelId is %{public}u, start_mode: %{public}u", modelId, cfg.startMode);
46         if (cfg.startMode != START_ON_STARTUP) {
47             continue;
48         }
49         if (cfg.modelId != AUDIT_MODEL) {
50             (void) InitModel(modelId);
51             continue;
52         }
53     }
54 }
55 
InitModel(uint32_t modelId)56 int32_t ModelManager::InitModel(uint32_t modelId)
57 {
58     std::unordered_map<uint32_t, std::unique_ptr<ModelAttrs>>::iterator iter;
59     {
60         std::lock_guard<std::mutex> lock(mutex_);
61         iter = modelIdApiMap_.find(modelId);
62         if (iter != modelIdApiMap_.end() && iter->second != nullptr && iter->second->GetModelApi() != nullptr) {
63             iter->second->GetModelApi()->Release();
64             modelIdApiMap_.erase(iter);
65         }
66     }
67 
68     ModelCfg cfg;
69     bool success = ConfigDataManager::GetInstance().GetModelConfig(modelId, cfg);
70     if (!success) {
71         SGLOGE("the model not support, modelId=%{public}u", modelId);
72         return NOT_FOUND;
73     }
74     std::string realPath;
75     if (!PathToRealPath(cfg.path, realPath) || realPath.find(PREFIX_MODEL_PATH) != 0) {
76         return FILE_ERR;
77     }
78     void *handle = dlopen(realPath.c_str(), RTLD_LAZY);
79     if (handle == nullptr) {
80         SGLOGE("modelId=%{public}u, open failed, reason:%{public}s", modelId, dlerror());
81         return FAILED;
82     }
83     std::unique_ptr<ModelAttrs> attr = std::make_unique<ModelAttrs>();
84     attr->SetHandle(handle);
85     auto getModelApi = (GetModelApi)dlsym(handle, "GetModelApi");
86     if (getModelApi == nullptr) {
87         SGLOGE("get model api func is nullptr");
88         return FAILED;
89     }
90     IModel *api = getModelApi();
91     if (api == nullptr) {
92         SGLOGE("get model api is nullptr");
93         return FAILED;
94     }
95     attr->SetModelApi(api);
96     int32_t ret = attr->GetModelApi()->Init(modelManagerApi_);
97     if (ret != SUCCESS) {
98         SGLOGE("model api init failed, ret=%{public}d", ret);
99         return ret;
100     }
101     {
102         std::lock_guard<std::mutex> lock(mutex_);
103         modelIdApiMap_[modelId] = std::move(attr);
104     }
105     SGLOGI("init model success, modelId=%{public}u", modelId);
106     return SUCCESS;
107 }
108 
GetResult(uint32_t modelId,const std::string & param)109 std::string ModelManager::GetResult(uint32_t modelId, const std::string &param)
110 {
111     std::string result = "unknown";
112     int32_t ret = InitModel(modelId);
113     if (ret != SUCCESS) {
114         return result;
115     }
116 
117     {
118         std::lock_guard<std::mutex> lock(mutex_);
119         auto iter = modelIdApiMap_.find(modelId);
120         if (iter == modelIdApiMap_.end() || iter->second == nullptr || iter->second->GetModelApi() == nullptr) {
121             SGLOGI("the model has not been initialized, begin init, modelId=%{public}u", modelId);
122             return result;
123         }
124         result = iter->second->GetModelApi()->GetResult(modelId, param);
125     }
126     ModelCfg config;
127     bool success = ConfigDataManager::GetInstance().GetModelConfig(modelId, config);
128     if (success && config.startMode == START_ON_DEMAND) {
129         Release(modelId);
130     }
131     return result;
132 }
133 
SubscribeResult(uint32_t modelId,std::shared_ptr<IModelResultListener> listener)134 int32_t ModelManager::SubscribeResult(uint32_t modelId, std::shared_ptr<IModelResultListener> listener)
135 {
136     int32_t ret = InitModel(modelId);
137     if (ret != SUCCESS) {
138         return ret;
139     }
140 
141     std::lock_guard<std::mutex> lock(mutex_);
142     auto iter = modelIdApiMap_.find(modelId);
143     if (iter == modelIdApiMap_.end() || iter->second == nullptr || iter->second->GetModelApi() == nullptr) {
144         SGLOGI("the model has not been initialized, modelId=%{public}u", modelId);
145         return FAILED;
146     }
147 
148     return iter->second->GetModelApi()->SubscribeResult(listener);
149 }
150 
Release(uint32_t modelId)151 void ModelManager::Release(uint32_t modelId)
152 {
153     std::lock_guard<std::mutex> lock(mutex_);
154     auto iter = modelIdApiMap_.find(modelId);
155     if (iter == modelIdApiMap_.end()) {
156         SGLOGI("the model has not been initialized, modelId=%{public}u", modelId);
157         return;
158     }
159 
160     if (iter->second == nullptr || iter->second->GetModelApi() == nullptr) {
161         SGLOGI("the model attr is nullptr, modelId=%{public}u", modelId);
162         modelIdApiMap_.erase(iter);
163         return;
164     }
165 
166     iter->second->GetModelApi()->Release();
167     modelIdApiMap_.erase(iter);
168 }
169 } // namespace OHOS::Security::SecurityGuard
170