1 /*
2 * Copyright (c) 2021 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #include "frame_combiner.h"
17 #include <set>
18 #include "log_print.h"
19 #include "protocol_proto.h"
20
21 namespace DistributedDB {
22 static const uint32_t MAX_WORK_PER_SRC_TARGET = 1; // Only allow 1 CombineWork for each target
23 static const int COMBINER_SURVAIL_PERIOD_IN_MILLISECOND = 10000; // Period is 10 s
24
Initialize()25 void FrameCombiner::Initialize()
26 {
27 RuntimeContext *context = RuntimeContext::GetInstance();
28 TimerAction action = [this](TimerId inTimerId)->int {
29 PeriodicalSurveillance();
30 return E_OK;
31 };
32 TimerFinalizer finalizer = [this]() {
33 timerRemovedIndicator_.SendSemaphore();
34 };
35 int errCode = context->SetTimer(COMBINER_SURVAIL_PERIOD_IN_MILLISECOND, action, finalizer, timerId_);
36 if (errCode != E_OK) {
37 LOGE("[Combiner][Init] Set timer fail, errCode=%d.", errCode);
38 return;
39 }
40 isTimerWork_ = true;
41 }
42
Finalize()43 void FrameCombiner::Finalize()
44 {
45 // First: Stop the timer
46 if (isTimerWork_) {
47 RuntimeContext *context = RuntimeContext::GetInstance();
48 context->RemoveTimer(timerId_);
49 timerRemovedIndicator_.WaitSemaphore();
50 }
51
52 // Second: Clear the combineWorkPool_
53 for (auto &eachSource : combineWorkPool_) {
54 for (auto &eachFrame : eachSource.second) {
55 delete eachFrame.second.buffer;
56 eachFrame.second.buffer = nullptr;
57 }
58 }
59 }
60
AssembleFrameFragment(const uint8_t * bytes,uint32_t length,const ParseResult & inPacketInfo,ParseResult & outFrameInfo,int & outErrorNo)61 SerialBuffer *FrameCombiner::AssembleFrameFragment(const uint8_t *bytes, uint32_t length,
62 const ParseResult &inPacketInfo, ParseResult &outFrameInfo, int &outErrorNo)
63 {
64 uint64_t sourceId = inPacketInfo.GetSourceId();
65 uint32_t frameId = inPacketInfo.GetFrameId();
66 std::lock_guard<std::mutex> overallLockGuard(overallMutex_);
67 if (combineWorkPool_[sourceId].count(frameId) != 0) {
68 // CombineWork already exist
69 int errCode = ContinueExistCombineWork(bytes, length, inPacketInfo);
70 if (errCode != E_OK) {
71 LOGE("[Combiner][Assemble] Continue work fail, errCode=%d.", errCode);
72 outErrorNo = errCode;
73 return nullptr;
74 }
75
76 if (combineWorkPool_[sourceId][frameId].status.IsCombineDone()) {
77 // We can parse the combined frame here, or outside this class.
78 LOGI("[Combiner][Assemble] Combine done, sourceId=%" PRIu64 ", frameId=%" PRIu32, ULL(sourceId), frameId);
79 SerialBuffer *outFrame = combineWorkPool_[sourceId][frameId].buffer;
80 outFrameInfo = combineWorkPool_[sourceId][frameId].frameInfo;
81 outErrorNo = E_OK;
82 combineWorkPool_[sourceId].erase(frameId);
83 return outFrame; // The caller is responsible for release the outFrame
84 }
85 } else {
86 // CombineWork not exist and even existing work number reaches the limitation. Try create work first.
87 int errCode = CreateNewCombineWork(bytes, length, inPacketInfo);
88 if (errCode != E_OK) {
89 LOGE("[Combiner][Assemble] Create work fail, errCode=%d.", errCode);
90 outErrorNo = errCode;
91 return nullptr;
92 }
93 // After successfully create work, the existing work number may exceed the limitation
94 // If so, choose one from works of this target with lowest progressId and abort it
95 if (combineWorkPool_[sourceId].size() > MAX_WORK_PER_SRC_TARGET) {
96 AbortCombineWorkBySource(sourceId);
97 }
98 }
99 outErrorNo = E_OK;
100 return nullptr;
101 }
102
PeriodicalSurveillance()103 void FrameCombiner::PeriodicalSurveillance()
104 {
105 std::lock_guard<std::mutex> overallLockGuard(overallMutex_);
106 for (auto &eachSource : combineWorkPool_) {
107 std::set<uint32_t> frameToAbort;
108 for (auto &eachFrame : eachSource.second) {
109 if (!eachFrame.second.status.CheckProgress()) {
110 LOGW("[Combiner][Surveil] Source=%" PRIu64 ", frame=%" PRIu32
111 " has no progress, this combine work will be aborted.", ULL(eachSource.first), eachFrame.first);
112 // Free this combine work first
113 delete eachFrame.second.buffer;
114 eachFrame.second.buffer = nullptr;
115 // Record this frame in abort list
116 frameToAbort.insert(eachFrame.first);
117 }
118 }
119 // Remove the combine work from map
120 for (auto &entry : frameToAbort) {
121 eachSource.second.erase(entry);
122 }
123 }
124 }
125
ContinueExistCombineWork(const uint8_t * bytes,uint32_t length,const ParseResult & inPacketInfo)126 int FrameCombiner::ContinueExistCombineWork(const uint8_t *bytes, uint32_t length, const ParseResult &inPacketInfo)
127 {
128 uint64_t sourceId = inPacketInfo.GetSourceId();
129 uint32_t frameId = inPacketInfo.GetFrameId();
130 CombineWork &oriWork = combineWorkPool_[sourceId][frameId]; // Be care here must be reference
131 if (!CheckPacketWithOriWork(inPacketInfo, oriWork)) {
132 LOGE("[Combiner][ContinueWork] Check packet fail, sourceId=%" PRIu64 ", frameId=%" PRIu32, sourceId, frameId);
133 return -E_COMBINE_FAIL;
134 }
135
136 uint32_t fragOffset = oriWork.status.GetThisFragmentOffset(inPacketInfo.GetFragNo());
137 uint32_t fragLength = oriWork.status.GetThisFragmentLength(inPacketInfo.GetFragNo());
138 int errCode = ProtocolProto::CombinePacketIntoFrame(oriWork.buffer, bytes, length, fragOffset, fragLength);
139 if (errCode != E_OK) {
140 // We can consider abort this work, but here we choose not to affect it
141 LOGE("[Combiner][ContinueWork] Combine packet fail, sourceId=%" PRIu64 ", frameId=%" PRIu32, sourceId, frameId);
142 return -E_COMBINE_FAIL;
143 }
144
145 oriWork.status.UpdateProgressId(incProgressId_++);
146 oriWork.status.CheckInFragmentNo(inPacketInfo.GetFragNo());
147 return E_OK;
148 }
149
CreateNewCombineWork(const uint8_t * bytes,uint32_t length,const ParseResult & inPacketInfo)150 int FrameCombiner::CreateNewCombineWork(const uint8_t *bytes, uint32_t length, const ParseResult &inPacketInfo)
151 {
152 uint32_t fragLen = 0;
153 uint32_t lastFragLen = 0;
154 int errCode = ProtocolProto::AnalyzeSplitStructure(inPacketInfo, fragLen, lastFragLen);
155 if (errCode != E_OK) {
156 LOGE("[Combiner][CreateWork] Analyze fail, errCode=%d.", errCode);
157 return errCode;
158 }
159
160 CombineWork work;
161
162 work.frameInfo.SetPacketLen(inPacketInfo.GetFrameLen());
163 work.frameInfo.SetSourceId(inPacketInfo.GetSourceId());
164 work.frameInfo.SetFrameId(inPacketInfo.GetFrameId());
165 work.frameInfo.SetFrameTypeInfo(inPacketInfo.GetFrameTypeInfo());
166 work.frameInfo.SetFrameLen(inPacketInfo.GetFrameLen());
167 work.frameInfo.SetFragCount(inPacketInfo.GetFragCount());
168
169 work.status.SetFragmentLen(fragLen);
170 work.status.SetLastFragmentLen(lastFragLen);
171 work.status.SetFragmentCount(inPacketInfo.GetFragCount());
172
173 work.buffer = CreateNewFrameBuffer(inPacketInfo);
174 if (work.buffer == nullptr) {
175 return -E_OUT_OF_MEMORY;
176 }
177
178 uint32_t fragOffset = work.status.GetThisFragmentOffset(inPacketInfo.GetFragNo());
179 uint32_t fragLength = work.status.GetThisFragmentLength(inPacketInfo.GetFragNo());
180 errCode = ProtocolProto::CombinePacketIntoFrame(work.buffer, bytes, length, fragOffset, fragLength);
181 if (errCode != E_OK) {
182 delete work.buffer;
183 work.buffer = nullptr;
184 return errCode;
185 }
186
187 totalSizeByByte_ += work.buffer->GetSize();
188 work.status.UpdateProgressId(incProgressId_++);
189 work.status.CheckInFragmentNo(inPacketInfo.GetFragNo());
190 combineWorkPool_[inPacketInfo.GetSourceId()][inPacketInfo.GetFrameId()] = work;
191 return E_OK;
192 }
193
AbortCombineWorkBySource(uint64_t inSourceId)194 void FrameCombiner::AbortCombineWorkBySource(uint64_t inSourceId)
195 {
196 if (combineWorkPool_[inSourceId].empty()) {
197 return;
198 }
199 uint32_t toBeAbortFrameId = 0;
200 uint64_t toBeAbortProgressId = UINT64_MAX;
201 for (auto &entry : combineWorkPool_[inSourceId]) {
202 if (entry.second.status.GetProgressId() < toBeAbortProgressId) {
203 toBeAbortProgressId = entry.second.status.GetProgressId();
204 toBeAbortFrameId = entry.first;
205 }
206 }
207 // Do Abort!
208 LOGW("[Combiner][AbortWork] Abort Incomplete CombineWork, sourceId=%" PRIu64 ", frameId=%" PRIu32 ".",
209 ULL(inSourceId), toBeAbortFrameId);
210 delete combineWorkPool_[inSourceId][toBeAbortFrameId].buffer;
211 combineWorkPool_[inSourceId][toBeAbortFrameId].buffer = nullptr;
212 combineWorkPool_[inSourceId].erase(toBeAbortFrameId);
213 }
214
CheckPacketWithOriWork(const ParseResult & inPacketInfo,const CombineWork & inWork)215 bool FrameCombiner::CheckPacketWithOriWork(const ParseResult &inPacketInfo, const CombineWork &inWork)
216 {
217 if (inPacketInfo.GetFrameLen() != inWork.frameInfo.GetFrameLen()) {
218 LOGE("[Combiner][CheckPacket] FrameLen mismatch %" PRIu32 " vs %" PRIu32 ".", inPacketInfo.GetFrameLen(),
219 inWork.frameInfo.GetFrameLen());
220 return false;
221 }
222 if (inPacketInfo.GetFragCount() != inWork.frameInfo.GetFragCount()) {
223 LOGE("[Combiner][CheckPacket] FragCount mismatch %" PRIu32 " vs %" PRIu32 ".", inPacketInfo.GetFragCount(),
224 inWork.frameInfo.GetFragCount());
225 return false;
226 }
227 if (inPacketInfo.GetFragNo() >= inPacketInfo.GetFragCount()) {
228 LOGE("[Combiner][CheckPacket] FragNo=%" PRIu32 " illegal vs FragCount=%" PRIu32 ".", inPacketInfo.GetFragNo(),
229 inPacketInfo.GetFragCount());
230 return false;
231 }
232 if (inWork.status.IsFragNoAlreadyExist(inPacketInfo.GetFragNo())) {
233 LOGE("[Combiner][CheckPacket] FragNo=%" PRIu32 " already exist.", inPacketInfo.GetFragNo());
234 return false;
235 }
236 return true;
237 }
238
CreateNewFrameBuffer(const ParseResult & inInfo)239 SerialBuffer *FrameCombiner::CreateNewFrameBuffer(const ParseResult &inInfo)
240 {
241 SerialBuffer *buffer = new (std::nothrow) SerialBuffer();
242 if (buffer == nullptr) {
243 return nullptr;
244 }
245 uint32_t frameHeaderLength = (inInfo.GetFrameTypeInfo() != FrameType::APPLICATION_MESSAGE) ?
246 ProtocolProto::GetCommLayerFrameHeaderLength() : ProtocolProto::GetAppLayerFrameHeaderLength();
247 int errCode = buffer->AllocBufferByTotalLength(inInfo.GetFrameLen(), frameHeaderLength);
248 if (errCode != E_OK) {
249 LOGE("[Combiner][CreateBuffer] Alloc Buffer Fail.");
250 delete buffer;
251 buffer = nullptr;
252 return nullptr;
253 }
254 return buffer;
255 }
256 } // namespace DistributedDB