1 /*
2  * Copyright (c) 2019, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "cppbor.h"
18 #include "cppbor_parse.h"
19 
20 #define LOG_TAG "CppBor"
21 #include <android-base/logging.h>
22 
23 namespace cppbor {
24 
25 namespace {
26 
27 template <typename T, typename Iterator, typename = std::enable_if<std::is_unsigned<T>::value>>
28 Iterator writeBigEndian(T value, Iterator pos) {
29     for (unsigned i = 0; i < sizeof(value); ++i) {
30         *pos++ = static_cast<uint8_t>(value >> (8 * (sizeof(value) - 1)));
31         value = static_cast<T>(value << 8);
32     }
33     return pos;
34 }
35 
36 template <typename T, typename = std::enable_if<std::is_unsigned<T>::value>>
37 void writeBigEndian(T value, std::function<void(uint8_t)>& cb) {
38     for (unsigned i = 0; i < sizeof(value); ++i) {
39         cb(static_cast<uint8_t>(value >> (8 * (sizeof(value) - 1))));
40         value = static_cast<T>(value << 8);
41     }
42 }
43 
44 }  // namespace
45 
headerSize(uint64_t addlInfo)46 size_t headerSize(uint64_t addlInfo) {
47     if (addlInfo < ONE_BYTE_LENGTH) return 1;
48     if (addlInfo <= std::numeric_limits<uint8_t>::max()) return 2;
49     if (addlInfo <= std::numeric_limits<uint16_t>::max()) return 3;
50     if (addlInfo <= std::numeric_limits<uint32_t>::max()) return 5;
51     return 9;
52 }
53 
encodeHeader(MajorType type,uint64_t addlInfo,uint8_t * pos,const uint8_t * end)54 uint8_t* encodeHeader(MajorType type, uint64_t addlInfo, uint8_t* pos, const uint8_t* end) {
55     size_t sz = headerSize(addlInfo);
56     if (end - pos < static_cast<ssize_t>(sz)) return nullptr;
57     switch (sz) {
58         case 1:
59             *pos++ = type | static_cast<uint8_t>(addlInfo);
60             return pos;
61         case 2:
62             *pos++ = type | ONE_BYTE_LENGTH;
63             *pos++ = static_cast<uint8_t>(addlInfo);
64             return pos;
65         case 3:
66             *pos++ = type | TWO_BYTE_LENGTH;
67             return writeBigEndian(static_cast<uint16_t>(addlInfo), pos);
68         case 5:
69             *pos++ = type | FOUR_BYTE_LENGTH;
70             return writeBigEndian(static_cast<uint32_t>(addlInfo), pos);
71         case 9:
72             *pos++ = type | EIGHT_BYTE_LENGTH;
73             return writeBigEndian(addlInfo, pos);
74         default:
75             CHECK(false);  // Impossible to get here.
76             return nullptr;
77     }
78 }
79 
encodeHeader(MajorType type,uint64_t addlInfo,EncodeCallback encodeCallback)80 void encodeHeader(MajorType type, uint64_t addlInfo, EncodeCallback encodeCallback) {
81     size_t sz = headerSize(addlInfo);
82     switch (sz) {
83         case 1:
84             encodeCallback(type | static_cast<uint8_t>(addlInfo));
85             break;
86         case 2:
87             encodeCallback(type | ONE_BYTE_LENGTH);
88             encodeCallback(static_cast<uint8_t>(addlInfo));
89             break;
90         case 3:
91             encodeCallback(type | TWO_BYTE_LENGTH);
92             writeBigEndian(static_cast<uint16_t>(addlInfo), encodeCallback);
93             break;
94         case 5:
95             encodeCallback(type | FOUR_BYTE_LENGTH);
96             writeBigEndian(static_cast<uint32_t>(addlInfo), encodeCallback);
97             break;
98         case 9:
99             encodeCallback(type | EIGHT_BYTE_LENGTH);
100             writeBigEndian(addlInfo, encodeCallback);
101             break;
102         default:
103             CHECK(false);  // Impossible to get here.
104     }
105 }
106 
operator ==(const Item & other) const107 bool Item::operator==(const Item& other) const& {
108     if (type() != other.type()) return false;
109     switch (type()) {
110         case UINT:
111             return *asUint() == *(other.asUint());
112         case NINT:
113             return *asNint() == *(other.asNint());
114         case BSTR:
115             return *asBstr() == *(other.asBstr());
116         case TSTR:
117             return *asTstr() == *(other.asTstr());
118         case ARRAY:
119             return *asArray() == *(other.asArray());
120         case MAP:
121             return *asMap() == *(other.asMap());
122         case SIMPLE:
123             return *asSimple() == *(other.asSimple());
124         case SEMANTIC:
125             return *asSemantic() == *(other.asSemantic());
126         default:
127             CHECK(false);  // Impossible to get here.
128             return false;
129     }
130 }
131 
Nint(int64_t v)132 Nint::Nint(int64_t v) : mValue(v) {
133     CHECK(v < 0) << "Only negative values allowed";
134 }
135 
operator ==(const Simple & other) const136 bool Simple::operator==(const Simple& other) const& {
137     if (simpleType() != other.simpleType()) return false;
138 
139     switch (simpleType()) {
140         case BOOLEAN:
141             return *asBool() == *(other.asBool());
142         case NULL_T:
143             return true;
144         default:
145             CHECK(false);  // Impossible to get here.
146             return false;
147     }
148 }
149 
encode(uint8_t * pos,const uint8_t * end) const150 uint8_t* Bstr::encode(uint8_t* pos, const uint8_t* end) const {
151     pos = encodeHeader(mValue.size(), pos, end);
152     if (!pos || end - pos < static_cast<ptrdiff_t>(mValue.size())) return nullptr;
153     return std::copy(mValue.begin(), mValue.end(), pos);
154 }
155 
encodeValue(EncodeCallback encodeCallback) const156 void Bstr::encodeValue(EncodeCallback encodeCallback) const {
157     for (auto c : mValue) {
158         encodeCallback(c);
159     }
160 }
161 
encode(uint8_t * pos,const uint8_t * end) const162 uint8_t* Tstr::encode(uint8_t* pos, const uint8_t* end) const {
163     pos = encodeHeader(mValue.size(), pos, end);
164     if (!pos || end - pos < static_cast<ptrdiff_t>(mValue.size())) return nullptr;
165     return std::copy(mValue.begin(), mValue.end(), pos);
166 }
167 
encodeValue(EncodeCallback encodeCallback) const168 void Tstr::encodeValue(EncodeCallback encodeCallback) const {
169     for (auto c : mValue) {
170         encodeCallback(static_cast<uint8_t>(c));
171     }
172 }
173 
operator ==(const CompoundItem & other) const174 bool CompoundItem::operator==(const CompoundItem& other) const& {
175     return type() == other.type()             //
176            && addlInfo() == other.addlInfo()  //
177            // Can't use vector::operator== because the contents are pointers.  std::equal lets us
178            // provide a predicate that does the dereferencing.
179            && std::equal(mEntries.begin(), mEntries.end(), other.mEntries.begin(),
180                          [](auto& a, auto& b) -> bool { return *a == *b; });
181 }
182 
encode(uint8_t * pos,const uint8_t * end) const183 uint8_t* CompoundItem::encode(uint8_t* pos, const uint8_t* end) const {
184     pos = encodeHeader(addlInfo(), pos, end);
185     if (!pos) return nullptr;
186     for (auto& entry : mEntries) {
187         pos = entry->encode(pos, end);
188         if (!pos) return nullptr;
189     }
190     return pos;
191 }
192 
encode(EncodeCallback encodeCallback) const193 void CompoundItem::encode(EncodeCallback encodeCallback) const {
194     encodeHeader(addlInfo(), encodeCallback);
195     for (auto& entry : mEntries) {
196         entry->encode(encodeCallback);
197     }
198 }
199 
assertInvariant() const200 void Map::assertInvariant() const {
201     CHECK(mEntries.size() % 2 == 0);
202 }
203 
clone() const204 std::unique_ptr<Item> Map::clone() const {
205     assertInvariant();
206     auto res = std::make_unique<Map>();
207     for (size_t i = 0; i < mEntries.size(); i += 2) {
208         res->add(mEntries[i]->clone(), mEntries[i + 1]->clone());
209     }
210     return res;
211 }
212 
clone() const213 std::unique_ptr<Item> Array::clone() const {
214     auto res = std::make_unique<Array>();
215     for (size_t i = 0; i < mEntries.size(); i++) {
216         res->add(mEntries[i]->clone());
217     }
218     return res;
219 }
220 
assertInvariant() const221 void Semantic::assertInvariant() const {
222     CHECK(mEntries.size() == 1);
223 }
224 
225 }  // namespace cppbor
226