1 /*
2  * Copyright (c) 2023 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 #include "distributeddb_tools_unit_test.h"
16 #include "mock_thread_pool.h"
17 
18 #include <gtest/gtest.h>
19 
20 #include "log_print.h"
21 #include "runtime_context.h"
22 
23 using namespace testing::ext;
24 using namespace testing;
25 using namespace DistributedDB;
26 using namespace DistributedDBUnitTest;
27 
28 namespace {
29 const int MAX_TIMER_COUNT = 5;
30 const int ONE_HUNDRED_MILLISECONDS = 100;
31 const int ONE_SECOND = 1000;
32 
33 class DistributedDBThreadPoolTest : public testing::Test {
34 public:
35     static void SetUpTestCase();
36     static void TearDownTestCase();
37     void SetUp() override;
38     void TearDown() override;
39     std::shared_ptr<MockThreadPool> threadPoolPtr_ = nullptr;
40     class TimerWatcher {
41     public:
42         TimerWatcher() = default;
43 
~TimerWatcher()44         ~TimerWatcher()
45         {
46             SafeExit();
47         }
48 
AddCount()49         void AddCount()
50         {
51             std::lock_guard<std::mutex> autoLock(countLock_);
52             count_++;
53         }
54 
DecCount()55         void DecCount()
56         {
57             {
58                 std::lock_guard<std::mutex> autoLock(countLock_);
59                 count_--;
60             }
61             countCv_.notify_one();
62         }
63 
SafeExit()64         void SafeExit()
65         {
66             std::unique_lock<std::mutex> uniqueLock(countLock_);
67             countCv_.wait(uniqueLock, [this]() {
68                 return count_ <= 0;
69             });
70         }
71     private:
72         std::mutex countLock_;
73         int count_ = 0;
74         std::condition_variable countCv_;
75     };
76 };
77 
SetUpTestCase()78 void DistributedDBThreadPoolTest::SetUpTestCase()
79 {
80 }
81 
TearDownTestCase()82 void DistributedDBThreadPoolTest::TearDownTestCase()
83 {
84 }
85 
SetUp()86 void DistributedDBThreadPoolTest::SetUp()
87 {
88     DistributedDBToolsUnitTest::PrintTestCaseInfo();
89     if (threadPoolPtr_ == nullptr) {
90         threadPoolPtr_ = std::make_shared<MockThreadPool>();
91     }
92     RuntimeContext::GetInstance()->SetThreadPool(std::dynamic_pointer_cast<IThreadPool>(threadPoolPtr_));
93 }
94 
TearDown()95 void DistributedDBThreadPoolTest::TearDown()
96 {
97     RuntimeContext::GetInstance()->SetThreadPool(nullptr);
98     threadPoolPtr_ = nullptr;
99 }
100 
CallScheduleTaskOnce()101 void CallScheduleTaskOnce()
102 {
103     int errCode = RuntimeContext::GetInstance()->ScheduleTask([]() {
104         LOGI("Task running");
105     });
106     EXPECT_EQ(errCode, E_OK);
107 }
108 
MockSchedule(const std::shared_ptr<MockThreadPool> & threadPoolPtr,const std::shared_ptr<DistributedDBThreadPoolTest::TimerWatcher> & watcher,std::atomic<TaskId> & taskId)109 void MockSchedule(const std::shared_ptr<MockThreadPool> &threadPoolPtr,
110     const std::shared_ptr<DistributedDBThreadPoolTest::TimerWatcher> &watcher, std::atomic<TaskId> &taskId)
111 {
112     ASSERT_NE(watcher, nullptr);
113     EXPECT_CALL(*threadPoolPtr, Execute(_, _)).
114         WillRepeatedly([watcher, &taskId](const Task &task, Duration time) {
115         watcher->AddCount();
116         std::thread workingThread = std::thread([task, time, watcher]() {
117             std::this_thread::sleep_for(time);
118             task();
119             watcher->DecCount();
120         });
121         workingThread.detach();
122         TaskId currentId = taskId++;
123         return currentId;
124     });
125 }
126 
MockRemove(const std::shared_ptr<MockThreadPool> & threadPoolPtr,bool removeRes,int & removeCount)127 void MockRemove(const std::shared_ptr<MockThreadPool> &threadPoolPtr, bool removeRes, int &removeCount)
128 {
129     EXPECT_CALL(*threadPoolPtr, Remove).WillRepeatedly([removeRes, &removeCount](const TaskId &taskId, bool) {
130         LOGI("Call remove task %" PRIu64, taskId);
131         removeCount++;
132         return removeRes;
133     });
134 }
135 
SetTimer(int & timerCount,TimerId & timer,int & finalizeCount,int timeOut=ONE_HUNDRED_MILLISECONDS)136 void SetTimer(int &timerCount, TimerId &timer, int &finalizeCount, int timeOut = ONE_HUNDRED_MILLISECONDS)
137 {
138     int errCode = RuntimeContext::GetInstance()->SetTimer(timeOut, [&timerCount](TimerId timerId) {
139         LOGI("Timer %" PRIu64 " running", timerId);
140         timerCount++;
141         if (timerCount < MAX_TIMER_COUNT) { // max timer count is 5
142             return E_OK;
143         }
144         return -E_END_TIMER;
145     }, [&finalizeCount, timer]() {
146         finalizeCount++;
147         LOGI("Timer %" PRIu64" finalize", timer + 1);
148     }, timer);
149     EXPECT_EQ(errCode, E_OK);
150 }
151 
SetTimer(int & timerCount,TimerId & timer)152 void SetTimer(int &timerCount, TimerId &timer)
153 {
154     int errCode = RuntimeContext::GetInstance()->SetTimer(ONE_HUNDRED_MILLISECONDS, [&timerCount](TimerId timerId) {
155         LOGI("Timer %" PRIu64 " running", timerId);
156         timerCount++;
157         if (timerCount < MAX_TIMER_COUNT) { // max timer count is 5
158             return E_OK;
159         }
160         return -E_END_TIMER;
161     }, nullptr, timer);
162     EXPECT_EQ(errCode, E_OK);
163 }
164 
165 /**
166  * @tc.name: ScheduleTask001
167  * @tc.desc: Test schedule task by thread pool
168  * @tc.type: FUNC
169  * @tc.require: AR000I0KU9
170  * @tc.author: zhangqiquan
171  */
172 HWTEST_F(DistributedDBThreadPoolTest, ScheduleTask001, TestSize.Level1)
173 {
174     /**
175      * @tc.steps: step1. set thread pool and schedule task
176      * @tc.expected: step1. thread pool execute task count is once.
177      */
178     ASSERT_NE(threadPoolPtr_, nullptr);
179     int callCount = 0;
180     std::thread workingThread;
__anonbed5c56d0a02(const Task &task) 181     EXPECT_CALL(*threadPoolPtr_, Execute(_)).WillRepeatedly([&callCount, &workingThread](const Task &task) {
182         callCount++;
183         workingThread = std::thread([task]() {
184             task();
185         });
186         return 1u; // task id is 1
187     });
188     ASSERT_NO_FATAL_FAILURE(CallScheduleTaskOnce());
189     if (workingThread.joinable()) {
190         workingThread.join();
191     }
192     EXPECT_EQ(callCount, 1);
193     /**
194      * @tc.steps: step2. reset thread pool and schedule task
195      * @tc.expected: step2. thread pool execute task count is once.
196      */
197     RuntimeContext::GetInstance()->SetThreadPool(nullptr);
198     callCount = 0;
199     ASSERT_NO_FATAL_FAILURE(CallScheduleTaskOnce());
200     if (workingThread.joinable()) {
201         workingThread.join();
202     }
203     EXPECT_EQ(callCount, 0);
204 }
205 
206 /**
207  * @tc.name: SetTimer001
208  * @tc.desc: Test set timer by thread pool
209  * @tc.type: FUNC
210  * @tc.require: AR000I0KU9
211  * @tc.author: zhangqiquan
212  */
213 HWTEST_F(DistributedDBThreadPoolTest, SetTimer001, TestSize.Level1)
214 {
215     ASSERT_NE(threadPoolPtr_, nullptr);
216     std::shared_ptr<TimerWatcher> watcher = std::make_shared<TimerWatcher>();
217     std::atomic<TaskId> currentId = 1;
218     ASSERT_NO_FATAL_FAILURE(MockSchedule(threadPoolPtr_, watcher, currentId));
219     /**
220      * @tc.steps: step1. set timer and record timer call count
221      * @tc.expected: step1. call count is MAX_TIMER_COUNT and finalize once.
222      */
223     int timerCount = 0;
224     TimerId timer = 0;
225     int finalizeCount = 0;
226     ASSERT_NO_FATAL_FAILURE(SetTimer(timerCount, timer, finalizeCount));
227     /**
228      * @tc.steps: step2. mock modify timer
229      * @tc.expected: step2. can call modify timer when timer is runnning.
230      */
231     std::this_thread::sleep_for(std::chrono::milliseconds(150)); // sleep 150ms
__anonbed5c56d0c02(const TaskId &id, Duration modifyTime) 232     EXPECT_CALL(*threadPoolPtr_, Reset).WillOnce([&currentId](const TaskId &id, Duration modifyTime) {
233         LOGI("call modify timer task is %" PRIu64, id);
234         EXPECT_EQ(id, currentId - 1);
235         Duration duration = std::chrono::duration_cast<std::chrono::steady_clock::duration>(
236             std::chrono::milliseconds(ONE_HUNDRED_MILLISECONDS));
237         EXPECT_EQ(modifyTime, duration);
238         return id;
239     });
240     EXPECT_EQ(RuntimeContext::GetInstance()->ModifyTimer(timer, ONE_HUNDRED_MILLISECONDS), E_OK);
241     /**
242      * @tc.steps: step3. wait timer finished
243      */
244     watcher->SafeExit();
245     EXPECT_EQ(timerCount, MAX_TIMER_COUNT);
246     EXPECT_EQ(finalizeCount, 1);
247 }
248 
249 /**
250  * @tc.name: SetTimer002
251  * @tc.desc: Test repeat setting timer
252  * @tc.type: FUNC
253  * @tc.require: AR000I0KU9
254  * @tc.author: zhangqiquan
255  */
256 HWTEST_F(DistributedDBThreadPoolTest, SetTimer002, TestSize.Level1)
257 {
258     ASSERT_NE(threadPoolPtr_, nullptr);
259     std::shared_ptr<TimerWatcher> watcher = std::make_shared<TimerWatcher>();
260     std::atomic<TaskId> currentId = 1;
261     ASSERT_NO_FATAL_FAILURE(MockSchedule(threadPoolPtr_, watcher, currentId));
262     /**
263      * @tc.steps: step1. set timer and record timer call count
264      * @tc.expected: step1. call count is MAX_TIMER_COUNT and finalize once.
265      */
266     int timerCountArray[MAX_TIMER_COUNT] = {};
267     int finalizeCountArray[MAX_TIMER_COUNT] = {};
268     for (int i = 0; i < MAX_TIMER_COUNT; ++i) {
269         TimerId timer;
270         SetTimer(timerCountArray[i], timer, finalizeCountArray[i]);
271     }
272     /**
273      * @tc.steps: step2. wait timer finished
274      */
275     watcher->SafeExit();
276     for (int i = 0; i < MAX_TIMER_COUNT; ++i) {
277         EXPECT_EQ(timerCountArray[i], MAX_TIMER_COUNT);
278         EXPECT_EQ(finalizeCountArray[i], 1);
279     }
280 }
281 
282 /**
283  * @tc.name: SetTimer003
284  * @tc.desc: Test set timer and finalize is null
285  * @tc.type: FUNC
286  * @tc.require: AR000I0KU9
287  * @tc.author: zhangqiquan
288  */
289 HWTEST_F(DistributedDBThreadPoolTest, SetTimer003, TestSize.Level1)
290 {
291     ASSERT_NE(threadPoolPtr_, nullptr);
292     std::shared_ptr<TimerWatcher> watcher = std::make_shared<TimerWatcher>();
293     std::atomic<TaskId> currentId = 1;
294     ASSERT_NO_FATAL_FAILURE(MockSchedule(threadPoolPtr_, watcher, currentId));
295     /**
296      * @tc.steps: step1. set timer and record timer call count
297      * @tc.expected: step1. call count is MAX_TIMER_COUNT and finalize once.
298      */
299     int timerCount = 0;
300     TimerId timer = 0;
301     ASSERT_NO_FATAL_FAILURE(SetTimer(timerCount, timer));
302     /**
303      * @tc.steps: step3. wait timer finished
304      */
305     watcher->SafeExit();
306     EXPECT_EQ(timerCount, MAX_TIMER_COUNT);
307 }
308 
309 /**
310  * @tc.name: SetTimer004
311  * @tc.desc: Test remove timer function
312  * @tc.type: FUNC
313  * @tc.require: AR000I0KU9
314  * @tc.author: zhangqiquan
315  */
316 HWTEST_F(DistributedDBThreadPoolTest, SetTimer004, TestSize.Level2)
317 {
318     ASSERT_NE(threadPoolPtr_, nullptr);
319     std::shared_ptr<TimerWatcher> watcher = std::make_shared<TimerWatcher>();
320     std::atomic<TaskId> currentId = 1;
321     ASSERT_NO_FATAL_FAILURE(MockSchedule(threadPoolPtr_, watcher, currentId));
322     int removeCount = 0;
323     MockRemove(threadPoolPtr_, false, removeCount);
324     /**
325      * @tc.steps: step1. set timer and record timer call count
326      * @tc.expected: step1. call count is MAX_TIMER_COUNT and finalize once.
327      */
328     int timerCount = 0;
329     TimerId timer = 0;
330     int finalizeCount = 0;
331     ASSERT_NO_FATAL_FAILURE(SetTimer(timerCount, timer, finalizeCount, ONE_SECOND));
332     /**
333      * @tc.steps: step2. remove timer
334      * @tc.expected: step2. call count is zero when timerId no exist.
335      */
336     std::this_thread::sleep_for(std::chrono::milliseconds(ONE_HUNDRED_MILLISECONDS));
337     RuntimeContext::GetInstance()->RemoveTimer(timer - 1, true);
338     EXPECT_EQ(removeCount, 0);
339     MockRemove(threadPoolPtr_, true, removeCount);
340     RuntimeContext::GetInstance()->RemoveTimer(timer, true);
341     EXPECT_EQ(removeCount, 1);
342     /**
343      * @tc.steps: step3. wait timer finished
344      */
345     watcher->SafeExit();
346     EXPECT_EQ(timerCount, 0);
347 }
348 
349 /**
350  * @tc.name: SetTimer005
351  * @tc.desc: Test repeat remove timer
352  * @tc.type: FUNC
353  * @tc.require: AR000I0KU9
354  * @tc.author: zhangqiquan
355  */
356 HWTEST_F(DistributedDBThreadPoolTest, SetTimer005, TestSize.Level1)
357 {
358     ASSERT_NE(threadPoolPtr_, nullptr);
359     std::shared_ptr<TimerWatcher> watcher = std::make_shared<TimerWatcher>();
360     std::atomic<TaskId> currentId = 1;
361     ASSERT_NO_FATAL_FAILURE(MockSchedule(threadPoolPtr_, watcher, currentId));
362     int removeCount = 0;
363     MockRemove(threadPoolPtr_, true, removeCount);
364     /**
365      * @tc.steps: step1. set timer and record timer call count
366      * @tc.expected: step1. call count is MAX_TIMER_COUNT and finalize once.
367      */
368     int timerCountArray[MAX_TIMER_COUNT] = {};
369     int finalizeCountArray[MAX_TIMER_COUNT] = {};
370     TimerId timerIdArray[MAX_TIMER_COUNT] = {};
371     for (int i = 0; i < MAX_TIMER_COUNT; ++i) {
372         SetTimer(timerCountArray[i], timerIdArray[i], finalizeCountArray[i], ONE_SECOND);
373     }
374     /**
375      * @tc.steps: step2. remove all timer
376      */
377     int sleepTime = MAX_TIMER_COUNT * ONE_SECOND - 2 * ONE_HUNDRED_MILLISECONDS;
378     std::this_thread::sleep_for(std::chrono::milliseconds(sleepTime));
379     for (const auto &timerId : timerIdArray) {
380         RuntimeContext::GetInstance()->RemoveTimer(timerId, true);
381     }
382     /**
383      * @tc.steps: step3. wait timer finished
384      */
385     watcher->SafeExit();
386     EXPECT_LE(removeCount, MAX_TIMER_COUNT);
387     for (int i = 0; i < MAX_TIMER_COUNT; ++i) {
388         EXPECT_LE(timerCountArray[i], MAX_TIMER_COUNT);
389         EXPECT_EQ(finalizeCountArray[i], 1);
390     }
391 }
392 
393 /**
394  * @tc.name: SetTimer006
395  * @tc.desc: Test repeat remove timer when time action finished
396  * @tc.type: FUNC
397  * @tc.require: AR000I0KU9
398  * @tc.author: zhangqiquan
399  */
400 HWTEST_F(DistributedDBThreadPoolTest, SetTimer006, TestSize.Level1)
401 {
402     ASSERT_NE(threadPoolPtr_, nullptr);
403     std::atomic<TaskId> currentId = 1;
404     std::shared_ptr<TimerWatcher> watcher = std::make_shared<TimerWatcher>();
405     ASSERT_NO_FATAL_FAILURE(MockSchedule(threadPoolPtr_, watcher, currentId));
406     int removeCount = 0;
407     MockRemove(threadPoolPtr_, false, removeCount);
408 
409     std::mutex dataMutex;
410     std::set<TimerId> removeSet;
411     std::set<TimerId> checkSet;
412     std::set<TimerId> timerSet;
413     for (int i = 0; i < 10; ++i) { // 10 timer
414         TimerId id;
__anonbed5c56d0d02(TimerId timerId) 415         int errCode = RuntimeContext::GetInstance()->SetTimer(1, [&dataMutex, &removeSet, &checkSet](TimerId timerId) {
416             LOGI("Timer %" PRIu64 " running", timerId);
417             std::lock_guard<std::mutex> autoLock(dataMutex);
418             if (removeSet.find(timerId) != removeSet.end()) {
419                 EXPECT_TRUE(checkSet.find(timerId) == checkSet.end());
420                 checkSet.insert(timerId);
421                 return -E_END_TIMER;
422             }
423             return E_OK;
424         }, nullptr, id);
425         EXPECT_EQ(errCode, E_OK);
426         timerSet.insert(id);
427     }
428     for (const auto &timer: timerSet) {
429         RuntimeContext::GetInstance()->RemoveTimer(timer);
430         LOGI("Timer %" PRIu64 " remove", timer);
431         std::lock_guard<std::mutex> autoLock(dataMutex);
432         removeSet.insert(timer);
433     }
434     watcher->SafeExit();
435 }
436 
437 /**
438  * @tc.name: TaskPool001
439  * @tc.desc: Test TaskPool schedule task
440  * @tc.type: FUNC
441  * @tc.require:
442  * @tc.author: zhangqiquan
443  */
444 HWTEST_F(DistributedDBThreadPoolTest, TaskPool001, TestSize.Level1)
445 {
446     RuntimeContext::GetInstance()->SetThreadPool(nullptr);
447     std::mutex dataMutex;
448     std::condition_variable cv;
449     int finishedTaskCount = 0;
__anonbed5c56d0e02() 450     int errCode = RuntimeContext::GetInstance()->ScheduleTask([&finishedTaskCount, &dataMutex, &cv]() {
451         std::this_thread::sleep_for(std::chrono::seconds(1)); // sleep 1s
452         LOGD("exec task ok");
453         {
454             std::lock_guard<std::mutex> autoLock(dataMutex);
455             finishedTaskCount++;
456         }
457         cv.notify_one();
458     });
459     EXPECT_EQ(errCode, E_OK);
460     constexpr int execTaskCount = 2;
461     for (int i = 0; i < execTaskCount; ++i) {
462         errCode = RuntimeContext::GetInstance()->ScheduleQueuedTask("TaskPool",
__anonbed5c56d0f02() 463             [i, &finishedTaskCount, &dataMutex, &cv]() {
464             LOGD("exec task %d", i);
465             {
466                 std::lock_guard<std::mutex> autoLock(dataMutex);
467                 finishedTaskCount++;
468             }
469             cv.notify_one();
470         });
471         EXPECT_EQ(errCode, E_OK);
472         std::this_thread::sleep_for(std::chrono::seconds(1)); // sleep 1s
473     }
474     {
475         std::unique_lock<std::mutex> uniqueLock(dataMutex);
476         LOGD("begin wait all task finished");
__anonbed5c56d1002() 477         cv.wait(uniqueLock, [&finishedTaskCount]() {
478             return finishedTaskCount == execTaskCount + 1;
479         });
480         LOGD("end wait all task finished");
481     }
482     RuntimeContext::GetInstance()->StopTaskPool();
483 }
484 }