1 package android.net.testutils
2 
3 import android.net.LinkAddress
4 import android.net.LinkProperties
5 import android.net.Network
6 import android.net.NetworkCapabilities
7 import com.android.testutils.ConcurrentInterpreter
8 import com.android.testutils.InterpretMatcher
9 import com.android.testutils.RecorderCallback.CallbackEntry
10 import com.android.testutils.RecorderCallback.CallbackEntry.Available
11 import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
12 import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
13 import com.android.testutils.TestableNetworkCallback
14 import com.android.testutils.intArg
15 import com.android.testutils.strArg
16 import com.android.testutils.timeArg
17 import org.junit.Before
18 import org.junit.Test
19 import org.junit.runner.RunWith
20 import org.junit.runners.JUnit4
21 import kotlin.reflect.KClass
22 import kotlin.test.assertEquals
23 import kotlin.test.assertFails
24 import kotlin.test.assertNull
25 import kotlin.test.assertTrue
26 import kotlin.test.fail
27 
28 const val SHORT_TIMEOUT_MS = 20L
29 const val DEFAULT_LINGER_DELAY_MS = 30000
30 const val NOT_METERED = NetworkCapabilities.NET_CAPABILITY_NOT_METERED
31 const val WIFI = NetworkCapabilities.TRANSPORT_WIFI
32 const val CELLULAR = NetworkCapabilities.TRANSPORT_CELLULAR
33 const val TEST_INTERFACE_NAME = "testInterfaceName"
34 
35 @RunWith(JUnit4::class)
36 class TestableNetworkCallbackTest {
37     private lateinit var mCallback: TestableNetworkCallback
38 
39     private fun makeHasNetwork(netId: Int) = object : TestableNetworkCallback.HasNetwork {
40         override val network: Network = Network(netId)
41     }
42 
43     @Before
44     fun setUp() {
45         mCallback = TestableNetworkCallback()
46     }
47 
48     @Test
49     fun testLastAvailableNetwork() {
50         // Make sure there is no last available network at first, then the last available network
51         // is returned after onAvailable is called.
52         val net2097 = Network(2097)
53         assertNull(mCallback.lastAvailableNetwork)
54         mCallback.onAvailable(net2097)
55         assertEquals(mCallback.lastAvailableNetwork, net2097)
56 
57         // Make sure calling onCapsChanged/onLinkPropertiesChanged don't affect the last available
58         // network.
59         mCallback.onCapabilitiesChanged(net2097, NetworkCapabilities())
60         mCallback.onLinkPropertiesChanged(net2097, LinkProperties())
61         assertEquals(mCallback.lastAvailableNetwork, net2097)
62 
63         // Make sure onLost clears the last available network.
64         mCallback.onLost(net2097)
65         assertNull(mCallback.lastAvailableNetwork)
66 
67         // Do the same but with a different network after onLost : make sure the last available
68         // network is the new one, not the original one.
69         val net2098 = Network(2098)
70         mCallback.onAvailable(net2098)
71         mCallback.onCapabilitiesChanged(net2098, NetworkCapabilities())
72         mCallback.onLinkPropertiesChanged(net2098, LinkProperties())
73         assertEquals(mCallback.lastAvailableNetwork, net2098)
74 
75         // Make sure onAvailable changes the last available network even if onLost was not called.
76         val net2099 = Network(2099)
77         mCallback.onAvailable(net2099)
78         assertEquals(mCallback.lastAvailableNetwork, net2099)
79 
80         // For legacy reasons, lastAvailableNetwork is null as soon as any is lost, not necessarily
81         // the last available one. Check that behavior.
82         mCallback.onLost(net2098)
83         assertNull(mCallback.lastAvailableNetwork)
84 
85         // Make sure that losing the really last available one still results in null.
86         mCallback.onLost(net2099)
87         assertNull(mCallback.lastAvailableNetwork)
88 
89         // Make sure multiple onAvailable in a row then onLost still results in null.
90         mCallback.onAvailable(net2097)
91         mCallback.onAvailable(net2098)
92         mCallback.onAvailable(net2099)
93         mCallback.onLost(net2097)
94         assertNull(mCallback.lastAvailableNetwork)
95     }
96 
97     @Test
98     fun testAssertNoCallback() {
99         mCallback.assertNoCallback(SHORT_TIMEOUT_MS)
100         mCallback.onAvailable(Network(100))
101         assertFails { mCallback.assertNoCallback(SHORT_TIMEOUT_MS) }
102     }
103 
104     @Test
105     fun testCapabilitiesWithAndWithout() {
106         val net = Network(101)
107         val matcher = makeHasNetwork(101)
108         val meteredNc = NetworkCapabilities()
109         val unmeteredNc = NetworkCapabilities().addCapability(NOT_METERED)
110         // Check that expecting caps (with or without) fails when no callback has been received.
111         assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
112         assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
113 
114         // Add NOT_METERED and check that With succeeds and Without fails.
115         mCallback.onCapabilitiesChanged(net, unmeteredNc)
116         mCallback.expectCapabilitiesWith(NOT_METERED, matcher)
117         mCallback.onCapabilitiesChanged(net, unmeteredNc)
118         assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
119 
120         // Don't add NOT_METERED and check that With fails and Without succeeds.
121         mCallback.onCapabilitiesChanged(net, meteredNc)
122         assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
123         mCallback.onCapabilitiesChanged(net, meteredNc)
124         mCallback.expectCapabilitiesWithout(NOT_METERED, matcher)
125     }
126 
127     @Test
128     fun testExpectCallbackThat() {
129         val net = Network(193)
130         val netCaps = NetworkCapabilities().addTransportType(CELLULAR)
131         // Check that expecting callbackThat anything fails when no callback has been received.
132         assertFails { mCallback.expectCallbackThat(SHORT_TIMEOUT_MS) { true } }
133 
134         // Basic test for true and false
135         mCallback.onAvailable(net)
136         mCallback.expectCallbackThat { true }
137         mCallback.onAvailable(net)
138         assertFails { mCallback.expectCallbackThat(SHORT_TIMEOUT_MS) { false } }
139 
140         // Try a positive and a negative case
141         mCallback.onBlockedStatusChanged(net, true)
142         mCallback.expectCallbackThat { cb -> cb is BlockedStatus && cb.blocked }
143         mCallback.onCapabilitiesChanged(net, netCaps)
144         assertFails { mCallback.expectCallbackThat(SHORT_TIMEOUT_MS) { cb ->
145             cb is CapabilitiesChanged && cb.caps.hasTransport(WIFI)
146         } }
147     }
148 
149     @Test
150     fun testCapabilitiesThat() {
151         val net = Network(101)
152         val netCaps = NetworkCapabilities().addCapability(NOT_METERED).addTransportType(WIFI)
153         // Check that expecting capabilitiesThat anything fails when no callback has been received.
154         assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { true } }
155 
156         // Basic test for true and false
157         mCallback.onCapabilitiesChanged(net, netCaps)
158         mCallback.expectCapabilitiesThat(net) { true }
159         mCallback.onCapabilitiesChanged(net, netCaps)
160         assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { false } }
161 
162         // Try a positive and a negative case
163         mCallback.onCapabilitiesChanged(net, netCaps)
164         mCallback.expectCapabilitiesThat(net) { caps ->
165             caps.hasCapability(NOT_METERED) &&
166                     caps.hasTransport(WIFI) &&
167                     !caps.hasTransport(CELLULAR)
168         }
169         mCallback.onCapabilitiesChanged(net, netCaps)
170         assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { caps ->
171             caps.hasTransport(CELLULAR)
172         } }
173 
174         // Try a matching callback on the wrong network
175         mCallback.onCapabilitiesChanged(net, netCaps)
176         assertFails { mCallback.expectCapabilitiesThat(Network(100), SHORT_TIMEOUT_MS) { true } }
177     }
178 
179     @Test
180     fun testLinkPropertiesThat() {
181         val net = Network(112)
182         val linkAddress = LinkAddress("fe80::ace:d00d/64")
183         val mtu = 1984
184         val linkProps = LinkProperties().apply {
185             this.mtu = mtu
186             interfaceName = TEST_INTERFACE_NAME
187             addLinkAddress(linkAddress)
188         }
189 
190         // Check that expecting linkPropsThat anything fails when no callback has been received.
191         assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { true } }
192 
193         // Basic test for true and false
194         mCallback.onLinkPropertiesChanged(net, linkProps)
195         mCallback.expectLinkPropertiesThat(net) { true }
196         mCallback.onLinkPropertiesChanged(net, linkProps)
197         assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { false } }
198 
199         // Try a positive and negative case
200         mCallback.onLinkPropertiesChanged(net, linkProps)
201         mCallback.expectLinkPropertiesThat(net) { lp ->
202             lp.interfaceName == TEST_INTERFACE_NAME &&
203                     lp.linkAddresses.contains(linkAddress) &&
204                     lp.mtu == mtu
205         }
206         mCallback.onLinkPropertiesChanged(net, linkProps)
207         assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { lp ->
208             lp.interfaceName != TEST_INTERFACE_NAME
209         } }
210 
211         // Try a matching callback on the wrong network
212         mCallback.onLinkPropertiesChanged(net, linkProps)
213         assertFails { mCallback.expectLinkPropertiesThat(Network(114), SHORT_TIMEOUT_MS) { lp ->
214             lp.interfaceName == TEST_INTERFACE_NAME
215         } }
216     }
217 
218     @Test
219     fun testExpectCallback() {
220         val net = Network(103)
221         // Test expectCallback fails when nothing was sent.
222         assertFails { mCallback.expectCallback<BlockedStatus>(net, SHORT_TIMEOUT_MS) }
223 
224         // Test onAvailable is seen and can be expected
225         mCallback.onAvailable(net)
226         mCallback.expectCallback<Available>(net, SHORT_TIMEOUT_MS)
227 
228         // Test onAvailable won't return calls with a different network
229         mCallback.onAvailable(Network(106))
230         assertFails { mCallback.expectCallback<Available>(net, SHORT_TIMEOUT_MS) }
231 
232         // Test onAvailable won't return calls with a different callback
233         mCallback.onAvailable(net)
234         assertFails { mCallback.expectCallback<BlockedStatus>(net, SHORT_TIMEOUT_MS) }
235     }
236 
237     @Test
238     fun testPollForNextCallback() {
239         assertFails { mCallback.pollForNextCallback(SHORT_TIMEOUT_MS) }
240         TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1,
241                 threadTransform = { cb -> cb.createLinkedCopy() }, spec = """
242             sleep; onAvailable(133)    | poll(2) = Available(133) time 1..4
243                                        | poll(1) fails
244             onCapabilitiesChanged(108) | poll(1) = CapabilitiesChanged(108) time 0..3
245             onBlockedStatus(199)       | poll(1) = BlockedStatus(199) time 0..3
246         """)
247     }
248 }
249 
250 private object TNCInterpreter : ConcurrentInterpreter<TestableNetworkCallback>(interpretTable)
251 
252 val EntryList = CallbackEntry::class.sealedSubclasses.map { it.simpleName }.joinToString("|")
253 private fun callbackEntryFromString(name: String): KClass<out CallbackEntry> {
254     return CallbackEntry::class.sealedSubclasses.first { it.simpleName == name }
255 }
256 
257 private val interpretTable = listOf<InterpretMatcher<TestableNetworkCallback>>(
258     // Interpret "Available(xx)" as "call to onAvailable with netId xx", and likewise for
259     // all callback types. This is implemented above by enumerating the subclasses of
260     // CallbackEntry and reading their simpleName.
261     Regex("""(.*)\s+=\s+($EntryList)\((\d+)\)""") to { i, cb, t ->
262         val record = i.interpret(t.strArg(1), cb)
263         assertTrue(callbackEntryFromString(t.strArg(2)).isInstance(record))
264         // Strictly speaking testing for is CallbackEntry is useless as it's been tested above
265         // but the compiler can't figure things out from the isInstance call. It does understand
266         // from the assertTrue(is CallbackEntry) that this is true, which allows to access
267         // the 'network' member below.
268         assertTrue(record is CallbackEntry)
269         assertEquals(record.network.netId, t.intArg(3))
270     },
271     // Interpret "onAvailable(xx)" as calling "onAvailable" with a netId of xx, and likewise for
272     // all callback types. NetworkCapabilities and LinkProperties just get an empty object
273     // as their argument. Losing gets the default linger timer. Blocked gets false.
274     Regex("""on($EntryList)\((\d+)\)""") to { i, cb, t ->
275         val net = Network(t.intArg(2))
276         when (t.strArg(1)) {
277             "Available" -> cb.onAvailable(net)
278             // PreCheck not used in tests. Add it here if it becomes useful.
279             "CapabilitiesChanged" -> cb.onCapabilitiesChanged(net, NetworkCapabilities())
280             "LinkPropertiesChanged" -> cb.onLinkPropertiesChanged(net, LinkProperties())
281             "Suspended" -> cb.onNetworkSuspended(net)
282             "Resumed" -> cb.onNetworkResumed(net)
283             "Losing" -> cb.onLosing(net, DEFAULT_LINGER_DELAY_MS)
284             "Lost" -> cb.onLost(net)
285             "Unavailable" -> cb.onUnavailable()
286             "BlockedStatus" -> cb.onBlockedStatusChanged(net, false)
287             else -> fail("Unknown callback type")
288         }
289     },
290     Regex("""poll\((\d+)\)""") to { i, cb, t ->
291         cb.pollForNextCallback(t.timeArg(1))
292     }
293 )
294