1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <linux/netfilter/nfnetlink_log.h>
18 
19 #include <arpa/inet.h>
20 #include <sys/socket.h>
21 #include <netinet/in.h>
22 #include <netinet/ip.h>
23 #include <netinet/tcp.h>
24 
25 #include <gmock/gmock.h>
26 #include <gtest/gtest.h>
27 
28 #include "NetlinkManager.h"
29 #include "WakeupController.h"
30 
31 using ::testing::StrictMock;
32 using ::testing::Test;
33 using ::testing::DoAll;
34 using ::testing::SaveArg;
35 using ::testing::Return;
36 using ::testing::_;
37 
38 namespace android {
39 namespace net {
40 
41 const uint32_t kDefaultPacketCopyRange = WakeupController::kDefaultPacketCopyRange;
42 
43 using netdutils::status::ok;
44 
45 class MockNetdEventListener {
46   public:
47     MOCK_METHOD10(onWakeupEvent,
48                   void(const std::string& prefix, int uid, int ether, int ipNextHeader,
49                        const std::vector<uint8_t>& dstHw, const std::string& srcIp,
50                        const std::string& dstIp, int srcPort, int dstPort, uint64_t timestampNs));
51 };
52 
53 class MockIptablesRestore : public IptablesRestoreInterface {
54   public:
55     ~MockIptablesRestore() override = default;
56     MOCK_METHOD3(execute, int(const IptablesTarget target, const std::string& commands,
57                               std::string* output));
58 };
59 
60 class MockNFLogListener : public NFLogListenerInterface {
61   public:
62     ~MockNFLogListener() override = default;
63     MOCK_METHOD2(subscribe, netdutils::Status(uint16_t nfLogGroup, const DispatchFn& fn));
64     MOCK_METHOD3(subscribe,
65             netdutils::Status(uint16_t nfLogGroup, uint32_t copyRange, const DispatchFn& fn));
66     MOCK_METHOD1(unsubscribe, netdutils::Status(uint16_t nfLogGroup));
67 };
68 
69 class WakeupControllerTest : public Test {
70   protected:
WakeupControllerTest()71     WakeupControllerTest() {
72         EXPECT_CALL(mListener,
73             subscribe(NetlinkManager::NFLOG_WAKEUP_GROUP, kDefaultPacketCopyRange, _))
74             .WillOnce(DoAll(SaveArg<2>(&mMessageHandler), Return(ok)));
75         EXPECT_CALL(mListener,
76             unsubscribe(NetlinkManager::NFLOG_WAKEUP_GROUP)).WillOnce(Return(ok));
77         EXPECT_OK(mController.init(&mListener));
78     }
79 
80     StrictMock<MockNetdEventListener> mEventListener;
81     StrictMock<MockIptablesRestore> mIptables;
82     StrictMock<MockNFLogListener> mListener;
83     WakeupController mController{
__anonf6a63f210102() 84         [this](const WakeupController::ReportArgs& args) {
85             mEventListener.onWakeupEvent(args.prefix, args.uid, args.ethertype, args.ipNextHeader,
86                                          args.dstHw, args.srcIp, args.dstIp, args.srcPort,
87                                          args.dstPort, args.timestampNs);
88         },
89         &mIptables};
90     NFLogListenerInterface::DispatchFn mMessageHandler;
91 };
92 
TEST_F(WakeupControllerTest,msgHandlerWithPartialAttributes)93 TEST_F(WakeupControllerTest, msgHandlerWithPartialAttributes) {
94     const char kPrefix[] = "test:prefix";
95     const uid_t kUid = 8734;
96     const gid_t kGid = 2222;
97     const uint64_t kNsPerS = 1000000000ULL;
98     const uint64_t kTsNs = 9999 + (34 * kNsPerS);
99 
100     struct Msg {
101         nlmsghdr nlmsg;
102         nfgenmsg nfmsg;
103         nlattr uidAttr;
104         uid_t uid;
105         nlattr gidAttr;
106         gid_t gid;
107         nlattr tsAttr;
108         timespec ts;
109         nlattr prefixAttr;
110         char prefix[sizeof(kPrefix)];
111     } msg = {};
112 
113     msg.uidAttr.nla_type = NFULA_UID;
114     msg.uidAttr.nla_len = sizeof(msg.uidAttr) + sizeof(msg.uid);
115     msg.uid = htonl(kUid);
116 
117     msg.gidAttr.nla_type = NFULA_GID;
118     msg.gidAttr.nla_len = sizeof(msg.gidAttr) + sizeof(msg.gid);
119     msg.gid = htonl(kGid);
120 
121     msg.tsAttr.nla_type = NFULA_TIMESTAMP;
122     msg.tsAttr.nla_len = sizeof(msg.tsAttr) + sizeof(msg.ts);
123     msg.ts.tv_sec = htonl(kTsNs / kNsPerS);
124     msg.ts.tv_nsec = htonl(kTsNs % kNsPerS);
125 
126     msg.prefixAttr.nla_type = NFULA_PREFIX;
127     msg.prefixAttr.nla_len = sizeof(msg.prefixAttr) + sizeof(msg.prefix);
128     memcpy(msg.prefix, kPrefix, sizeof(kPrefix));
129 
130     auto payload = drop(netdutils::makeSlice(msg), offsetof(Msg, uidAttr));
131     EXPECT_CALL(mEventListener,
132             onWakeupEvent(kPrefix, kUid, -1, -1, std::vector<uint8_t>(), "", "", -1, -1, kTsNs));
133     mMessageHandler(msg.nlmsg, msg.nfmsg, payload);
134 }
135 
TEST_F(WakeupControllerTest,msgHandler)136 TEST_F(WakeupControllerTest, msgHandler) {
137     const char kPrefix[] = "test:prefix";
138     const uid_t kUid = 8734;
139     const gid_t kGid = 2222;
140     const std::vector<uint8_t> kMacAddr = {11, 22, 33, 44, 55, 66};
141     const char* kSrcIpAddr = "192.168.2.1";
142     const char* kDstIpAddr = "192.168.2.23";
143     const uint16_t kEthertype = 0x800;
144     const uint8_t kIpNextHeader = 6;
145     const uint16_t kSrcPort = 1238;
146     const uint16_t kDstPort = 4567;
147     const uint64_t kNsPerS = 1000000000ULL;
148     const uint64_t kTsNs = 9999 + (34 * kNsPerS);
149 
150     struct Msg {
151         nlmsghdr nlmsg;
152         nfgenmsg nfmsg;
153         nlattr uidAttr;
154         uid_t uid;
155         nlattr gidAttr;
156         gid_t gid;
157         nlattr tsAttr;
158         timespec ts;
159         nlattr prefixAttr;
160         char prefix[sizeof(kPrefix)];
161         nlattr packetHeaderAttr;
162         struct nfulnl_msg_packet_hdr packetHeader;
163         nlattr hardwareAddrAttr;
164         struct nfulnl_msg_packet_hw hardwareAddr;
165         nlattr packetPayloadAttr;
166         struct iphdr ipHeader;
167         struct tcphdr tcpHeader;
168     } msg = {};
169 
170     msg.prefixAttr.nla_type = NFULA_PREFIX;
171     msg.prefixAttr.nla_len = sizeof(msg.prefixAttr) + sizeof(msg.prefix);
172     memcpy(msg.prefix, kPrefix, sizeof(kPrefix));
173 
174     msg.uidAttr.nla_type = NFULA_UID;
175     msg.uidAttr.nla_len = sizeof(msg.uidAttr) + sizeof(msg.uid);
176     msg.uid = htonl(kUid);
177 
178     msg.gidAttr.nla_type = NFULA_GID;
179     msg.gidAttr.nla_len = sizeof(msg.gidAttr) + sizeof(msg.gid);
180     msg.gid = htonl(kGid);
181 
182     msg.tsAttr.nla_type = NFULA_TIMESTAMP;
183     msg.tsAttr.nla_len = sizeof(msg.tsAttr) + sizeof(msg.ts);
184     msg.ts.tv_sec = htonl(kTsNs / kNsPerS);
185     msg.ts.tv_nsec = htonl(kTsNs % kNsPerS);
186 
187     msg.packetHeaderAttr.nla_type = NFULA_PACKET_HDR;
188     msg.packetHeaderAttr.nla_len = sizeof(msg.packetHeaderAttr) + sizeof(msg.packetHeader);
189     msg.packetHeader.hw_protocol = htons(kEthertype);
190 
191     msg.hardwareAddrAttr.nla_type = NFULA_HWADDR;
192     msg.hardwareAddrAttr.nla_len = sizeof(msg.hardwareAddrAttr) + sizeof(msg.hardwareAddr);
193     msg.hardwareAddr.hw_addrlen = htons(kMacAddr.size());
194     std::copy(kMacAddr.begin(), kMacAddr.end(), msg.hardwareAddr.hw_addr);
195 
196     msg.packetPayloadAttr.nla_type = NFULA_PAYLOAD;
197     msg.packetPayloadAttr.nla_len =
198             sizeof(msg.packetPayloadAttr) + sizeof(msg.ipHeader) + sizeof(msg.tcpHeader);
199     msg.ipHeader.protocol = IPPROTO_TCP;
200     msg.ipHeader.ihl = sizeof(msg.ipHeader) / 4; // ipv4 IHL counts 32 bit words.
201     inet_pton(AF_INET, kSrcIpAddr, &msg.ipHeader.saddr);
202     inet_pton(AF_INET, kDstIpAddr, &msg.ipHeader.daddr);
203     msg.tcpHeader.th_sport = htons(kSrcPort);
204     msg.tcpHeader.th_dport = htons(kDstPort);
205 
206     auto payload = drop(netdutils::makeSlice(msg), offsetof(Msg, uidAttr));
207     EXPECT_CALL(mEventListener, onWakeupEvent(kPrefix, kUid, kEthertype, kIpNextHeader, kMacAddr,
208                                               kSrcIpAddr, kDstIpAddr, kSrcPort, kDstPort, kTsNs));
209     mMessageHandler(msg.nlmsg, msg.nfmsg, payload);
210 }
211 
TEST_F(WakeupControllerTest,badAttr)212 TEST_F(WakeupControllerTest, badAttr) {
213     const char kPrefix[] = "test:prefix";
214     const uid_t kUid = 8734;
215     const gid_t kGid = 2222;
216     const uint64_t kNsPerS = 1000000000ULL;
217     const uint64_t kTsNs = 9999 + (34 * kNsPerS);
218 
219     struct Msg {
220         nlmsghdr nlmsg;
221         nfgenmsg nfmsg;
222         nlattr uidAttr;
223         uid_t uid;
224         nlattr invalid0;
225         nlattr invalid1;
226         nlattr gidAttr;
227         gid_t gid;
228         nlattr tsAttr;
229         timespec ts;
230         nlattr prefixAttr;
231         char prefix[sizeof(kPrefix)];
232     } msg = {};
233 
234     msg.uidAttr.nla_type = 999;
235     msg.uidAttr.nla_len = sizeof(msg.uidAttr) + sizeof(msg.uid);
236     msg.uid = htonl(kUid);
237 
238     msg.invalid0.nla_type = 0;
239     msg.invalid0.nla_len = 0;
240     msg.invalid1.nla_type = 0;
241     msg.invalid1.nla_len = 1;
242 
243     msg.gidAttr.nla_type = NFULA_GID;
244     msg.gidAttr.nla_len = sizeof(msg.gidAttr) + sizeof(msg.gid);
245     msg.gid = htonl(kGid);
246 
247     msg.tsAttr.nla_type = NFULA_TIMESTAMP;
248     msg.tsAttr.nla_len = sizeof(msg.tsAttr) - 2;
249     msg.ts.tv_sec = htonl(kTsNs / kNsPerS);
250     msg.ts.tv_nsec = htonl(kTsNs % kNsPerS);
251 
252     msg.prefixAttr.nla_type = NFULA_UID;
253     msg.prefixAttr.nla_len = sizeof(msg.prefixAttr) + sizeof(msg.prefix);
254     memcpy(msg.prefix, kPrefix, sizeof(kPrefix));
255 
256     auto payload = drop(netdutils::makeSlice(msg), offsetof(Msg, uidAttr));
257     EXPECT_CALL(mEventListener,
258             onWakeupEvent("", 1952805748, -1, -1, std::vector<uint8_t>(), "", "", -1, -1, 0));
259     mMessageHandler(msg.nlmsg, msg.nfmsg, payload);
260 }
261 
TEST_F(WakeupControllerTest,unterminatedString)262 TEST_F(WakeupControllerTest, unterminatedString) {
263     char ones[20] = {};
264     memset(ones, 1, sizeof(ones));
265 
266     struct Msg {
267         nlmsghdr nlmsg;
268         nfgenmsg nfmsg;
269         nlattr prefixAttr;
270         char prefix[sizeof(ones)];
271     } msg = {};
272 
273     msg.prefixAttr.nla_type = NFULA_PREFIX;
274     msg.prefixAttr.nla_len = sizeof(msg.prefixAttr) + sizeof(msg.prefix);
275     memcpy(msg.prefix, ones, sizeof(ones));
276 
277     const auto expected = std::string(ones, sizeof(ones) - 1);
278     auto payload = drop(netdutils::makeSlice(msg), offsetof(Msg, prefixAttr));
279     EXPECT_CALL(mEventListener,
280             onWakeupEvent(expected, -1, -1, -1, std::vector<uint8_t>(), "", "", -1, -1, 0));
281     mMessageHandler(msg.nlmsg, msg.nfmsg, payload);
282 }
283 
TEST_F(WakeupControllerTest,addInterface)284 TEST_F(WakeupControllerTest, addInterface) {
285     const char kPrefix[] = "test:prefix";
286     const char kIfName[] = "wlan8";
287     const uint32_t kMark = 0x12345678;
288     const uint32_t kMask = 0x0F0F0F0F;
289     const char kExpected[] =
290         "*mangle\n-A wakeupctrl_mangle_INPUT -i test:prefix"
291         " -j NFLOG --nflog-prefix wlan8 --nflog-group 3 --nflog-threshold 8"
292         " -m mark --mark 0x12345678/0x0f0f0f0f -m limit --limit 10/s\nCOMMIT\n";
293     EXPECT_CALL(mIptables, execute(V4V6, kExpected, _)).WillOnce(Return(0));
294     EXPECT_OK(mController.addInterface(kPrefix, kIfName, kMark, kMask));
295 }
296 
TEST_F(WakeupControllerTest,delInterface)297 TEST_F(WakeupControllerTest, delInterface) {
298     const char kPrefix[] = "test:prefix";
299     const char kIfName[] = "wlan8";
300     const uint32_t kMark = 0x12345678;
301     const uint32_t kMask = 0xF0F0F0F0;
302     const char kExpected[] =
303         "*mangle\n-D wakeupctrl_mangle_INPUT -i test:prefix"
304         " -j NFLOG --nflog-prefix wlan8 --nflog-group 3 --nflog-threshold 8"
305         " -m mark --mark 0x12345678/0xf0f0f0f0 -m limit --limit 10/s\nCOMMIT\n";
306     EXPECT_CALL(mIptables, execute(V4V6, kExpected, _)).WillOnce(Return(0));
307     EXPECT_OK(mController.delInterface(kPrefix, kIfName, kMark, kMask));
308 }
309 
310 }  // namespace net
311 }  // namespace android
312