1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package android.net
18 
19 import android.net.InvalidPacketException.ERROR_INVALID_IP_ADDRESS
20 import android.net.InvalidPacketException.ERROR_INVALID_PORT
21 import android.net.NattSocketKeepalive.NATT_PORT
22 import android.os.Build
23 import androidx.test.filters.SmallTest
24 import androidx.test.runner.AndroidJUnit4
25 import com.android.testutils.assertEqualBothWays
26 import com.android.testutils.DevSdkIgnoreRule
27 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
28 import com.android.testutils.assertParcelingIsLossless
29 import com.android.testutils.parcelingRoundTrip
30 import java.net.InetAddress
31 import org.junit.Assert.assertEquals
32 import org.junit.Assert.assertNotEquals
33 import org.junit.Assert.fail
34 import org.junit.Rule
35 import org.junit.Test
36 import org.junit.runner.RunWith
37 
38 @RunWith(AndroidJUnit4::class)
39 @SmallTest
40 class NattKeepalivePacketDataTest {
41     @Rule @JvmField
42     val ignoreRule: DevSdkIgnoreRule = DevSdkIgnoreRule()
43 
44     /* Refer to the definition in {@code NattKeepalivePacketData} */
45     private val IPV4_HEADER_LENGTH = 20
46     private val UDP_HEADER_LENGTH = 8
47 
48     private val TEST_PORT = 4243
49     private val TEST_PORT2 = 4244
50     private val TEST_SRC_ADDRV4 = "198.168.0.2".address()
51     private val TEST_DST_ADDRV4 = "198.168.0.1".address()
52     private val TEST_ADDRV6 = "2001:db8::1".address()
53 
54     private fun String.address() = InetAddresses.parseNumericAddress(this)
55     private fun nattKeepalivePacket(
56         srcAddress: InetAddress? = TEST_SRC_ADDRV4,
57         srcPort: Int = TEST_PORT,
58         dstAddress: InetAddress? = TEST_DST_ADDRV4,
59         dstPort: Int = NATT_PORT
60     ) = NattKeepalivePacketData.nattKeepalivePacket(srcAddress, srcPort, dstAddress, dstPort)
61 
62     @Test @IgnoreUpTo(Build.VERSION_CODES.Q)
63     fun testConstructor() {
64         try {
65             nattKeepalivePacket(dstPort = TEST_PORT)
66             fail("Dst port is not NATT port should cause exception")
67         } catch (e: InvalidPacketException) {
68             assertEquals(e.error, ERROR_INVALID_PORT)
69         }
70 
71         try {
72             nattKeepalivePacket(srcAddress = TEST_ADDRV6)
73             fail("A v6 srcAddress should cause exception")
74         } catch (e: InvalidPacketException) {
75             assertEquals(e.error, ERROR_INVALID_IP_ADDRESS)
76         }
77 
78         try {
79             nattKeepalivePacket(dstAddress = TEST_ADDRV6)
80             fail("A v6 dstAddress should cause exception")
81         } catch (e: InvalidPacketException) {
82             assertEquals(e.error, ERROR_INVALID_IP_ADDRESS)
83         }
84 
85         try {
86             parcelingRoundTrip(
87                     NattKeepalivePacketData(TEST_SRC_ADDRV4, TEST_PORT, TEST_DST_ADDRV4, TEST_PORT,
88                     byteArrayOf(12, 31, 22, 44)))
89             fail("Invalid data should cause exception")
90         } catch (e: IllegalArgumentException) { }
91     }
92 
93     @Test @IgnoreUpTo(Build.VERSION_CODES.Q)
94     fun testParcel() {
95         assertParcelingIsLossless(nattKeepalivePacket())
96     }
97 
98     @Test @IgnoreUpTo(Build.VERSION_CODES.Q)
99     fun testEquals() {
100         assertEqualBothWays(nattKeepalivePacket(), nattKeepalivePacket())
101         assertNotEquals(nattKeepalivePacket(dstAddress = TEST_SRC_ADDRV4), nattKeepalivePacket())
102         assertNotEquals(nattKeepalivePacket(srcAddress = TEST_DST_ADDRV4), nattKeepalivePacket())
103         // Test src port only because dst port have to be NATT_PORT
104         assertNotEquals(nattKeepalivePacket(srcPort = TEST_PORT2), nattKeepalivePacket())
105     }
106 
107     @Test @IgnoreUpTo(Build.VERSION_CODES.Q)
108     fun testHashCode() {
109         assertEquals(nattKeepalivePacket().hashCode(), nattKeepalivePacket().hashCode())
110     }
111 }