1 /*
2  * Copyright (c) 2022 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *    http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include <algorithm>
17 #include <gtest/gtest.h>
18 #include <iterator>
19 #include <pthread.h>
20 #include <random>
21 #include <rwlock.h>
22 #include <thread>
23 #include <unistd.h>
24 
25 #include "hks_api.h"
26 #include "hks_condition.h"
27 #include "hks_log.h"
28 
29 using namespace testing::ext;
30 namespace Unittest::HksUtilsConditionTest {
31 class HksConditionTest : public testing::Test {
32 public:
33     static void SetUpTestCase(void);
34 
35     static void TearDownTestCase(void);
36 
37     void SetUp();
38 
39     void TearDown();
40 };
41 
SetUpTestCase(void)42 void HksConditionTest::SetUpTestCase(void)
43 {
44 }
45 
TearDownTestCase(void)46 void HksConditionTest::TearDownTestCase(void)
47 {
48 }
49 
SetUp()50 void HksConditionTest::SetUp()
51 {
52 }
53 
TearDown()54 void HksConditionTest::TearDown()
55 {
56 }
57 
58 /**
59  * @tc.name: HksConditionTest.HksConditionTest001
60  * @tc.desc: tdd HksConditionWait, with nullptr input, expect -1
61  * @tc.type: FUNC
62  */
63 HWTEST_F(HksConditionTest, HksConditionTest001, TestSize.Level0)
64 {
65     HKS_LOG_I("enter HksConditionTest001");
66     int32_t ret = HksConditionWait(nullptr);
67     EXPECT_EQ(ret, -1) << "HksConditionTest001 failed, ret = " << ret;
68 }
69 
NotifyCondition(HksCondition * condition)70 void NotifyCondition(HksCondition *condition)
71 {
72     sleep(1);
73     int32_t ret = HksConditionNotify(condition);
74     if (ret != HKS_SUCCESS) {
75         HKS_LOG_E("HksConditionNotify failed, ret = %" LOG_PUBLIC "d", ret);
76     }
77 }
78 
79 /**
80  * @tc.name: HksConditionTest.HksConditionTest002
81  * @tc.desc: tdd HksConditionWait, with notified false, expecting HKS_SUCCESS
82  * @tc.type: FUNC
83  */
84 HWTEST_F(HksConditionTest, HksConditionTest002, TestSize.Level0)
85 {
86     HKS_LOG_I("enter HksConditionTest002");
87     HksCondition *condition = HksConditionCreate();
88     EXPECT_NE(condition, nullptr) << "HksConditionCreate failed";
89     std::thread thObj(NotifyCondition, condition);
90     int32_t ret = HksConditionWait(condition);
91     EXPECT_EQ(ret, HKS_SUCCESS) << "HksConditionTest002 failed, ret = " << ret;
92     HksConditionDestroy(condition);
93     thObj.join();
94 }
95 
96 /**
97  * @tc.name: HksConditionTest.HksConditionTest003
98  * @tc.desc: tdd HksConditionWait, with notified true, expecting HKS_SUCCESS
99  * @tc.type: FUNC
100  */
101 HWTEST_F(HksConditionTest, HksConditionTest003, TestSize.Level0)
102 {
103     HKS_LOG_I("enter HksConditionTest003");
104     HksCondition *condition = HksConditionCreate();
105     int32_t ret = HksConditionNotify(condition);
106     EXPECT_EQ(ret, HKS_SUCCESS) << "HksConditionTest003 failed, ret = " << ret;
107     ret = HksConditionWait(condition);
108     EXPECT_EQ(ret, HKS_SUCCESS) << "HksConditionTest003 failed, ret = " << ret;
109     HksConditionDestroy(condition);
110 }
111 
112 /**
113  * @tc.name: HksConditionTest.HksConditionTest004
114  * @tc.desc: tdd HksConditionNotify, with waited false, expecting HKS_SUCCESS
115  * @tc.type: FUNC
116  */
117 HWTEST_F(HksConditionTest, HksConditionTest004, TestSize.Level0)
118 {
119     HKS_LOG_I("enter HksConditionTest004");
120     HksCondition *condition = HksConditionCreate();
121     int32_t ret = HksConditionNotify(condition);
122     EXPECT_EQ(ret, HKS_SUCCESS) << "HksConditionTest004 failed, ret = " << ret;
123     ret = HksConditionNotify(condition);
124     EXPECT_EQ(ret, HKS_SUCCESS) << "HksConditionTest004 failed, ret = " << ret;
125     HksConditionDestroy(condition);
126 }
127 
128 /**
129  * @tc.name: HksConditionTest.HksConditionTest005
130  * @tc.desc: tdd HksConditionDestroy, with nullptr input
131  * @tc.type: FUNC
132  */
133 HWTEST_F(HksConditionTest, HksConditionTest005, TestSize.Level0)
134 {
135     HKS_LOG_I("enter HksConditionTest005");
136     HksConditionDestroy(nullptr);
137     int32_t ret = HksInitialize();
138     EXPECT_EQ(ret, HKS_SUCCESS);
139 }
140 
WaitThread(void * p)141 static void *WaitThread(void *p)
142 {
143     EXPECT_EQ(HksConditionWait(static_cast<HksCondition *>(p)), HKS_SUCCESS);
144     return nullptr;
145 }
146 
147 /**
148  * @tc.name: HksConditionTest.HksConditionTest006
149  * @tc.desc: case fail if stuck
150  * @tc.type: FUNC
151  */
152 HWTEST_F(HksConditionTest, HksConditionTest006, TestSize.Level0)
153 {
154     HKS_LOG_I("enter HksConditionTest006");
155     enum {
156         TSET_THREADS_COUNT = 10,
157         TEST_TIMES = 100,
158     };
159     for (int no = 0; no < TEST_TIMES; ++no) {
160         HksCondition *condition = HksConditionCreate();
161         EXPECT_NE(condition, nullptr);
162         pthread_t threads[TSET_THREADS_COUNT] {};
163         for (int i = 0; i < TSET_THREADS_COUNT; ++i) {
164             EXPECT_EQ(pthread_create(&threads[i], nullptr, WaitThread, condition), 0);
165         }
166         EXPECT_EQ(HksConditionNotifyAll(condition), HKS_SUCCESS);
167         for (int i = 0; i < TSET_THREADS_COUNT; ++i) {
168             EXPECT_EQ(pthread_join(threads[i], nullptr), 0);
169         }
170         HksConditionDestroy(condition);
171     }
172 }
173 
Fib(uint64_t n)174 uint64_t Fib(uint64_t n)
175 {
176     if (n < 1) {
177         return 0;
178     }
179     enum {
180         FIB_START_INDEX = 2,
181     };
182     if (n <= FIB_START_INDEX) {
183         return 1;
184     }
185     int current = 1;
186     int previous = 0;
187     for (uint64_t i = 1; i < n; ++i) {
188         int next = current + previous;
189         previous = current;
190         current = next;
191     }
192     return current;
193 }
194 
TimeConsumingWork(uint64_t repeatTimes,uint64_t fibNumber)195 static uint64_t TimeConsumingWork(uint64_t repeatTimes, uint64_t fibNumber)
196 {
197     uint64_t sum = 0;
198     for (uint64_t i = 0; i < repeatTimes; ++i) {
199         sum += Fib(fibNumber);
200     }
201     if (repeatTimes == 0) {
202         return Fib(fibNumber);
203     }
204     // avoid divisor zero
205     return sum / repeatTimes;
206 }
207 
208 static OHOS::Utils::RWLock g_rwLock(true);
209 
OnStartTest(void * p)210 static void *OnStartTest(void *p)
211 {
212     {
213         OHOS::Utils::UniqueWriteGuard<OHOS::Utils::RWLock> writeGuard(g_rwLock);
214         enum {
215             CALC_FIB_TIMES = 10'000'000,
216             CALC_FIB_NUMBER = 40,
217         };
218         // huks sa service start, upgrade keys
219         EXPECT_EQ(TimeConsumingWork(CALC_FIB_TIMES, CALC_FIB_NUMBER), Fib(CALC_FIB_NUMBER));
220     }
221     EXPECT_EQ(HksConditionNotifyAll(static_cast<HksCondition *>(p)), HKS_SUCCESS);
222     return nullptr;
223 }
224 
HksUpgradeOnUserUnlockTest()225 static void HksUpgradeOnUserUnlockTest()
226 {
227     g_rwLock.UnLockRead();
228 
229     {
230         OHOS::Utils::UniqueWriteGuard<OHOS::Utils::RWLock> writeGuard(g_rwLock);
231         enum {
232             CALC_FIB_TIMES = 10'000'000,
233             CALC_FIB_NUMBER = 40,
234         };
235         // upgrade keys in case that user unlocked, or first time someone using credential-encrypted level key.
236         EXPECT_EQ(TimeConsumingWork(CALC_FIB_TIMES, CALC_FIB_NUMBER), Fib(CALC_FIB_NUMBER));
237     }
238 
239     g_rwLock.LockRead();
240 }
241 
242 static volatile std::atomic_bool g_isCeUpgradeSucc = false;
243 
OnRemoteRequestTest(void * p)244 static void *OnRemoteRequestTest(void *p)
245 {
246     EXPECT_EQ(HksConditionWait(static_cast<HksCondition *>(p)), HKS_SUCCESS);
247     OHOS::Utils::UniqueReadGuard<OHOS::Utils::RWLock> readGuard(g_rwLock);
248     enum {
249         CALC_FIB_TIMES = 1'000'000,
250         CALC_FIB_NUMBER = 40,
251     };
252     // someone is invoking huks
253     EXPECT_EQ(TimeConsumingWork(CALC_FIB_TIMES, CALC_FIB_NUMBER), Fib(CALC_FIB_NUMBER));
254 
255     enum {
256         IF_STORAGE_LEVEL_IS_CE = 2,
257     };
258     // someone is invoking huks for credential-encrypted level key
259     if (std::rand() % IF_STORAGE_LEVEL_IS_CE) {
260         bool flag = false;
261         if (std::atomic_compare_exchange_strong(&g_isCeUpgradeSucc, &flag, true)) {
262             HksUpgradeOnUserUnlockTest();
263         }
264     }
265     return nullptr;
266 }
267 
OnReceiveEventTest(void * p)268 static void *OnReceiveEventTest(void *p)
269 {
270     EXPECT_EQ(HksConditionWait(static_cast<HksCondition *>(p)), HKS_SUCCESS);
271     OHOS::Utils::UniqueReadGuard<OHOS::Utils::RWLock> readGuard(g_rwLock);
272     HksUpgradeOnUserUnlockTest();
273     return nullptr;
274 }
275 
276 /**
277  * @tc.name: HksConditionTest.HksConditionTest007
278  * @tc.desc: case fail if stuck
279  * @tc.type: FUNC
280  */
281 HWTEST_F(HksConditionTest, HksConditionTest007, TestSize.Level0)
282 {
283     HKS_LOG_I("enter HksConditionTest007");
284     enum {
285         TEST_THREADS_COUNT = 20,
286         TEST_TIMES = 100,
287         TEST_ON_RECEIVE_EVENT_THREADS_COUNT = 5,
288     };
289     for (int no = 0; no < TEST_TIMES; ++no) {
290         HksCondition *condition = HksConditionCreate();
291         EXPECT_NE(condition, nullptr);
292 
293         void *(*functions[TEST_THREADS_COUNT])(void *) {};
294         // 1 :> OnStartTest, TEST_ON_RECEIVE_EVENT_THREADS_COUNT :> OnReceiveEventTest, others :> OnRemoteRequestTest
295         std::fill(std::begin(functions), std::end(functions), OnRemoteRequestTest);
296         std::fill_n(std::begin(functions), TEST_ON_RECEIVE_EVENT_THREADS_COUNT, OnReceiveEventTest);
297         functions[TEST_THREADS_COUNT - 1] = OnStartTest;
298 
299         std::random_device rd;
300         std::mt19937 g(rd());
301         std::shuffle(std::begin(functions), std::end(functions), g);
302 
303         pthread_t threads[TEST_THREADS_COUNT] {};
304         for (int i = 0; i < TEST_THREADS_COUNT; ++i) {
305             EXPECT_EQ(pthread_create(&threads[i], nullptr, functions[i], condition), 0);
306         }
307 
308         for (int i = 0; i < TEST_THREADS_COUNT; ++i) {
309             EXPECT_EQ(pthread_join(threads[i], nullptr), 0);
310         }
311         HksConditionDestroy(condition);
312     }
313 }
314 }
315