1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "resolv"
18
19 #include <arpa/inet.h>
20
21 #include <chrono>
22
23 #include <android-base/logging.h>
24 #include <android-base/macros.h>
25 #include <gmock/gmock.h>
26 #include <gtest/gtest.h>
27 #include <netdutils/Slice.h>
28
29 #include "DnsTlsDispatcher.h"
30 #include "DnsTlsQueryMap.h"
31 #include "DnsTlsServer.h"
32 #include "DnsTlsSessionCache.h"
33 #include "DnsTlsSocket.h"
34 #include "DnsTlsTransport.h"
35 #include "Experiments.h"
36 #include "IDnsTlsSocket.h"
37 #include "IDnsTlsSocketFactory.h"
38 #include "IDnsTlsSocketObserver.h"
39 #include "tests/dns_responder/dns_tls_frontend.h"
40
41 namespace android {
42 namespace net {
43
44 using netdutils::makeSlice;
45 using netdutils::Slice;
46
47 static const std::string DOT_MAXTRIES_FLAG = "dot_maxtries";
48
49 typedef std::vector<uint8_t> bytevec;
50
parseServer(const char * server,in_port_t port,sockaddr_storage * parsed)51 static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
52 sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
53 if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
54 // IPv4 parse succeeded, so it's IPv4
55 sin->sin_family = AF_INET;
56 sin->sin_port = htons(port);
57 return;
58 }
59 sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
60 if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
61 // IPv6 parse succeeded, so it's IPv6.
62 sin6->sin6_family = AF_INET6;
63 sin6->sin6_port = htons(port);
64 return;
65 }
66 LOG(ERROR) << "Failed to parse server address: " << server;
67 }
68
69 std::string SERVERNAME1 = "dns.example.com";
70 std::string SERVERNAME2 = "dns.example.org";
71
72 // BaseTest just provides constants that are useful for the tests.
73 class BaseTest : public ::testing::Test {
74 protected:
BaseTest()75 BaseTest() {
76 parseServer("192.0.2.1", 853, &V4ADDR1);
77 parseServer("192.0.2.2", 853, &V4ADDR2);
78 parseServer("2001:db8::1", 853, &V6ADDR1);
79 parseServer("2001:db8::2", 853, &V6ADDR2);
80
81 SERVER1 = DnsTlsServer(V4ADDR1);
82 SERVER1.name = SERVERNAME1;
83 }
84
85 sockaddr_storage V4ADDR1;
86 sockaddr_storage V4ADDR2;
87 sockaddr_storage V6ADDR1;
88 sockaddr_storage V6ADDR2;
89
90 DnsTlsServer SERVER1;
91 };
92
make_query(uint16_t id,size_t size)93 bytevec make_query(uint16_t id, size_t size) {
94 bytevec vec(size);
95 vec[0] = id >> 8;
96 vec[1] = id;
97 // Arbitrarily fill the query body with unique data.
98 for (size_t i = 2; i < size; ++i) {
99 vec[i] = id + i;
100 }
101 return vec;
102 }
103
104 // Query constants
105 const unsigned NETID = 123;
106 const unsigned MARK = 123;
107 const uint16_t ID = 52;
108 const uint16_t SIZE = 22;
109 const bytevec QUERY = make_query(ID, SIZE);
110
111 template <class T>
112 class FakeSocketFactory : public IDnsTlsSocketFactory {
113 public:
FakeSocketFactory()114 FakeSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server ATTRIBUTE_UNUSED,unsigned mark ATTRIBUTE_UNUSED,IDnsTlsSocketObserver * observer,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)115 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
116 const DnsTlsServer& server ATTRIBUTE_UNUSED,
117 unsigned mark ATTRIBUTE_UNUSED,
118 IDnsTlsSocketObserver* observer,
119 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
120 return std::make_unique<T>(observer);
121 }
122 };
123
make_echo(uint16_t id,const Slice query)124 bytevec make_echo(uint16_t id, const Slice query) {
125 bytevec response(query.size() + 2);
126 response[0] = id >> 8;
127 response[1] = id;
128 // Echo the query as the fake response.
129 memcpy(response.data() + 2, query.base(), query.size());
130 return response;
131 }
132
133 // Simplest possible fake server. This just echoes the query as the response.
134 class FakeSocketEcho : public IDnsTlsSocket {
135 public:
FakeSocketEcho(IDnsTlsSocketObserver * observer)136 explicit FakeSocketEcho(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
query(uint16_t id,const Slice query)137 bool query(uint16_t id, const Slice query) override {
138 // Return the response immediately (asynchronously).
139 std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
140 return true;
141 }
startHandshake()142 bool startHandshake() override { return true; }
143
144 private:
145 IDnsTlsSocketObserver* const mObserver;
146 };
147
148 class TransportTest : public BaseTest {};
149
TEST_F(TransportTest,Query)150 TEST_F(TransportTest, Query) {
151 FakeSocketFactory<FakeSocketEcho> factory;
152 DnsTlsTransport transport(SERVER1, MARK, &factory);
153 auto r = transport.query(makeSlice(QUERY)).get();
154
155 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
156 EXPECT_EQ(QUERY, r.response);
157 EXPECT_EQ(transport.getConnectCounter(), 1);
158 }
159
160 // Fake Socket that echoes the observed query ID as the response body.
161 class FakeSocketId : public IDnsTlsSocket {
162 public:
FakeSocketId(IDnsTlsSocketObserver * observer)163 explicit FakeSocketId(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
query(uint16_t id,const Slice query ATTRIBUTE_UNUSED)164 bool query(uint16_t id, const Slice query ATTRIBUTE_UNUSED) override {
165 // Return the response immediately (asynchronously).
166 bytevec response(4);
167 // Echo the ID in the header to match the response to the query.
168 // This will be overwritten by DnsTlsQueryMap.
169 response[0] = id >> 8;
170 response[1] = id;
171 // Echo the ID in the body, so that the test can verify which ID was used by
172 // DnsTlsQueryMap.
173 response[2] = id >> 8;
174 response[3] = id;
175 std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, response).detach();
176 return true;
177 }
startHandshake()178 bool startHandshake() override { return true; }
179
180 private:
181 IDnsTlsSocketObserver* const mObserver;
182 };
183
184 // Test that IDs are properly reused
TEST_F(TransportTest,IdReuse)185 TEST_F(TransportTest, IdReuse) {
186 FakeSocketFactory<FakeSocketId> factory;
187 DnsTlsTransport transport(SERVER1, MARK, &factory);
188 for (int i = 0; i < 100; ++i) {
189 // Send a query.
190 std::future<DnsTlsTransport::Result> f = transport.query(makeSlice(QUERY));
191 // Wait for the response.
192 DnsTlsTransport::Result r = f.get();
193 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
194
195 // All queries should have an observed ID of zero, because it is returned to the ID pool
196 // after each use.
197 EXPECT_EQ(0, (r.response[2] << 8) | r.response[3]);
198 }
199 EXPECT_EQ(transport.getConnectCounter(), 1);
200 }
201
202 // These queries might be handled in serial or parallel as they race the
203 // responses.
TEST_F(TransportTest,RacingQueries_10000)204 TEST_F(TransportTest, RacingQueries_10000) {
205 FakeSocketFactory<FakeSocketEcho> factory;
206 DnsTlsTransport transport(SERVER1, MARK, &factory);
207 std::vector<std::future<DnsTlsTransport::Result>> results;
208 // Fewer than 65536 queries to avoid ID exhaustion.
209 const int num_queries = 10000;
210 results.reserve(num_queries);
211 for (int i = 0; i < num_queries; ++i) {
212 results.push_back(transport.query(makeSlice(QUERY)));
213 }
214 for (auto& result : results) {
215 auto r = result.get();
216 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
217 EXPECT_EQ(QUERY, r.response);
218 }
219 EXPECT_EQ(transport.getConnectCounter(), 1);
220 }
221
222 // A server that waits until sDelay queries are queued before responding.
223 class FakeSocketDelay : public IDnsTlsSocket {
224 public:
FakeSocketDelay(IDnsTlsSocketObserver * observer)225 explicit FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
~FakeSocketDelay()226 ~FakeSocketDelay() {
227 std::lock_guard guard(mLock);
228 sDelay = 1;
229 sReverse = false;
230 sConnectable = true;
231 }
232 inline static size_t sDelay = 1;
233 inline static bool sReverse = false;
234 inline static bool sConnectable = true;
235
query(uint16_t id,const Slice query)236 bool query(uint16_t id, const Slice query) override {
237 LOG(DEBUG) << "FakeSocketDelay got query with ID " << int(id);
238 std::lock_guard guard(mLock);
239 // Check for duplicate IDs.
240 EXPECT_EQ(0U, mIds.count(id));
241 mIds.insert(id);
242
243 // Store response.
244 mResponses.push_back(make_echo(id, query));
245
246 LOG(DEBUG) << "Up to " << mResponses.size() << " out of " << sDelay << " queries";
247 if (mResponses.size() == sDelay) {
248 std::thread(&FakeSocketDelay::sendResponses, this).detach();
249 }
250 return true;
251 }
startHandshake()252 bool startHandshake() override { return sConnectable; }
253
254 private:
sendResponses()255 void sendResponses() {
256 std::lock_guard guard(mLock);
257 if (sReverse) {
258 std::reverse(std::begin(mResponses), std::end(mResponses));
259 }
260 for (auto& response : mResponses) {
261 mObserver->onResponse(response);
262 }
263 mIds.clear();
264 mResponses.clear();
265 }
266
267 std::mutex mLock;
268 IDnsTlsSocketObserver* const mObserver;
269 std::set<uint16_t> mIds GUARDED_BY(mLock);
270 std::vector<bytevec> mResponses GUARDED_BY(mLock);
271 };
272
TEST_F(TransportTest,ParallelColliding)273 TEST_F(TransportTest, ParallelColliding) {
274 FakeSocketDelay::sDelay = 10;
275 FakeSocketDelay::sReverse = false;
276 FakeSocketFactory<FakeSocketDelay> factory;
277 DnsTlsTransport transport(SERVER1, MARK, &factory);
278 std::vector<std::future<DnsTlsTransport::Result>> results;
279 // Fewer than 65536 queries to avoid ID exhaustion.
280 results.reserve(FakeSocketDelay::sDelay);
281 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
282 results.push_back(transport.query(makeSlice(QUERY)));
283 }
284 for (auto& result : results) {
285 auto r = result.get();
286 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
287 EXPECT_EQ(QUERY, r.response);
288 }
289 EXPECT_EQ(transport.getConnectCounter(), 1);
290 }
291
TEST_F(TransportTest,ParallelColliding_Max)292 TEST_F(TransportTest, ParallelColliding_Max) {
293 FakeSocketDelay::sDelay = 65536;
294 FakeSocketDelay::sReverse = false;
295 FakeSocketFactory<FakeSocketDelay> factory;
296 DnsTlsTransport transport(SERVER1, MARK, &factory);
297 std::vector<std::future<DnsTlsTransport::Result>> results;
298 // Exactly 65536 queries should still be possible in parallel,
299 // even if they all have the same original ID.
300 results.reserve(FakeSocketDelay::sDelay);
301 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
302 results.push_back(transport.query(makeSlice(QUERY)));
303 }
304 for (auto& result : results) {
305 auto r = result.get();
306 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
307 EXPECT_EQ(QUERY, r.response);
308 }
309 EXPECT_EQ(transport.getConnectCounter(), 1);
310 }
311
TEST_F(TransportTest,ParallelUnique)312 TEST_F(TransportTest, ParallelUnique) {
313 FakeSocketDelay::sDelay = 10;
314 FakeSocketDelay::sReverse = false;
315 FakeSocketFactory<FakeSocketDelay> factory;
316 DnsTlsTransport transport(SERVER1, MARK, &factory);
317 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
318 std::vector<std::future<DnsTlsTransport::Result>> results;
319 results.reserve(FakeSocketDelay::sDelay);
320 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
321 queries[i] = make_query(i, SIZE);
322 results.push_back(transport.query(makeSlice(queries[i])));
323 }
324 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
325 auto r = results[i].get();
326 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
327 EXPECT_EQ(queries[i], r.response);
328 }
329 EXPECT_EQ(transport.getConnectCounter(), 1);
330 }
331
TEST_F(TransportTest,ParallelUnique_Max)332 TEST_F(TransportTest, ParallelUnique_Max) {
333 FakeSocketDelay::sDelay = 65536;
334 FakeSocketDelay::sReverse = false;
335 FakeSocketFactory<FakeSocketDelay> factory;
336 DnsTlsTransport transport(SERVER1, MARK, &factory);
337 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
338 std::vector<std::future<DnsTlsTransport::Result>> results;
339 // Exactly 65536 queries should still be possible in parallel,
340 // and they should all be mapped correctly back to the original ID.
341 results.reserve(FakeSocketDelay::sDelay);
342 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
343 queries[i] = make_query(i, SIZE);
344 results.push_back(transport.query(makeSlice(queries[i])));
345 }
346 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
347 auto r = results[i].get();
348 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
349 EXPECT_EQ(queries[i], r.response);
350 }
351 EXPECT_EQ(transport.getConnectCounter(), 1);
352 }
353
TEST_F(TransportTest,IdExhaustion)354 TEST_F(TransportTest, IdExhaustion) {
355 const int num_queries = 65536;
356 // A delay of 65537 is unreachable, because the maximum number
357 // of outstanding queries is 65536.
358 FakeSocketDelay::sDelay = num_queries + 1;
359 FakeSocketDelay::sReverse = false;
360 FakeSocketFactory<FakeSocketDelay> factory;
361 DnsTlsTransport transport(SERVER1, MARK, &factory);
362 std::vector<std::future<DnsTlsTransport::Result>> results;
363 // Issue the maximum number of queries.
364 results.reserve(num_queries);
365 for (int i = 0; i < num_queries; ++i) {
366 results.push_back(transport.query(makeSlice(QUERY)));
367 }
368
369 // The ID space is now full, so subsequent queries should fail immediately.
370 auto r = transport.query(makeSlice(QUERY)).get();
371 EXPECT_EQ(DnsTlsTransport::Response::internal_error, r.code);
372 EXPECT_TRUE(r.response.empty());
373
374 for (auto& result : results) {
375 // All other queries should remain outstanding.
376 EXPECT_EQ(std::future_status::timeout,
377 result.wait_for(std::chrono::duration<int>::zero()));
378 }
379 EXPECT_EQ(transport.getConnectCounter(), 1);
380 }
381
382 // Responses can come back from the server in any order. This should have no
383 // effect on Transport's observed behavior.
TEST_F(TransportTest,ReverseOrder)384 TEST_F(TransportTest, ReverseOrder) {
385 FakeSocketDelay::sDelay = 10;
386 FakeSocketDelay::sReverse = true;
387 FakeSocketFactory<FakeSocketDelay> factory;
388 DnsTlsTransport transport(SERVER1, MARK, &factory);
389 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
390 std::vector<std::future<DnsTlsTransport::Result>> results;
391 results.reserve(FakeSocketDelay::sDelay);
392 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
393 queries[i] = make_query(i, SIZE);
394 results.push_back(transport.query(makeSlice(queries[i])));
395 }
396 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
397 auto r = results[i].get();
398 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
399 EXPECT_EQ(queries[i], r.response);
400 }
401 EXPECT_EQ(transport.getConnectCounter(), 1);
402 }
403
TEST_F(TransportTest,ReverseOrder_Max)404 TEST_F(TransportTest, ReverseOrder_Max) {
405 FakeSocketDelay::sDelay = 65536;
406 FakeSocketDelay::sReverse = true;
407 FakeSocketFactory<FakeSocketDelay> factory;
408 DnsTlsTransport transport(SERVER1, MARK, &factory);
409 std::vector<bytevec> queries(FakeSocketDelay::sDelay);
410 std::vector<std::future<DnsTlsTransport::Result>> results;
411 results.reserve(FakeSocketDelay::sDelay);
412 for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
413 queries[i] = make_query(i, SIZE);
414 results.push_back(transport.query(makeSlice(queries[i])));
415 }
416 for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
417 auto r = results[i].get();
418 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
419 EXPECT_EQ(queries[i], r.response);
420 }
421 EXPECT_EQ(transport.getConnectCounter(), 1);
422 }
423
424 // Returning null from the factory indicates a connection failure.
425 class NullSocketFactory : public IDnsTlsSocketFactory {
426 public:
NullSocketFactory()427 NullSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server ATTRIBUTE_UNUSED,unsigned mark ATTRIBUTE_UNUSED,IDnsTlsSocketObserver * observer ATTRIBUTE_UNUSED,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)428 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
429 const DnsTlsServer& server ATTRIBUTE_UNUSED,
430 unsigned mark ATTRIBUTE_UNUSED,
431 IDnsTlsSocketObserver* observer ATTRIBUTE_UNUSED,
432 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
433 return nullptr;
434 }
435 };
436
TEST_F(TransportTest,ConnectFail)437 TEST_F(TransportTest, ConnectFail) {
438 // Failure on creating socket.
439 NullSocketFactory factory1;
440 DnsTlsTransport transport1(SERVER1, MARK, &factory1);
441 auto r = transport1.query(makeSlice(QUERY)).get();
442
443 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
444 EXPECT_TRUE(r.response.empty());
445 EXPECT_EQ(transport1.getConnectCounter(), 1);
446
447 // Failure on handshaking.
448 FakeSocketDelay::sConnectable = false;
449 FakeSocketFactory<FakeSocketDelay> factory2;
450 DnsTlsTransport transport2(SERVER1, MARK, &factory2);
451 r = transport2.query(makeSlice(QUERY)).get();
452
453 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
454 EXPECT_TRUE(r.response.empty());
455 EXPECT_EQ(transport2.getConnectCounter(), 1);
456 }
457
458 // Simulate a socket that connects but then immediately receives a server
459 // close notification.
460 class FakeSocketClose : public IDnsTlsSocket {
461 public:
FakeSocketClose(IDnsTlsSocketObserver * observer)462 explicit FakeSocketClose(IDnsTlsSocketObserver* observer)
463 : mCloser(&IDnsTlsSocketObserver::onClosed, observer) {}
~FakeSocketClose()464 ~FakeSocketClose() { mCloser.join(); }
query(uint16_t id ATTRIBUTE_UNUSED,const Slice query ATTRIBUTE_UNUSED)465 bool query(uint16_t id ATTRIBUTE_UNUSED,
466 const Slice query ATTRIBUTE_UNUSED) override {
467 return true;
468 }
startHandshake()469 bool startHandshake() override { return true; }
470
471 private:
472 std::thread mCloser;
473 };
474
TEST_F(TransportTest,CloseRetryFail)475 TEST_F(TransportTest, CloseRetryFail) {
476 FakeSocketFactory<FakeSocketClose> factory;
477 DnsTlsTransport transport(SERVER1, MARK, &factory);
478 auto r = transport.query(makeSlice(QUERY)).get();
479
480 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
481 EXPECT_TRUE(r.response.empty());
482
483 // Reconnections might be triggered depending on the flag.
484 EXPECT_EQ(transport.getConnectCounter(),
485 Experiments::getInstance()->getFlag(DOT_MAXTRIES_FLAG, DnsTlsQueryMap::kMaxTries));
486 }
487
488 // Simulate a server that occasionally closes the connection and silently
489 // drops some queries.
490 class FakeSocketLimited : public IDnsTlsSocket {
491 public:
492 static int sLimit; // Number of queries to answer per socket.
493 static size_t sMaxSize; // Silently discard queries greater than this size.
FakeSocketLimited(IDnsTlsSocketObserver * observer)494 explicit FakeSocketLimited(IDnsTlsSocketObserver* observer)
495 : mObserver(observer), mQueries(0) {}
~FakeSocketLimited()496 ~FakeSocketLimited() {
497 {
498 LOG(DEBUG) << "~FakeSocketLimited acquiring mLock";
499 std::lock_guard guard(mLock);
500 LOG(DEBUG) << "~FakeSocketLimited acquired mLock";
501 for (auto& thread : mThreads) {
502 LOG(DEBUG) << "~FakeSocketLimited joining response thread";
503 thread.join();
504 LOG(DEBUG) << "~FakeSocketLimited joined response thread";
505 }
506 mThreads.clear();
507 }
508
509 if (mCloser) {
510 LOG(DEBUG) << "~FakeSocketLimited joining closer thread";
511 mCloser->join();
512 LOG(DEBUG) << "~FakeSocketLimited joined closer thread";
513 }
514 }
query(uint16_t id,const Slice query)515 bool query(uint16_t id, const Slice query) override {
516 LOG(DEBUG) << "FakeSocketLimited::query acquiring mLock";
517 std::lock_guard guard(mLock);
518 LOG(DEBUG) << "FakeSocketLimited::query acquired mLock";
519 ++mQueries;
520
521 if (mQueries <= sLimit) {
522 LOG(DEBUG) << "size " << query.size() << " vs. limit of " << sMaxSize;
523 if (query.size() <= sMaxSize) {
524 // Return the response immediately (asynchronously).
525 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query));
526 }
527 }
528 if (mQueries == sLimit) {
529 mCloser = std::make_unique<std::thread>(&FakeSocketLimited::sendClose, this);
530 }
531 return mQueries <= sLimit;
532 }
startHandshake()533 bool startHandshake() override { return true; }
534
535 private:
sendClose()536 void sendClose() {
537 {
538 LOG(DEBUG) << "FakeSocketLimited::sendClose acquiring mLock";
539 std::lock_guard guard(mLock);
540 LOG(DEBUG) << "FakeSocketLimited::sendClose acquired mLock";
541 for (auto& thread : mThreads) {
542 LOG(DEBUG) << "FakeSocketLimited::sendClose joining response thread";
543 thread.join();
544 LOG(DEBUG) << "FakeSocketLimited::sendClose joined response thread";
545 }
546 mThreads.clear();
547 }
548 mObserver->onClosed();
549 }
550 std::mutex mLock;
551 IDnsTlsSocketObserver* const mObserver;
552 int mQueries GUARDED_BY(mLock);
553 std::vector<std::thread> mThreads GUARDED_BY(mLock);
554 std::unique_ptr<std::thread> mCloser GUARDED_BY(mLock);
555 };
556
557 int FakeSocketLimited::sLimit;
558 size_t FakeSocketLimited::sMaxSize;
559
TEST_F(TransportTest,SilentDrop)560 TEST_F(TransportTest, SilentDrop) {
561 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
562 FakeSocketLimited::sMaxSize = 0; // Silently drop all queries
563 FakeSocketFactory<FakeSocketLimited> factory;
564 DnsTlsTransport transport(SERVER1, MARK, &factory);
565
566 // Queue up 10 queries. They will all be ignored, and after the 10th,
567 // the socket will close. Transport will retry them all, until they
568 // all hit the retry limit and expire.
569 std::vector<std::future<DnsTlsTransport::Result>> results;
570 results.reserve(FakeSocketLimited::sLimit);
571 for (int i = 0; i < FakeSocketLimited::sLimit; ++i) {
572 results.push_back(transport.query(makeSlice(QUERY)));
573 }
574 for (auto& result : results) {
575 auto r = result.get();
576 EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
577 EXPECT_TRUE(r.response.empty());
578 }
579
580 // Reconnections might be triggered depending on the flag.
581 EXPECT_EQ(transport.getConnectCounter(),
582 Experiments::getInstance()->getFlag(DOT_MAXTRIES_FLAG, DnsTlsQueryMap::kMaxTries));
583 }
584
TEST_F(TransportTest,PartialDrop)585 TEST_F(TransportTest, PartialDrop) {
586 FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
587 FakeSocketLimited::sMaxSize = SIZE - 2; // Silently drop "long" queries
588 FakeSocketFactory<FakeSocketLimited> factory;
589 DnsTlsTransport transport(SERVER1, MARK, &factory);
590
591 // Queue up 100 queries, alternating "short" which will be served and "long"
592 // which will be dropped.
593 const int num_queries = 10 * FakeSocketLimited::sLimit;
594 std::vector<bytevec> queries(num_queries);
595 std::vector<std::future<DnsTlsTransport::Result>> results;
596 results.reserve(num_queries);
597 for (int i = 0; i < num_queries; ++i) {
598 queries[i] = make_query(i, SIZE + (i % 2));
599 results.push_back(transport.query(makeSlice(queries[i])));
600 }
601 // Just check the short queries, which are at the even indices.
602 for (int i = 0; i < num_queries; i += 2) {
603 auto r = results[i].get();
604 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
605 EXPECT_EQ(queries[i], r.response);
606 }
607
608 // TODO: transport.getConnectCounter() seems not stable in this test. Find how to check the
609 // connect attempts for this test.
610 }
611
TEST_F(TransportTest,ConnectCounter)612 TEST_F(TransportTest, ConnectCounter) {
613 FakeSocketLimited::sLimit = 2; // Close the socket after 2 queries.
614 FakeSocketLimited::sMaxSize = SIZE; // No query drops.
615 FakeSocketFactory<FakeSocketLimited> factory;
616 DnsTlsTransport transport(SERVER1, MARK, &factory);
617
618 // Connecting on demand.
619 EXPECT_EQ(transport.getConnectCounter(), 0);
620
621 const int num_queries = 10;
622 std::vector<std::future<DnsTlsTransport::Result>> results;
623 results.reserve(num_queries);
624 for (int i = 0; i < num_queries; i++) {
625 // Reconnections take place every two queries.
626 results.push_back(transport.query(makeSlice(QUERY)));
627 }
628 for (int i = 0; i < num_queries; i++) {
629 auto r = results[i].get();
630 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
631 }
632
633 EXPECT_EQ(transport.getConnectCounter(), num_queries / FakeSocketLimited::sLimit);
634 }
635
636 // Simulate a malfunctioning server that injects extra miscellaneous
637 // responses to queries that were not asked. This will cause wrong answers but
638 // must not crash the Transport.
639 class FakeSocketGarbage : public IDnsTlsSocket {
640 public:
FakeSocketGarbage(IDnsTlsSocketObserver * observer)641 explicit FakeSocketGarbage(IDnsTlsSocketObserver* observer) : mObserver(observer) {
642 // Inject a garbage event.
643 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(ID + 1, SIZE));
644 }
~FakeSocketGarbage()645 ~FakeSocketGarbage() {
646 std::lock_guard guard(mLock);
647 for (auto& thread : mThreads) {
648 thread.join();
649 }
650 }
query(uint16_t id,const Slice query)651 bool query(uint16_t id, const Slice query) override {
652 std::lock_guard guard(mLock);
653 // Return the response twice.
654 auto echo = make_echo(id, query);
655 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
656 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
657 // Also return some other garbage
658 mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
659 return true;
660 }
startHandshake()661 bool startHandshake() override { return true; }
662
663 private:
664 std::mutex mLock;
665 std::vector<std::thread> mThreads GUARDED_BY(mLock);
666 IDnsTlsSocketObserver* const mObserver;
667 };
668
TEST_F(TransportTest,IgnoringGarbage)669 TEST_F(TransportTest, IgnoringGarbage) {
670 FakeSocketFactory<FakeSocketGarbage> factory;
671 DnsTlsTransport transport(SERVER1, MARK, &factory);
672 for (int i = 0; i < 10; ++i) {
673 auto r = transport.query(makeSlice(QUERY)).get();
674
675 EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
676 // Don't check the response because this server is malfunctioning.
677 }
678 EXPECT_EQ(transport.getConnectCounter(), 1);
679 }
680
681 // Dispatcher tests
682 class DispatcherTest : public BaseTest {};
683
TEST_F(DispatcherTest,Query)684 TEST_F(DispatcherTest, Query) {
685 bytevec ans(4096);
686 int resplen = 0;
687 bool connectTriggered = false;
688
689 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
690 DnsTlsDispatcher dispatcher(std::move(factory));
691 auto r = dispatcher.query(SERVER1, NETID, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
692 &connectTriggered);
693
694 EXPECT_EQ(DnsTlsTransport::Response::success, r);
695 EXPECT_EQ(int(QUERY.size()), resplen);
696 EXPECT_TRUE(connectTriggered);
697 ans.resize(resplen);
698 EXPECT_EQ(QUERY, ans);
699
700 // Expect to reuse the connection.
701 r = dispatcher.query(SERVER1, NETID, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
702 &connectTriggered);
703 EXPECT_EQ(DnsTlsTransport::Response::success, r);
704 EXPECT_FALSE(connectTriggered);
705 }
706
TEST_F(DispatcherTest,AnswerTooLarge)707 TEST_F(DispatcherTest, AnswerTooLarge) {
708 bytevec ans(SIZE - 1); // Too small to hold the answer
709 int resplen = 0;
710 bool connectTriggered = false;
711
712 auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
713 DnsTlsDispatcher dispatcher(std::move(factory));
714 auto r = dispatcher.query(SERVER1, NETID, MARK, makeSlice(QUERY), makeSlice(ans), &resplen,
715 &connectTriggered);
716
717 EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
718 EXPECT_TRUE(connectTriggered);
719 }
720
721 template<class T>
722 class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
723 public:
TrackingFakeSocketFactory()724 TrackingFakeSocketFactory() {}
createDnsTlsSocket(const DnsTlsServer & server,unsigned mark,IDnsTlsSocketObserver * observer,DnsTlsSessionCache * cache ATTRIBUTE_UNUSED)725 std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
726 const DnsTlsServer& server,
727 unsigned mark,
728 IDnsTlsSocketObserver* observer,
729 DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
730 std::lock_guard guard(mLock);
731 keys.emplace(mark, server);
732 return std::make_unique<T>(observer);
733 }
734 std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
735
736 private:
737 std::mutex mLock;
738 };
739
TEST_F(DispatcherTest,Dispatching)740 TEST_F(DispatcherTest, Dispatching) {
741 FakeSocketDelay::sDelay = 5;
742 FakeSocketDelay::sReverse = true;
743 auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketDelay>>();
744 auto* weak_factory = factory.get(); // Valid as long as dispatcher is in scope.
745 DnsTlsDispatcher dispatcher(std::move(factory));
746
747 // Populate a vector of two servers and two socket marks, four combinations
748 // in total.
749 std::vector<std::pair<unsigned, DnsTlsServer>> keys;
750 keys.emplace_back(MARK, SERVER1);
751 keys.emplace_back(MARK + 1, SERVER1);
752 keys.emplace_back(MARK, V4ADDR2);
753 keys.emplace_back(MARK + 1, V4ADDR2);
754
755 // Do several queries on each server. They should all succeed.
756 std::vector<std::thread> threads;
757 for (size_t i = 0; i < FakeSocketDelay::sDelay * keys.size(); ++i) {
758 auto key = keys[i % keys.size()];
759 threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
760 auto q = make_query(i, SIZE);
761 bytevec ans(4096);
762 int resplen = 0;
763 bool connectTriggered = false;
764 unsigned mark = key.first;
765 unsigned netId = key.first;
766 const DnsTlsServer& server = key.second;
767 auto r = dispatcher->query(server, netId, mark, makeSlice(q), makeSlice(ans), &resplen,
768 &connectTriggered);
769 EXPECT_EQ(DnsTlsTransport::Response::success, r);
770 EXPECT_EQ(int(q.size()), resplen);
771 ans.resize(resplen);
772 EXPECT_EQ(q, ans);
773 }, &dispatcher);
774 }
775 for (auto& thread : threads) {
776 thread.join();
777 }
778 // We expect that the factory created one socket for each key.
779 EXPECT_EQ(keys.size(), weak_factory->keys.size());
780 for (auto& key : keys) {
781 EXPECT_EQ(1U, weak_factory->keys.count(key));
782 }
783 }
784
785 // Check DnsTlsServer's comparison logic.
786 AddressComparator ADDRESS_COMPARATOR;
isAddressEqual(const DnsTlsServer & s1,const DnsTlsServer & s2)787 bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
788 bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
789 bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
790 EXPECT_FALSE(cmp1 && cmp2);
791 return !cmp1 && !cmp2;
792 }
793
checkUnequal(const DnsTlsServer & s1,const DnsTlsServer & s2)794 void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
795 EXPECT_TRUE(s1 == s1);
796 EXPECT_TRUE(s2 == s2);
797 EXPECT_TRUE(isAddressEqual(s1, s1));
798 EXPECT_TRUE(isAddressEqual(s2, s2));
799
800 EXPECT_TRUE(s1 < s2 ^ s2 < s1);
801 EXPECT_FALSE(s1 == s2);
802 EXPECT_FALSE(s2 == s1);
803 }
804
checkEqual(const DnsTlsServer & s1,const DnsTlsServer & s2)805 void checkEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
806 EXPECT_TRUE(s1 == s1);
807 EXPECT_TRUE(s2 == s2);
808 EXPECT_TRUE(isAddressEqual(s1, s1));
809 EXPECT_TRUE(isAddressEqual(s2, s2));
810
811 EXPECT_FALSE(s1 < s2);
812 EXPECT_FALSE(s2 < s1);
813 EXPECT_TRUE(s1 == s2);
814 EXPECT_TRUE(s2 == s1);
815 }
816
817 class ServerTest : public BaseTest {};
818
TEST_F(ServerTest,IPv4)819 TEST_F(ServerTest, IPv4) {
820 checkUnequal(V4ADDR1, V4ADDR2);
821 EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2));
822 }
823
TEST_F(ServerTest,IPv6)824 TEST_F(ServerTest, IPv6) {
825 checkUnequal(V6ADDR1, V6ADDR2);
826 EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2));
827 }
828
TEST_F(ServerTest,MixedAddressFamily)829 TEST_F(ServerTest, MixedAddressFamily) {
830 checkUnequal(V6ADDR1, V4ADDR1);
831 EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1));
832 }
833
TEST_F(ServerTest,IPv6ScopeId)834 TEST_F(ServerTest, IPv6ScopeId) {
835 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
836 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
837 addr1->sin6_scope_id = 1;
838 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
839 addr2->sin6_scope_id = 2;
840 checkUnequal(s1, s2);
841 EXPECT_FALSE(isAddressEqual(s1, s2));
842
843 EXPECT_FALSE(s1.wasExplicitlyConfigured());
844 EXPECT_FALSE(s2.wasExplicitlyConfigured());
845 }
846
TEST_F(ServerTest,IPv6FlowInfo)847 TEST_F(ServerTest, IPv6FlowInfo) {
848 DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
849 sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
850 addr1->sin6_flowinfo = 1;
851 sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
852 addr2->sin6_flowinfo = 2;
853 // All comparisons ignore flowinfo.
854 EXPECT_EQ(s1, s2);
855 EXPECT_TRUE(isAddressEqual(s1, s2));
856
857 EXPECT_FALSE(s1.wasExplicitlyConfigured());
858 EXPECT_FALSE(s2.wasExplicitlyConfigured());
859 }
860
TEST_F(ServerTest,Port)861 TEST_F(ServerTest, Port) {
862 DnsTlsServer s1, s2;
863 parseServer("192.0.2.1", 853, &s1.ss);
864 parseServer("192.0.2.1", 854, &s2.ss);
865 checkUnequal(s1, s2);
866 EXPECT_TRUE(isAddressEqual(s1, s2));
867 EXPECT_EQ(s1.toIpString(), "192.0.2.1");
868 EXPECT_EQ(s2.toIpString(), "192.0.2.1");
869
870 DnsTlsServer s3, s4;
871 parseServer("2001:db8::1", 853, &s3.ss);
872 parseServer("2001:db8::1", 852, &s4.ss);
873 checkUnequal(s3, s4);
874 EXPECT_TRUE(isAddressEqual(s3, s4));
875 EXPECT_EQ(s3.toIpString(), "2001:db8::1");
876 EXPECT_EQ(s4.toIpString(), "2001:db8::1");
877
878 EXPECT_FALSE(s1.wasExplicitlyConfigured());
879 EXPECT_FALSE(s2.wasExplicitlyConfigured());
880 }
881
TEST_F(ServerTest,Name)882 TEST_F(ServerTest, Name) {
883 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
884 s1.name = SERVERNAME1;
885 checkUnequal(s1, s2);
886 s2.name = SERVERNAME2;
887 checkUnequal(s1, s2);
888 EXPECT_TRUE(isAddressEqual(s1, s2));
889
890 EXPECT_TRUE(s1.wasExplicitlyConfigured());
891 EXPECT_TRUE(s2.wasExplicitlyConfigured());
892 }
893
TEST_F(ServerTest,State)894 TEST_F(ServerTest, State) {
895 DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
896 checkEqual(s1, s2);
897 s1.setValidationState(Validation::success);
898 checkEqual(s1, s2);
899 s2.setValidationState(Validation::fail);
900 checkEqual(s1, s2);
901 s1.setActive(true);
902 checkEqual(s1, s2);
903 s2.setActive(false);
904 checkEqual(s1, s2);
905
906 EXPECT_EQ(s1.validationState(), Validation::success);
907 EXPECT_EQ(s2.validationState(), Validation::fail);
908 EXPECT_TRUE(s1.active());
909 EXPECT_FALSE(s2.active());
910 }
911
TEST(QueryMapTest,Basic)912 TEST(QueryMapTest, Basic) {
913 DnsTlsQueryMap map;
914
915 EXPECT_TRUE(map.empty());
916
917 bytevec q0 = make_query(999, SIZE);
918 bytevec q1 = make_query(888, SIZE);
919 bytevec q2 = make_query(777, SIZE);
920
921 auto f0 = map.recordQuery(makeSlice(q0));
922 auto f1 = map.recordQuery(makeSlice(q1));
923 auto f2 = map.recordQuery(makeSlice(q2));
924
925 // Check return values of recordQuery
926 EXPECT_EQ(0, f0->query.newId);
927 EXPECT_EQ(1, f1->query.newId);
928 EXPECT_EQ(2, f2->query.newId);
929
930 // Check side effects of recordQuery
931 EXPECT_FALSE(map.empty());
932
933 auto all = map.getAll();
934 EXPECT_EQ(3U, all.size());
935
936 EXPECT_EQ(0, all[0].newId);
937 EXPECT_EQ(1, all[1].newId);
938 EXPECT_EQ(2, all[2].newId);
939
940 EXPECT_EQ(q0, all[0].query);
941 EXPECT_EQ(q1, all[1].query);
942 EXPECT_EQ(q2, all[2].query);
943
944 bytevec a0 = make_query(0, SIZE);
945 bytevec a1 = make_query(1, SIZE);
946 bytevec a2 = make_query(2, SIZE);
947
948 // Return responses out of order
949 map.onResponse(a2);
950 map.onResponse(a0);
951 map.onResponse(a1);
952
953 EXPECT_TRUE(map.empty());
954
955 auto r0 = f0->result.get();
956 auto r1 = f1->result.get();
957 auto r2 = f2->result.get();
958
959 EXPECT_EQ(DnsTlsQueryMap::Response::success, r0.code);
960 EXPECT_EQ(DnsTlsQueryMap::Response::success, r1.code);
961 EXPECT_EQ(DnsTlsQueryMap::Response::success, r2.code);
962
963 const bytevec& d0 = r0.response;
964 const bytevec& d1 = r1.response;
965 const bytevec& d2 = r2.response;
966
967 // The ID should match the query
968 EXPECT_EQ(999, d0[0] << 8 | d0[1]);
969 EXPECT_EQ(888, d1[0] << 8 | d1[1]);
970 EXPECT_EQ(777, d2[0] << 8 | d2[1]);
971 // The body should match the answer
972 EXPECT_EQ(bytevec(a0.begin() + 2, a0.end()), bytevec(d0.begin() + 2, d0.end()));
973 EXPECT_EQ(bytevec(a1.begin() + 2, a1.end()), bytevec(d1.begin() + 2, d1.end()));
974 EXPECT_EQ(bytevec(a2.begin() + 2, a2.end()), bytevec(d2.begin() + 2, d2.end()));
975 }
976
TEST(QueryMapTest,FillHole)977 TEST(QueryMapTest, FillHole) {
978 DnsTlsQueryMap map;
979 std::vector<std::unique_ptr<DnsTlsQueryMap::QueryFuture>> futures(UINT16_MAX + 1);
980 for (uint32_t i = 0; i <= UINT16_MAX; ++i) {
981 futures[i] = map.recordQuery(makeSlice(QUERY));
982 ASSERT_TRUE(futures[i]); // answers[i] should be nonnull.
983 EXPECT_EQ(i, futures[i]->query.newId);
984 }
985
986 // The map should now be full.
987 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
988
989 // Trying to add another query should fail because the map is full.
990 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
991
992 // Send an answer to query 40000
993 auto answer = make_query(40000, SIZE);
994 map.onResponse(answer);
995 auto result = futures[40000]->result.get();
996 EXPECT_EQ(DnsTlsQueryMap::Response::success, result.code);
997 EXPECT_EQ(ID, result.response[0] << 8 | result.response[1]);
998 EXPECT_EQ(bytevec(answer.begin() + 2, answer.end()),
999 bytevec(result.response.begin() + 2, result.response.end()));
1000
1001 // There should now be room in the map.
1002 EXPECT_EQ(size_t(UINT16_MAX), map.getAll().size());
1003 auto f = map.recordQuery(makeSlice(QUERY));
1004 ASSERT_TRUE(f);
1005 EXPECT_EQ(40000, f->query.newId);
1006
1007 // The map should now be full again.
1008 EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
1009 EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
1010 }
1011
1012 class DnsTlsSocketTest : public ::testing::Test {
1013 protected:
1014 class MockDnsTlsSocketObserver : public IDnsTlsSocketObserver {
1015 public:
1016 MOCK_METHOD(void, onClosed, (), (override));
1017 MOCK_METHOD(void, onResponse, (std::vector<uint8_t>), (override));
1018 };
1019
DnsTlsSocketTest()1020 DnsTlsSocketTest() { parseServer(kTlsAddr, std::stoi(kTlsPort), &server.ss); }
1021
makeDnsTlsSocket(IDnsTlsSocketObserver * observer)1022 std::unique_ptr<DnsTlsSocket> makeDnsTlsSocket(IDnsTlsSocketObserver* observer) {
1023 return std::make_unique<DnsTlsSocket>(this->server, MARK, observer, &this->cache);
1024 }
1025
enableAsyncHandshake(const std::unique_ptr<DnsTlsSocket> & socket)1026 void enableAsyncHandshake(const std::unique_ptr<DnsTlsSocket>& socket) {
1027 ASSERT_TRUE(socket);
1028 DnsTlsSocket* delegate = socket.get();
1029 std::lock_guard guard(delegate->mLock);
1030 delegate->mAsyncHandshake = true;
1031 }
1032
1033 static constexpr char kTlsAddr[] = "127.0.0.3";
1034 static constexpr char kTlsPort[] = "8530"; // High-numbered port so root isn't required.
1035 static constexpr char kBackendAddr[] = "192.0.2.1";
1036 static constexpr char kBackendPort[] = "8531"; // High-numbered port so root isn't required.
1037
1038 test::DnsTlsFrontend tls{kTlsAddr, kTlsPort, kBackendAddr, kBackendPort};
1039
1040 DnsTlsServer server;
1041 DnsTlsSessionCache cache;
1042 };
1043
TEST_F(DnsTlsSocketTest,SlowDestructor)1044 TEST_F(DnsTlsSocketTest, SlowDestructor) {
1045 ASSERT_TRUE(tls.startServer());
1046
1047 MockDnsTlsSocketObserver observer;
1048 auto socket = makeDnsTlsSocket(&observer);
1049
1050 ASSERT_TRUE(socket->initialize());
1051 ASSERT_TRUE(socket->startHandshake());
1052
1053 // Test: Time the socket destructor. This should be fast.
1054 auto before = std::chrono::steady_clock::now();
1055 EXPECT_CALL(observer, onClosed);
1056 socket.reset();
1057 auto after = std::chrono::steady_clock::now();
1058 auto delay = after - before;
1059 LOG(DEBUG) << "Shutdown took " << delay / std::chrono::nanoseconds{1} << "ns";
1060 // Shutdown should complete in milliseconds, but if the shutdown signal is lost
1061 // it will wait for the timeout, which is expected to take 20seconds.
1062 EXPECT_LT(delay, std::chrono::seconds{5});
1063 }
1064
TEST_F(DnsTlsSocketTest,StartHandshake)1065 TEST_F(DnsTlsSocketTest, StartHandshake) {
1066 ASSERT_TRUE(tls.startServer());
1067
1068 MockDnsTlsSocketObserver observer;
1069 auto socket = makeDnsTlsSocket(&observer);
1070
1071 // Call the function before the call to initialize().
1072 EXPECT_FALSE(socket->startHandshake());
1073
1074 // Call the function after the call to initialize().
1075 EXPECT_TRUE(socket->initialize());
1076 EXPECT_TRUE(socket->startHandshake());
1077
1078 // Call both of them again.
1079 EXPECT_FALSE(socket->initialize());
1080 EXPECT_FALSE(socket->startHandshake());
1081
1082 // Should happen when joining the loop thread in |socket| destruction.
1083 EXPECT_CALL(observer, onClosed);
1084 }
1085
TEST_F(DnsTlsSocketTest,ShutdownSignal)1086 TEST_F(DnsTlsSocketTest, ShutdownSignal) {
1087 ASSERT_TRUE(tls.startServer());
1088
1089 MockDnsTlsSocketObserver observer;
1090 std::unique_ptr<DnsTlsSocket> socket;
1091
1092 const auto setupAndStartHandshake = [&]() {
1093 socket = makeDnsTlsSocket(&observer);
1094 EXPECT_TRUE(socket->initialize());
1095 enableAsyncHandshake(socket);
1096 EXPECT_TRUE(socket->startHandshake());
1097 };
1098 const auto triggerShutdown = [&](const std::string& traceLog) {
1099 SCOPED_TRACE(traceLog);
1100 auto before = std::chrono::steady_clock::now();
1101 EXPECT_CALL(observer, onClosed);
1102 socket.reset();
1103 auto after = std::chrono::steady_clock::now();
1104 auto delay = after - before;
1105 LOG(INFO) << "Shutdown took " << delay / std::chrono::nanoseconds{1} << "ns";
1106 EXPECT_LT(delay, std::chrono::seconds{1});
1107 };
1108
1109 tls.setHangOnHandshakeForTesting(true);
1110
1111 // Test 1: Reset the DnsTlsSocket which is doing the handshake.
1112 setupAndStartHandshake();
1113 triggerShutdown("Shutdown handshake w/o query requests");
1114
1115 // Test 2: Reset the DnsTlsSocket which is doing the handshake with some query requests.
1116 setupAndStartHandshake();
1117
1118 // DnsTlsSocket doesn't report the status of pending queries. The decision whether to mark
1119 // a query request as failed or not is made in DnsTlsTransport.
1120 EXPECT_CALL(observer, onResponse).Times(0);
1121 EXPECT_TRUE(socket->query(1, makeSlice(QUERY)));
1122 EXPECT_TRUE(socket->query(2, makeSlice(QUERY)));
1123 triggerShutdown("Shutdown handshake w/ query requests");
1124 }
1125
1126 } // end of namespace net
1127 } // end of namespace android
1128