1 /*
2  * Copyright 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "model/devices/link_layer_socket_device.h"
18 
19 #include <gtest/gtest.h>
20 #include <cstdint>
21 #include <cstring>
22 #include <vector>
23 
24 #include <netdb.h>
25 #include <netinet/in.h>
26 #include <sys/socket.h>
27 #include <sys/types.h>
28 #include <unistd.h>
29 
30 #include "model/setup/async_manager.h"
31 #include "packets/link_layer/command_view.h"
32 
33 std::vector<uint8_t> count = {
34     0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
35     0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
36     0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f,
37     0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f,
38 };
39 
40 using test_vendor_lib::packets::CommandBuilder;
41 using test_vendor_lib::packets::CommandView;
42 using test_vendor_lib::packets::LinkLayerPacketBuilder;
43 using test_vendor_lib::packets::LinkLayerPacketView;
44 using test_vendor_lib::packets::PacketView;
45 using test_vendor_lib::packets::View;
46 
47 static const size_t kMaxConnections = 300;
48 
49 namespace test_vendor_lib {
50 
51 class LinkLayerSocketDeviceTest : public ::testing::Test {
52  public:
53   static const uint16_t kPort = 6123;
54 
55  protected:
56   class MockPhyLayer : public PhyLayer {
57    public:
MockPhyLayer(const std::function<void (std::shared_ptr<LinkLayerPacketBuilder>)> & on_receive)58     MockPhyLayer(const std::function<void(std::shared_ptr<LinkLayerPacketBuilder>)>& on_receive)
59         : PhyLayer(Phy::Type::LOW_ENERGY, 0, [](LinkLayerPacketView) {}), on_receive_(on_receive) {}
Send(const std::shared_ptr<LinkLayerPacketBuilder> packet)60     void Send(const std::shared_ptr<LinkLayerPacketBuilder> packet) override {
61       on_receive_(packet);
62     }
Receive(LinkLayerPacketView)63     void Receive(LinkLayerPacketView) override {}
TimerTick()64     void TimerTick() override {}
65 
66    private:
67     std::function<void(std::shared_ptr<LinkLayerPacketBuilder>)> on_receive_;
68   };
69 
StartServer()70   int StartServer() {
71     struct sockaddr_in serv_addr;
72     int fd = socket(AF_INET, SOCK_STREAM, 0);
73     EXPECT_FALSE(fd < 0);
74 
75     memset(&serv_addr, 0, sizeof(serv_addr));
76     serv_addr.sin_family = AF_INET;
77     serv_addr.sin_addr.s_addr = INADDR_ANY;
78     serv_addr.sin_port = htons(kPort);
79     int reuse_flag = 1;
80     EXPECT_FALSE(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &reuse_flag, sizeof(reuse_flag)) < 0);
81     EXPECT_FALSE(bind(fd, (sockaddr*)&serv_addr, sizeof(serv_addr)) < 0);
82 
83     listen(fd, 1);
84     return fd;
85   }
86 
AcceptConnection(int fd)87   int AcceptConnection(int fd) {
88     return accept(fd, NULL, NULL);
89   }
90 
ValidatePacket(size_t index,bool at_server,std::shared_ptr<LinkLayerPacketBuilder> received)91   void ValidatePacket(size_t index, bool at_server, std::shared_ptr<LinkLayerPacketBuilder> received) {
92     /* Convert the Builder into a View */
93     std::shared_ptr<std::vector<uint8_t>> packet_ptr = std::make_shared<std::vector<uint8_t>>();
94     std::back_insert_iterator<std::vector<uint8_t>> it(*packet_ptr);
95     received->Serialize(it);
96     LinkLayerPacketView received_view = LinkLayerPacketView::Create(packet_ptr);
97 
98     /* Validate received packet */
99     ASSERT_EQ(received_view.GetSourceAddress(), source_);
100     ASSERT_EQ(received_view.GetDestinationAddress(), dest_);
101     ASSERT_EQ(Link::PacketType::COMMAND, received_view.GetType());
102     CommandView command_view = CommandView::GetCommand(received_view);
103     if (at_server) {
104       ASSERT_EQ(client_opcodes_[index], command_view.GetOpcode());
105     } else {
106       ASSERT_EQ(server_opcodes_[index], command_view.GetOpcode());
107     }
108     auto args_itr = command_view.GetData();
109     ASSERT_EQ(args_itr.NumBytesRemaining(), count.size());
110     for (size_t i = 0; i < count.size(); i++) {
111       ASSERT_EQ(*args_itr++, count[i]);
112     }
113     if (at_server) {
114       validated_client_packets_[index]++;
115     } else {
116       validated_server_packets_[index]++;
117     }
118   }
119 
SetUp()120   void SetUp() override {
121     servers_.reserve(kMaxConnections);
122     clients_.reserve(kMaxConnections);
123     socket_fd_ = StartServer();
124 
125     async_manager_.WatchFdForNonBlockingReads(socket_fd_, [this](int fd) {
126       int connection_fd = AcceptConnection(fd);
127       ASSERT_GE(connection_fd, 0);
128       size_t index = servers_.size();
129       servers_.emplace_back(connection_fd, Phy::Type::LOW_ENERGY);
130       ASSERT_EQ(servers_.size() - 1, index) << "Race condition";
131       std::shared_ptr<MockPhyLayer> mock_phy = std::make_shared<MockPhyLayer>(
132           [this, index](std::shared_ptr<LinkLayerPacketBuilder> received) { ValidatePacket(index, true, received); });
133       servers_[index].RegisterPhyLayer(mock_phy);
134     });
135   }
136 
TearDown()137   void TearDown() override {
138     async_manager_.StopWatchingFileDescriptor(socket_fd_);
139     close(socket_fd_);
140   }
141 
ConnectClient()142   int ConnectClient() {
143     int socket_cli_fd = socket(AF_INET, SOCK_STREAM, 0);
144     EXPECT_FALSE(socket_cli_fd < 0);
145 
146     struct hostent* server;
147     server = gethostbyname("localhost");
148     EXPECT_FALSE(server == NULL);
149 
150     struct sockaddr_in serv_addr;
151     memset((void*)&serv_addr, 0, sizeof(serv_addr));
152     serv_addr.sin_family = AF_INET;
153     serv_addr.sin_addr.s_addr = INADDR_ANY;
154     serv_addr.sin_port = htons(kPort);
155 
156     int result = connect(socket_cli_fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr));
157     EXPECT_FALSE(result < 0);
158 
159     EXPECT_GE(socket_cli_fd, 0);
160 
161     return socket_cli_fd;
162   }
163 
ValidateConnection(size_t pair_id)164   void ValidateConnection(size_t pair_id) {
165     ASSERT_GT(clients_.size(), pair_id);
166     ASSERT_GT(servers_.size(), pair_id);
167   }
168 
CreateConnection()169   size_t CreateConnection() {
170     int fd = ConnectClient();
171     size_t index = clients_.size();
172     clients_.emplace_back(fd, Phy::Type::LOW_ENERGY);
173     std::shared_ptr<MockPhyLayer> mock_phy = std::make_shared<MockPhyLayer>(
174         [this, index](std::shared_ptr<LinkLayerPacketBuilder> received) { ValidatePacket(index, false, received); });
175     clients_[index].RegisterPhyLayer(mock_phy);
176     for (size_t timeout = 10; timeout > 0 && clients_.size() > servers_.size(); timeout--) {
177       sleep(0);  // Wait for server to be created
178     }
179     ValidateConnection(index);
180     return index;
181   }
182 
NextPacket()183   LinkLayerPacketView NextPacket() {
184     std::shared_ptr<std::vector<uint8_t>> count_shared = std::make_shared<std::vector<uint8_t>>(count);
185     LinkLayerPacketView view = LinkLayerPacketView::Create(count_shared);
186     return view;
187   }
188 
SendFromClient(size_t pair_id)189   void SendFromClient(size_t pair_id) {
190     ASSERT_GT(clients_.size(), pair_id);
191     LinkLayerPacketView view = NextPacket();
192     client_opcodes_[pair_id] = CommandView::GetCommand(view).GetOpcode();
193     clients_[pair_id].IncomingPacket(view);
194   }
195 
SendFromServer(size_t pair_id)196   void SendFromServer(size_t pair_id) {
197     ASSERT_GT(servers_.size(), pair_id);
198     LinkLayerPacketView view = NextPacket();
199     server_opcodes_[pair_id] = CommandView::GetCommand(view).GetOpcode();
200     servers_[pair_id].IncomingPacket(view);
201   }
202 
ReadFromClient(size_t pair_id)203   void ReadFromClient(size_t pair_id) {
204     ASSERT_GT(clients_.size(), pair_id);
205     size_t validated_packets = validated_server_packets_[pair_id];
206     for (size_t tries = 0; tries < 10 && validated_server_packets_[pair_id] == validated_packets; tries++) {
207       clients_[pair_id].TimerTick();
208     }
209     ASSERT_EQ(validated_server_packets_[pair_id], validated_packets + 1);
210   }
211 
ReadFromServer(size_t pair_id)212   void ReadFromServer(size_t pair_id) {
213     ASSERT_GT(servers_.size(), pair_id);
214     size_t validated_packets = validated_client_packets_[pair_id];
215     for (size_t tries = 0; tries < 10 && validated_client_packets_[pair_id] == validated_packets; tries++) {
216       servers_[pair_id].TimerTick();
217     }
218     ASSERT_EQ(validated_client_packets_[pair_id], validated_packets + 1);
219   }
220 
221  private:
222   uint16_t packet_id_{1};
223   AsyncManager async_manager_;
224   int socket_fd_;
225   std::vector<LinkLayerSocketDevice> servers_;
226   std::vector<LinkLayerSocketDevice> clients_;
227   uint16_t server_opcodes_[kMaxConnections]{0};
228   uint16_t client_opcodes_[kMaxConnections]{0};
229   size_t validated_server_packets_[kMaxConnections]{0};
230   size_t validated_client_packets_[kMaxConnections]{0};
231   Address source_{{1, 2, 3, 4, 5, 6}};
232   Address dest_{{6, 5, 4, 3, 2, 1}};
233 };
234 
TEST_F(LinkLayerSocketDeviceTest,TestClientFirst)235 TEST_F(LinkLayerSocketDeviceTest, TestClientFirst) {
236   size_t pair_id = CreateConnection();
237   ASSERT_EQ(pair_id, 0u);
238   ValidateConnection(pair_id);
239 
240   SendFromClient(pair_id);
241   ReadFromServer(pair_id);
242 }
243 
TEST_F(LinkLayerSocketDeviceTest,TestServerFirst)244 TEST_F(LinkLayerSocketDeviceTest, TestServerFirst) {
245   size_t pair_id = CreateConnection();
246   ASSERT_EQ(pair_id, 0u);
247 
248   SendFromServer(pair_id);
249   ReadFromClient(pair_id);
250 }
251 
TEST_F(LinkLayerSocketDeviceTest,TestMultiplePackets)252 TEST_F(LinkLayerSocketDeviceTest, TestMultiplePackets) {
253   static const int num_packets = 30;
254   size_t pair_id = CreateConnection();
255   ASSERT_EQ(pair_id, 0u);
256   for (int i = 0; i < num_packets; i++) {
257     SendFromClient(pair_id);
258     SendFromServer(pair_id);
259     ReadFromServer(pair_id);
260     ReadFromClient(pair_id);
261   }
262 }
263 
TEST_F(LinkLayerSocketDeviceTest,TestMultipleConnectionsFromServer)264 TEST_F(LinkLayerSocketDeviceTest, TestMultipleConnectionsFromServer) {
265   static size_t last_pair_id = -1;
266   size_t pair_id;
267   for (size_t i = 0; i < kMaxConnections; i++) {
268     pair_id = CreateConnection();
269     ASSERT_EQ(pair_id, last_pair_id + 1);
270     last_pair_id = pair_id;
271     SendFromServer(pair_id);
272     ReadFromClient(pair_id);
273   }
274 }
275 
TEST_F(LinkLayerSocketDeviceTest,TestMultipleConnectionsFromClient)276 TEST_F(LinkLayerSocketDeviceTest, TestMultipleConnectionsFromClient) {
277   for (size_t i = 0; i < kMaxConnections; i++) {
278     size_t pair_id = CreateConnection();
279     ASSERT_EQ(pair_id, i);
280     SendFromClient(pair_id);
281     ReadFromServer(pair_id);
282   }
283 }
284 
TEST_F(LinkLayerSocketDeviceTest,TestMultipleConnections)285 TEST_F(LinkLayerSocketDeviceTest, TestMultipleConnections) {
286   for (size_t i = 0; i < kMaxConnections; i++) {
287     size_t pair_id = CreateConnection();
288     ASSERT_EQ(pair_id, i);
289     SendFromClient(pair_id);
290     SendFromServer(pair_id);
291     ReadFromClient(pair_id);
292     ReadFromServer(pair_id);
293   }
294 }
295 
296 }  // namespace test_vendor_lib
297