1 /*
2  * Copyright (c) 2021 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #include "frame_combiner.h"
17 #include <set>
18 #include "log_print.h"
19 #include "protocol_proto.h"
20 
21 namespace DistributedDB {
22 static const uint32_t MAX_WORK_PER_SRC_TARGET = 1; // Only allow 1 CombineWork for each target
23 static const int COMBINER_SURVAIL_PERIOD_IN_MILLISECOND = 10000; // Period is 10 s
24 
Initialize()25 void FrameCombiner::Initialize()
26 {
27     RuntimeContext *context = RuntimeContext::GetInstance();
28     TimerAction action = [this](TimerId inTimerId)->int {
29         PeriodicalSurveillance();
30         return E_OK;
31     };
32     TimerFinalizer finalizer = [this]() {
33         timerRemovedIndicator_.SendSemaphore();
34     };
35     int errCode = context->SetTimer(COMBINER_SURVAIL_PERIOD_IN_MILLISECOND, action, finalizer, timerId_);
36     if (errCode != E_OK) {
37         LOGE("[Combiner][Init] Set timer fail, errCode=%d.", errCode);
38         return;
39     }
40     isTimerWork_ = true;
41 }
42 
Finalize()43 void FrameCombiner::Finalize()
44 {
45     // First: Stop the timer
46     if (isTimerWork_) {
47         RuntimeContext *context = RuntimeContext::GetInstance();
48         context->RemoveTimer(timerId_);
49         timerRemovedIndicator_.WaitSemaphore();
50     }
51 
52     // Second: Clear the combineWorkPool_
53     for (auto &eachSource : combineWorkPool_) {
54         for (auto &eachFrame : eachSource.second) {
55             delete eachFrame.second.buffer;
56             eachFrame.second.buffer = nullptr;
57         }
58     }
59 }
60 
AssembleFrameFragment(const uint8_t * bytes,uint32_t length,const ParseResult & inPacketInfo,ParseResult & outFrameInfo,int & outErrorNo)61 SerialBuffer *FrameCombiner::AssembleFrameFragment(const uint8_t *bytes, uint32_t length,
62     const ParseResult &inPacketInfo, ParseResult &outFrameInfo, int &outErrorNo)
63 {
64     uint64_t sourceId = inPacketInfo.GetSourceId();
65     uint32_t frameId = inPacketInfo.GetFrameId();
66     std::lock_guard<std::mutex> overallLockGuard(overallMutex_);
67     if (combineWorkPool_[sourceId].count(frameId) != 0) {
68         // CombineWork already exist
69         int errCode = ContinueExistCombineWork(bytes, length, inPacketInfo);
70         if (errCode != E_OK) {
71             LOGE("[Combiner][Assemble] Continue work fail, errCode=%d.", errCode);
72             outErrorNo = errCode;
73             return nullptr;
74         }
75 
76         if (combineWorkPool_[sourceId][frameId].status.IsCombineDone()) {
77             // We can parse the combined frame here, or outside this class.
78             LOGI("[Combiner][Assemble] Combine done, sourceId=%" PRIu64 ", frameId=%" PRIu32, ULL(sourceId), frameId);
79             SerialBuffer *outFrame = combineWorkPool_[sourceId][frameId].buffer;
80             outFrameInfo = combineWorkPool_[sourceId][frameId].frameInfo;
81             outErrorNo = E_OK;
82             combineWorkPool_[sourceId].erase(frameId);
83             return outFrame; // The caller is responsible for release the outFrame
84         }
85     } else {
86         // CombineWork not exist and even existing work number reaches the limitation. Try create work first.
87         int errCode = CreateNewCombineWork(bytes, length, inPacketInfo);
88         if (errCode != E_OK) {
89             LOGE("[Combiner][Assemble] Create work fail, errCode=%d.", errCode);
90             outErrorNo = errCode;
91             return nullptr;
92         }
93         // After successfully create work, the existing work number may exceed the limitation
94         // If so, choose one from works of this target with lowest progressId and abort it
95         if (combineWorkPool_[sourceId].size() > MAX_WORK_PER_SRC_TARGET) {
96             AbortCombineWorkBySource(sourceId);
97         }
98     }
99     outErrorNo = E_OK;
100     return nullptr;
101 }
102 
PeriodicalSurveillance()103 void FrameCombiner::PeriodicalSurveillance()
104 {
105     std::lock_guard<std::mutex> overallLockGuard(overallMutex_);
106     for (auto &eachSource : combineWorkPool_) {
107         std::set<uint32_t> frameToAbort;
108         for (auto &eachFrame : eachSource.second) {
109             if (!eachFrame.second.status.CheckProgress()) {
110                 LOGW("[Combiner][Surveil] Source=%" PRIu64 ", frame=%" PRIu32
111                     " has no progress, this combine work will be aborted.", ULL(eachSource.first), eachFrame.first);
112                 // Free this combine work first
113                 delete eachFrame.second.buffer;
114                 eachFrame.second.buffer = nullptr;
115                 // Record this frame in abort list
116                 frameToAbort.insert(eachFrame.first);
117             }
118         }
119         // Remove the combine work from map
120         for (auto &entry : frameToAbort) {
121             eachSource.second.erase(entry);
122         }
123     }
124 }
125 
ContinueExistCombineWork(const uint8_t * bytes,uint32_t length,const ParseResult & inPacketInfo)126 int FrameCombiner::ContinueExistCombineWork(const uint8_t *bytes, uint32_t length, const ParseResult &inPacketInfo)
127 {
128     uint64_t sourceId = inPacketInfo.GetSourceId();
129     uint32_t frameId = inPacketInfo.GetFrameId();
130     CombineWork &oriWork = combineWorkPool_[sourceId][frameId]; // Be care here must be reference
131     if (!CheckPacketWithOriWork(inPacketInfo, oriWork)) {
132         LOGE("[Combiner][ContinueWork] Check packet fail, sourceId=%" PRIu64 ", frameId=%" PRIu32, sourceId, frameId);
133         return -E_COMBINE_FAIL;
134     }
135 
136     uint32_t fragOffset = oriWork.status.GetThisFragmentOffset(inPacketInfo.GetFragNo());
137     uint32_t fragLength = oriWork.status.GetThisFragmentLength(inPacketInfo.GetFragNo());
138     int errCode = ProtocolProto::CombinePacketIntoFrame(oriWork.buffer, bytes, length, fragOffset, fragLength);
139     if (errCode != E_OK) {
140         // We can consider abort this work, but here we choose not to affect it
141         LOGE("[Combiner][ContinueWork] Combine packet fail, sourceId=%" PRIu64 ", frameId=%" PRIu32, sourceId, frameId);
142         return -E_COMBINE_FAIL;
143     }
144 
145     oriWork.status.UpdateProgressId(incProgressId_++);
146     oriWork.status.CheckInFragmentNo(inPacketInfo.GetFragNo());
147     return E_OK;
148 }
149 
CreateNewCombineWork(const uint8_t * bytes,uint32_t length,const ParseResult & inPacketInfo)150 int FrameCombiner::CreateNewCombineWork(const uint8_t *bytes, uint32_t length, const ParseResult &inPacketInfo)
151 {
152     uint32_t fragLen = 0;
153     uint32_t lastFragLen = 0;
154     int errCode = ProtocolProto::AnalyzeSplitStructure(inPacketInfo, fragLen, lastFragLen);
155     if (errCode != E_OK) {
156         LOGE("[Combiner][CreateWork] Analyze fail, errCode=%d.", errCode);
157         return errCode;
158     }
159 
160     CombineWork work;
161 
162     work.frameInfo.SetPacketLen(inPacketInfo.GetFrameLen());
163     work.frameInfo.SetSourceId(inPacketInfo.GetSourceId());
164     work.frameInfo.SetFrameId(inPacketInfo.GetFrameId());
165     work.frameInfo.SetFrameTypeInfo(inPacketInfo.GetFrameTypeInfo());
166     work.frameInfo.SetFrameLen(inPacketInfo.GetFrameLen());
167     work.frameInfo.SetFragCount(inPacketInfo.GetFragCount());
168 
169     work.status.SetFragmentLen(fragLen);
170     work.status.SetLastFragmentLen(lastFragLen);
171     work.status.SetFragmentCount(inPacketInfo.GetFragCount());
172 
173     work.buffer = CreateNewFrameBuffer(inPacketInfo);
174     if (work.buffer == nullptr) {
175         return -E_OUT_OF_MEMORY;
176     }
177 
178     uint32_t fragOffset = work.status.GetThisFragmentOffset(inPacketInfo.GetFragNo());
179     uint32_t fragLength = work.status.GetThisFragmentLength(inPacketInfo.GetFragNo());
180     errCode = ProtocolProto::CombinePacketIntoFrame(work.buffer, bytes, length, fragOffset, fragLength);
181     if (errCode != E_OK) {
182         delete work.buffer;
183         work.buffer = nullptr;
184         return errCode;
185     }
186 
187     totalSizeByByte_ += work.buffer->GetSize();
188     work.status.UpdateProgressId(incProgressId_++);
189     work.status.CheckInFragmentNo(inPacketInfo.GetFragNo());
190     combineWorkPool_[inPacketInfo.GetSourceId()][inPacketInfo.GetFrameId()] = work;
191     return E_OK;
192 }
193 
AbortCombineWorkBySource(uint64_t inSourceId)194 void FrameCombiner::AbortCombineWorkBySource(uint64_t inSourceId)
195 {
196     if (combineWorkPool_[inSourceId].empty()) {
197         return;
198     }
199     uint32_t toBeAbortFrameId = 0;
200     uint64_t toBeAbortProgressId = UINT64_MAX;
201     for (auto &entry : combineWorkPool_[inSourceId]) {
202         if (entry.second.status.GetProgressId() < toBeAbortProgressId) {
203             toBeAbortProgressId = entry.second.status.GetProgressId();
204             toBeAbortFrameId = entry.first;
205         }
206     }
207     // Do Abort!
208     LOGW("[Combiner][AbortWork] Abort Incomplete CombineWork, sourceId=%" PRIu64 ", frameId=%" PRIu32 ".",
209         ULL(inSourceId), toBeAbortFrameId);
210     delete combineWorkPool_[inSourceId][toBeAbortFrameId].buffer;
211     combineWorkPool_[inSourceId][toBeAbortFrameId].buffer = nullptr;
212     combineWorkPool_[inSourceId].erase(toBeAbortFrameId);
213 }
214 
CheckPacketWithOriWork(const ParseResult & inPacketInfo,const CombineWork & inWork)215 bool FrameCombiner::CheckPacketWithOriWork(const ParseResult &inPacketInfo, const CombineWork &inWork)
216 {
217     if (inPacketInfo.GetFrameLen() != inWork.frameInfo.GetFrameLen()) {
218         LOGE("[Combiner][CheckPacket] FrameLen mismatch %" PRIu32 " vs %" PRIu32 ".", inPacketInfo.GetFrameLen(),
219             inWork.frameInfo.GetFrameLen());
220         return false;
221     }
222     if (inPacketInfo.GetFragCount() != inWork.frameInfo.GetFragCount()) {
223         LOGE("[Combiner][CheckPacket] FragCount mismatch %" PRIu32 " vs %" PRIu32 ".", inPacketInfo.GetFragCount(),
224             inWork.frameInfo.GetFragCount());
225         return false;
226     }
227     if (inPacketInfo.GetFragNo() >= inPacketInfo.GetFragCount()) {
228         LOGE("[Combiner][CheckPacket] FragNo=%" PRIu32 " illegal vs FragCount=%" PRIu32 ".", inPacketInfo.GetFragNo(),
229             inPacketInfo.GetFragCount());
230         return false;
231     }
232     if (inWork.status.IsFragNoAlreadyExist(inPacketInfo.GetFragNo())) {
233         LOGE("[Combiner][CheckPacket] FragNo=%" PRIu32 " already exist.", inPacketInfo.GetFragNo());
234         return false;
235     }
236     return true;
237 }
238 
CreateNewFrameBuffer(const ParseResult & inInfo)239 SerialBuffer *FrameCombiner::CreateNewFrameBuffer(const ParseResult &inInfo)
240 {
241     SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
242     if (buffer == nullptr) {
243         return nullptr;
244     }
245     uint32_t frameHeaderLength = (inInfo.GetFrameTypeInfo() != FrameType::APPLICATION_MESSAGE) ?
246         ProtocolProto::GetCommLayerFrameHeaderLength() : ProtocolProto::GetAppLayerFrameHeaderLength();
247     int errCode = buffer->AllocBufferByTotalLength(inInfo.GetFrameLen(), frameHeaderLength);
248     if (errCode != E_OK) {
249         LOGE("[Combiner][CreateBuffer] Alloc Buffer Fail.");
250         delete buffer;
251         buffer = nullptr;
252         return nullptr;
253     }
254     return buffer;
255 }
256 } // namespace DistributedDB