1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.testutils
18 
19 import android.net.ConnectivityManager.NetworkCallback
20 import android.net.LinkProperties
21 import android.net.Network
22 import android.net.NetworkCapabilities
23 import android.net.NetworkCapabilities.NET_CAPABILITY_VALIDATED
24 import android.util.Log
25 import com.android.net.module.util.ArrayTrackRecord
26 import com.android.testutils.RecorderCallback.CallbackEntry.Available
27 import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
28 import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatusInt
29 import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
30 import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
31 import com.android.testutils.RecorderCallback.CallbackEntry.Losing
32 import com.android.testutils.RecorderCallback.CallbackEntry.Lost
33 import com.android.testutils.RecorderCallback.CallbackEntry.Resumed
34 import com.android.testutils.RecorderCallback.CallbackEntry.Suspended
35 import com.android.testutils.RecorderCallback.CallbackEntry.Unavailable
36 import kotlin.reflect.KClass
37 import kotlin.test.assertEquals
38 import kotlin.test.assertNotNull
39 import kotlin.test.assertTrue
40 import kotlin.test.fail
41 
42 object NULL_NETWORK : Network(-1)
43 object ANY_NETWORK : Network(-2)
44 
45 private val Int.capabilityName get() = NetworkCapabilities.capabilityNameOf(this)
46 
47 open class RecorderCallback private constructor(
48     private val backingRecord: ArrayTrackRecord<CallbackEntry>
49 ) : NetworkCallback() {
50     public constructor() : this(ArrayTrackRecord())
51     protected constructor(src: RecorderCallback?): this(src?.backingRecord ?: ArrayTrackRecord())
52 
53     private val TAG = this::class.simpleName
54 
55     sealed class CallbackEntry {
56         // To get equals(), hashcode(), componentN() etc for free, the child classes of
57         // this class are data classes. But while data classes can inherit from other classes,
58         // they may only have visible members in the constructors, so they couldn't declare
59         // a constructor with a non-val arg to pass to CallbackEntry. Instead, force all
60         // subclasses to implement a `network' property, which can be done in a data class
61         // constructor by specifying override.
62         abstract val network: Network
63 
64         data class Available(override val network: Network) : CallbackEntry()
65         data class CapabilitiesChanged(
66             override val network: Network,
67             val caps: NetworkCapabilities
68         ) : CallbackEntry()
69         data class LinkPropertiesChanged(
70             override val network: Network,
71             val lp: LinkProperties
72         ) : CallbackEntry()
73         data class Suspended(override val network: Network) : CallbackEntry()
74         data class Resumed(override val network: Network) : CallbackEntry()
75         data class Losing(override val network: Network, val maxMsToLive: Int) : CallbackEntry()
76         data class Lost(override val network: Network) : CallbackEntry()
77         data class Unavailable private constructor(
78             override val network: Network
79         ) : CallbackEntry() {
80             constructor() : this(NULL_NETWORK)
81         }
82         data class BlockedStatus(
83             override val network: Network,
84             val blocked: Boolean
85         ) : CallbackEntry()
86         data class BlockedStatusInt(
87             override val network: Network,
88             val blocked: Int
89         ) : CallbackEntry()
90         // Convenience constants for expecting a type
91         companion object {
92             @JvmField
93             val AVAILABLE = Available::class
94             @JvmField
95             val NETWORK_CAPS_UPDATED = CapabilitiesChanged::class
96             @JvmField
97             val LINK_PROPERTIES_CHANGED = LinkPropertiesChanged::class
98             @JvmField
99             val SUSPENDED = Suspended::class
100             @JvmField
101             val RESUMED = Resumed::class
102             @JvmField
103             val LOSING = Losing::class
104             @JvmField
105             val LOST = Lost::class
106             @JvmField
107             val UNAVAILABLE = Unavailable::class
108             @JvmField
109             val BLOCKED_STATUS = BlockedStatus::class
110             @JvmField
111             val BLOCKED_STATUS_INT = BlockedStatusInt::class
112         }
113     }
114 
115     val history = backingRecord.newReadHead()
116     val mark get() = history.mark
117 
118     override fun onAvailable(network: Network) {
119         history.add(Available(network))
120     }
121 
122     // PreCheck is not used in the tests today. For backward compatibility with existing tests that
123     // expect the callbacks not to record this, do not listen to PreCheck here.
124 
125     override fun onCapabilitiesChanged(network: Network, caps: NetworkCapabilities) {
126         Log.d(TAG, "onCapabilitiesChanged $network $caps")
127         history.add(CapabilitiesChanged(network, caps))
128     }
129 
130     override fun onLinkPropertiesChanged(network: Network, lp: LinkProperties) {
131         Log.d(TAG, "onLinkPropertiesChanged $network $lp")
132         history.add(LinkPropertiesChanged(network, lp))
133     }
134 
135     override fun onBlockedStatusChanged(network: Network, blocked: Boolean) {
136         Log.d(TAG, "onBlockedStatusChanged $network $blocked")
137         history.add(BlockedStatus(network, blocked))
138     }
139 
140     // Cannot do:
141     // fun onBlockedStatusChanged(network: Network, blocked: Int) {
142     // because on S, that needs to be "override fun", and on R, that cannot be "override fun".
143     override fun onNetworkSuspended(network: Network) {
144         Log.d(TAG, "onNetworkSuspended $network $network")
145         history.add(Suspended(network))
146     }
147 
148     override fun onNetworkResumed(network: Network) {
149         Log.d(TAG, "$network onNetworkResumed $network")
150         history.add(Resumed(network))
151     }
152 
153     override fun onLosing(network: Network, maxMsToLive: Int) {
154         Log.d(TAG, "onLosing $network $maxMsToLive")
155         history.add(Losing(network, maxMsToLive))
156     }
157 
158     override fun onLost(network: Network) {
159         Log.d(TAG, "onLost $network")
160         history.add(Lost(network))
161     }
162 
163     override fun onUnavailable() {
164         Log.d(TAG, "onUnavailable")
165         history.add(Unavailable())
166     }
167 }
168 
169 private const val DEFAULT_TIMEOUT = 200L // ms
170 
171 open class TestableNetworkCallback private constructor(
172     src: TestableNetworkCallback?,
173     val defaultTimeoutMs: Long = DEFAULT_TIMEOUT
174 ) : RecorderCallback(src) {
175     @JvmOverloads
176     constructor(timeoutMs: Long = DEFAULT_TIMEOUT): this(null, timeoutMs)
177 
178     fun createLinkedCopy() = TestableNetworkCallback(this, defaultTimeoutMs)
179 
180     // The last available network, or null if any network was lost since the last call to
181     // onAvailable. TODO : fix this by fixing the tests that rely on this behavior
182     val lastAvailableNetwork: Network?
183         get() = when (val it = history.lastOrNull { it is Available || it is Lost }) {
184             is Available -> it.network
185             else -> null
186         }
187 
188     fun pollForNextCallback(timeoutMs: Long = defaultTimeoutMs): CallbackEntry {
189         return history.poll(timeoutMs) ?: fail("Did not receive callback after ${timeoutMs}ms")
190     }
191 
192     // Make open for use in ConnectivityServiceTest which is the only one knowing its handlers.
193     // TODO : remove the necessity to overload this, remove the open qualifier, and give a
194     // default argument to assertNoCallback instead, possibly with @JvmOverloads if necessary.
195     open fun assertNoCallback() = assertNoCallback(defaultTimeoutMs)
196 
197     fun assertNoCallback(timeoutMs: Long) {
198         val cb = history.poll(timeoutMs)
199         if (null != cb) fail("Expected no callback but got $cb")
200     }
201 
202     // Expects a callback of the specified type on the specified network within the timeout.
203     // If no callback arrives, or a different callback arrives, fail. Returns the callback.
204     inline fun <reified T : CallbackEntry> expectCallback(
205         network: Network = ANY_NETWORK,
206         timeoutMs: Long = defaultTimeoutMs
207     ): T = pollForNextCallback(timeoutMs).let {
208         if (it !is T || (ANY_NETWORK !== network && it.network != network)) {
209             fail("Unexpected callback : $it, expected ${T::class} with Network[$network]")
210         } else {
211             it
212         }
213     }
214 
215     // Expects a callback of the specified type matching the predicate within the timeout.
216     // Any callback that doesn't match the predicate will be skipped. Fails only if
217     // no matching callback is received within the timeout.
218     inline fun <reified T : CallbackEntry> eventuallyExpect(
219         timeoutMs: Long = defaultTimeoutMs,
220         from: Int = mark,
221         crossinline predicate: (T) -> Boolean = { true }
222     ): T = eventuallyExpectOrNull(timeoutMs, from, predicate).also {
223         assertNotNull(it, "Callback ${T::class} not received within ${timeoutMs}ms")
224     } as T
225 
226     fun <T : CallbackEntry> eventuallyExpect(
227         type: KClass<T>,
228         timeoutMs: Long = defaultTimeoutMs,
229         predicate: (T: CallbackEntry) -> Boolean = { true }
230     ) = history.poll(timeoutMs) { type.java.isInstance(it) && predicate(it) }.also {
231         assertNotNull(it, "Callback ${type.java} not received within ${timeoutMs}ms")
232     } as T
233 
234     // TODO (b/157405399) straighten and unify the method names
235     inline fun <reified T : CallbackEntry> eventuallyExpectOrNull(
236         timeoutMs: Long = defaultTimeoutMs,
237         from: Int = mark,
238         crossinline predicate: (T) -> Boolean = { true }
239     ) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T?
240 
241     fun expectCallbackThat(
242         timeoutMs: Long = defaultTimeoutMs,
243         valid: (CallbackEntry) -> Boolean
244     ) = pollForNextCallback(timeoutMs).also { assertTrue(valid(it), "Unexpected callback : $it") }
245 
246     fun expectCapabilitiesThat(
247         net: Network,
248         tmt: Long = defaultTimeoutMs,
249         valid: (NetworkCapabilities) -> Boolean
250     ): CapabilitiesChanged {
251         return expectCallback<CapabilitiesChanged>(net, tmt).also {
252             assertTrue(valid(it.caps), "Capabilities don't match expectations ${it.caps}")
253         }
254     }
255 
256     fun expectLinkPropertiesThat(
257         net: Network,
258         tmt: Long = defaultTimeoutMs,
259         valid: (LinkProperties) -> Boolean
260     ): LinkPropertiesChanged {
261         return expectCallback<LinkPropertiesChanged>(net, tmt).also {
262             assertTrue(valid(it.lp), "LinkProperties don't match expectations ${it.lp}")
263         }
264     }
265 
266     // Expects onAvailable and the callbacks that follow it. These are:
267     // - onSuspended, iff the network was suspended when the callbacks fire.
268     // - onCapabilitiesChanged.
269     // - onLinkPropertiesChanged.
270     // - onBlockedStatusChanged.
271     //
272     // @param network the network to expect the callbacks on.
273     // @param suspended whether to expect a SUSPENDED callback.
274     // @param validated the expected value of the VALIDATED capability in the
275     //        onCapabilitiesChanged callback.
276     // @param tmt how long to wait for the callbacks.
277     fun expectAvailableCallbacks(
278         net: Network,
279         suspended: Boolean = false,
280         validated: Boolean = true,
281         blocked: Boolean = false,
282         tmt: Long = defaultTimeoutMs
283     ) {
284         expectAvailableCallbacksCommon(net, suspended, validated, tmt)
285         expectBlockedStatusCallback(blocked, net, tmt)
286     }
287 
288     fun expectAvailableCallbacks(
289         net: Network,
290         suspended: Boolean,
291         validated: Boolean,
292         blockedStatus: Int,
293         tmt: Long
294     ) {
295         expectAvailableCallbacksCommon(net, suspended, validated, tmt)
296         expectBlockedStatusCallback(blockedStatus, net)
297     }
298 
299     private fun expectAvailableCallbacksCommon(
300         net: Network,
301         suspended: Boolean,
302         validated: Boolean,
303         tmt: Long
304     ) {
305         expectCallback<Available>(net, tmt)
306         if (suspended) {
307             expectCallback<Suspended>(net, tmt)
308         }
309         expectCapabilitiesThat(net, tmt) { validated == it.hasCapability(NET_CAPABILITY_VALIDATED) }
310         expectCallback<LinkPropertiesChanged>(net, tmt)
311     }
312 
313     // Backward compatibility for existing Java code. Use named arguments instead and remove all
314     // these when there is no user left.
315     fun expectAvailableAndSuspendedCallbacks(
316         net: Network,
317         validated: Boolean,
318         tmt: Long = defaultTimeoutMs
319     ) = expectAvailableCallbacks(net, suspended = true, validated = validated, tmt = tmt)
320 
321     fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) {
322         expectCallback<BlockedStatus>(net, tmt).also {
323             assertEquals(it.blocked, blocked, "Unexpected blocked status ${it.blocked}")
324         }
325     }
326 
327     fun expectBlockedStatusCallback(blocked: Int, net: Network, tmt: Long = defaultTimeoutMs) {
328         expectCallback<BlockedStatusInt>(net, tmt).also {
329             assertEquals(it.blocked, blocked, "Unexpected blocked status ${it.blocked}")
330         }
331     }
332 
333     // Expects the available callbacks (where the onCapabilitiesChanged must contain the
334     // VALIDATED capability), plus another onCapabilitiesChanged which is identical to the
335     // one we just sent.
336     // TODO: this is likely a bug. Fix it and remove this method.
337     fun expectAvailableDoubleValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
338         val mark = history.mark
339         expectAvailableCallbacks(net, tmt = tmt)
340         val firstCaps = history.poll(tmt, mark) { it is CapabilitiesChanged }
341         assertEquals(firstCaps, expectCallback<CapabilitiesChanged>(net, tmt))
342     }
343 
344     // Expects the available callbacks where the onCapabilitiesChanged must not have validated,
345     // then expects another onCapabilitiesChanged that has the validated bit set. This is used
346     // when a network connects and satisfies a callback, and then immediately validates.
347     fun expectAvailableThenValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
348         expectAvailableCallbacks(net, validated = false, tmt = tmt)
349         expectCapabilitiesThat(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
350     }
351 
352     fun expectAvailableThenValidatedCallbacks(
353         net: Network,
354         blockedStatus: Int,
355         tmt: Long = defaultTimeoutMs
356     ) {
357         expectAvailableCallbacks(net, validated = false, suspended = false,
358                 blockedStatus = blockedStatus, tmt = tmt)
359         expectCapabilitiesThat(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
360     }
361 
362     // Temporary Java compat measure : have MockNetworkAgent implement this so that all existing
363     // calls with networkAgent can be routed through here without moving MockNetworkAgent.
364     // TODO: clean this up, remove this method.
365     interface HasNetwork {
366         val network: Network
367     }
368 
369     @JvmOverloads
370     open fun <T : CallbackEntry> expectCallback(
371         type: KClass<T>,
372         n: Network?,
373         timeoutMs: Long = defaultTimeoutMs
374     ) = pollForNextCallback(timeoutMs).also {
375         val network = n ?: NULL_NETWORK
376         // TODO : remove this .java access if the tests ever use kotlin-reflect. At the time of
377         // this writing this would be the only use of this library in the tests.
378         assertTrue(type.java.isInstance(it) && it.network == network,
379                 "Unexpected callback : $it, expected ${type.java} with Network[$network]")
380     } as T
381 
382     @JvmOverloads
383     open fun <T : CallbackEntry> expectCallback(
384         type: KClass<T>,
385         n: HasNetwork?,
386         timeoutMs: Long = defaultTimeoutMs
387     ) = expectCallback(type, n?.network, timeoutMs)
388 
389     fun expectAvailableCallbacks(
390         n: HasNetwork,
391         suspended: Boolean,
392         validated: Boolean,
393         blocked: Boolean,
394         timeoutMs: Long
395     ) = expectAvailableCallbacks(n.network, suspended, validated, blocked, timeoutMs)
396 
397     fun expectAvailableAndSuspendedCallbacks(n: HasNetwork, expectValidated: Boolean) {
398         expectAvailableAndSuspendedCallbacks(n.network, expectValidated)
399     }
400 
401     fun expectAvailableCallbacksValidated(n: HasNetwork) {
402         expectAvailableCallbacks(n.network)
403     }
404 
405     fun expectAvailableCallbacksValidatedAndBlocked(n: HasNetwork) {
406         expectAvailableCallbacks(n.network, blocked = true)
407     }
408 
409     fun expectAvailableCallbacksUnvalidated(n: HasNetwork) {
410         expectAvailableCallbacks(n.network, validated = false)
411     }
412 
413     fun expectAvailableCallbacksUnvalidatedAndBlocked(n: HasNetwork) {
414         expectAvailableCallbacks(n.network, validated = false, blocked = true)
415     }
416 
417     fun expectAvailableDoubleValidatedCallbacks(n: HasNetwork) {
418         expectAvailableDoubleValidatedCallbacks(n.network, defaultTimeoutMs)
419     }
420 
421     fun expectAvailableThenValidatedCallbacks(n: HasNetwork) {
422         expectAvailableThenValidatedCallbacks(n.network, defaultTimeoutMs)
423     }
424 
425     @JvmOverloads
426     fun expectLinkPropertiesThat(
427         n: HasNetwork,
428         tmt: Long = defaultTimeoutMs,
429         valid: (LinkProperties) -> Boolean
430     ) = expectLinkPropertiesThat(n.network, tmt, valid)
431 
432     @JvmOverloads
433     fun expectCapabilitiesThat(
434         n: HasNetwork,
435         tmt: Long = defaultTimeoutMs,
436         valid: (NetworkCapabilities) -> Boolean
437     ) = expectCapabilitiesThat(n.network, tmt, valid)
438 
439     @JvmOverloads
440     fun expectCapabilitiesWith(
441         capability: Int,
442         n: HasNetwork,
443         timeoutMs: Long = defaultTimeoutMs
444     ): NetworkCapabilities {
445         return expectCapabilitiesThat(n.network, timeoutMs) { it.hasCapability(capability) }.caps
446     }
447 
448     @JvmOverloads
449     fun expectCapabilitiesWithout(
450         capability: Int,
451         n: HasNetwork,
452         timeoutMs: Long = defaultTimeoutMs
453     ): NetworkCapabilities {
454         return expectCapabilitiesThat(n.network, timeoutMs) { !it.hasCapability(capability) }.caps
455     }
456 
457     fun expectBlockedStatusCallback(expectBlocked: Boolean, n: HasNetwork) {
458         expectBlockedStatusCallback(expectBlocked, n.network, defaultTimeoutMs)
459     }
460 }
461