1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.testutils 18 19 import android.net.ConnectivityManager.NetworkCallback 20 import android.net.LinkProperties 21 import android.net.Network 22 import android.net.NetworkCapabilities 23 import android.net.NetworkCapabilities.NET_CAPABILITY_VALIDATED 24 import android.util.Log 25 import com.android.net.module.util.ArrayTrackRecord 26 import com.android.testutils.RecorderCallback.CallbackEntry.Available 27 import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus 28 import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatusInt 29 import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged 30 import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged 31 import com.android.testutils.RecorderCallback.CallbackEntry.Losing 32 import com.android.testutils.RecorderCallback.CallbackEntry.Lost 33 import com.android.testutils.RecorderCallback.CallbackEntry.Resumed 34 import com.android.testutils.RecorderCallback.CallbackEntry.Suspended 35 import com.android.testutils.RecorderCallback.CallbackEntry.Unavailable 36 import kotlin.reflect.KClass 37 import kotlin.test.assertEquals 38 import kotlin.test.assertNotNull 39 import kotlin.test.assertTrue 40 import kotlin.test.fail 41 42 object NULL_NETWORK : Network(-1) 43 object ANY_NETWORK : Network(-2) 44 45 private val Int.capabilityName get() = NetworkCapabilities.capabilityNameOf(this) 46 47 open class RecorderCallback private constructor( 48 private val backingRecord: ArrayTrackRecord<CallbackEntry> 49 ) : NetworkCallback() { 50 public constructor() : this(ArrayTrackRecord()) 51 protected constructor(src: RecorderCallback?): this(src?.backingRecord ?: ArrayTrackRecord()) 52 53 private val TAG = this::class.simpleName 54 55 sealed class CallbackEntry { 56 // To get equals(), hashcode(), componentN() etc for free, the child classes of 57 // this class are data classes. But while data classes can inherit from other classes, 58 // they may only have visible members in the constructors, so they couldn't declare 59 // a constructor with a non-val arg to pass to CallbackEntry. Instead, force all 60 // subclasses to implement a `network' property, which can be done in a data class 61 // constructor by specifying override. 62 abstract val network: Network 63 64 data class Available(override val network: Network) : CallbackEntry() 65 data class CapabilitiesChanged( 66 override val network: Network, 67 val caps: NetworkCapabilities 68 ) : CallbackEntry() 69 data class LinkPropertiesChanged( 70 override val network: Network, 71 val lp: LinkProperties 72 ) : CallbackEntry() 73 data class Suspended(override val network: Network) : CallbackEntry() 74 data class Resumed(override val network: Network) : CallbackEntry() 75 data class Losing(override val network: Network, val maxMsToLive: Int) : CallbackEntry() 76 data class Lost(override val network: Network) : CallbackEntry() 77 data class Unavailable private constructor( 78 override val network: Network 79 ) : CallbackEntry() { 80 constructor() : this(NULL_NETWORK) 81 } 82 data class BlockedStatus( 83 override val network: Network, 84 val blocked: Boolean 85 ) : CallbackEntry() 86 data class BlockedStatusInt( 87 override val network: Network, 88 val blocked: Int 89 ) : CallbackEntry() 90 // Convenience constants for expecting a type 91 companion object { 92 @JvmField 93 val AVAILABLE = Available::class 94 @JvmField 95 val NETWORK_CAPS_UPDATED = CapabilitiesChanged::class 96 @JvmField 97 val LINK_PROPERTIES_CHANGED = LinkPropertiesChanged::class 98 @JvmField 99 val SUSPENDED = Suspended::class 100 @JvmField 101 val RESUMED = Resumed::class 102 @JvmField 103 val LOSING = Losing::class 104 @JvmField 105 val LOST = Lost::class 106 @JvmField 107 val UNAVAILABLE = Unavailable::class 108 @JvmField 109 val BLOCKED_STATUS = BlockedStatus::class 110 @JvmField 111 val BLOCKED_STATUS_INT = BlockedStatusInt::class 112 } 113 } 114 115 val history = backingRecord.newReadHead() 116 val mark get() = history.mark 117 118 override fun onAvailable(network: Network) { 119 history.add(Available(network)) 120 } 121 122 // PreCheck is not used in the tests today. For backward compatibility with existing tests that 123 // expect the callbacks not to record this, do not listen to PreCheck here. 124 125 override fun onCapabilitiesChanged(network: Network, caps: NetworkCapabilities) { 126 Log.d(TAG, "onCapabilitiesChanged $network $caps") 127 history.add(CapabilitiesChanged(network, caps)) 128 } 129 130 override fun onLinkPropertiesChanged(network: Network, lp: LinkProperties) { 131 Log.d(TAG, "onLinkPropertiesChanged $network $lp") 132 history.add(LinkPropertiesChanged(network, lp)) 133 } 134 135 override fun onBlockedStatusChanged(network: Network, blocked: Boolean) { 136 Log.d(TAG, "onBlockedStatusChanged $network $blocked") 137 history.add(BlockedStatus(network, blocked)) 138 } 139 140 // Cannot do: 141 // fun onBlockedStatusChanged(network: Network, blocked: Int) { 142 // because on S, that needs to be "override fun", and on R, that cannot be "override fun". 143 override fun onNetworkSuspended(network: Network) { 144 Log.d(TAG, "onNetworkSuspended $network $network") 145 history.add(Suspended(network)) 146 } 147 148 override fun onNetworkResumed(network: Network) { 149 Log.d(TAG, "$network onNetworkResumed $network") 150 history.add(Resumed(network)) 151 } 152 153 override fun onLosing(network: Network, maxMsToLive: Int) { 154 Log.d(TAG, "onLosing $network $maxMsToLive") 155 history.add(Losing(network, maxMsToLive)) 156 } 157 158 override fun onLost(network: Network) { 159 Log.d(TAG, "onLost $network") 160 history.add(Lost(network)) 161 } 162 163 override fun onUnavailable() { 164 Log.d(TAG, "onUnavailable") 165 history.add(Unavailable()) 166 } 167 } 168 169 private const val DEFAULT_TIMEOUT = 200L // ms 170 171 open class TestableNetworkCallback private constructor( 172 src: TestableNetworkCallback?, 173 val defaultTimeoutMs: Long = DEFAULT_TIMEOUT 174 ) : RecorderCallback(src) { 175 @JvmOverloads 176 constructor(timeoutMs: Long = DEFAULT_TIMEOUT): this(null, timeoutMs) 177 178 fun createLinkedCopy() = TestableNetworkCallback(this, defaultTimeoutMs) 179 180 // The last available network, or null if any network was lost since the last call to 181 // onAvailable. TODO : fix this by fixing the tests that rely on this behavior 182 val lastAvailableNetwork: Network? 183 get() = when (val it = history.lastOrNull { it is Available || it is Lost }) { 184 is Available -> it.network 185 else -> null 186 } 187 188 fun pollForNextCallback(timeoutMs: Long = defaultTimeoutMs): CallbackEntry { 189 return history.poll(timeoutMs) ?: fail("Did not receive callback after ${timeoutMs}ms") 190 } 191 192 // Make open for use in ConnectivityServiceTest which is the only one knowing its handlers. 193 // TODO : remove the necessity to overload this, remove the open qualifier, and give a 194 // default argument to assertNoCallback instead, possibly with @JvmOverloads if necessary. 195 open fun assertNoCallback() = assertNoCallback(defaultTimeoutMs) 196 197 fun assertNoCallback(timeoutMs: Long) { 198 val cb = history.poll(timeoutMs) 199 if (null != cb) fail("Expected no callback but got $cb") 200 } 201 202 // Expects a callback of the specified type on the specified network within the timeout. 203 // If no callback arrives, or a different callback arrives, fail. Returns the callback. 204 inline fun <reified T : CallbackEntry> expectCallback( 205 network: Network = ANY_NETWORK, 206 timeoutMs: Long = defaultTimeoutMs 207 ): T = pollForNextCallback(timeoutMs).let { 208 if (it !is T || (ANY_NETWORK !== network && it.network != network)) { 209 fail("Unexpected callback : $it, expected ${T::class} with Network[$network]") 210 } else { 211 it 212 } 213 } 214 215 // Expects a callback of the specified type matching the predicate within the timeout. 216 // Any callback that doesn't match the predicate will be skipped. Fails only if 217 // no matching callback is received within the timeout. 218 inline fun <reified T : CallbackEntry> eventuallyExpect( 219 timeoutMs: Long = defaultTimeoutMs, 220 from: Int = mark, 221 crossinline predicate: (T) -> Boolean = { true } 222 ): T = eventuallyExpectOrNull(timeoutMs, from, predicate).also { 223 assertNotNull(it, "Callback ${T::class} not received within ${timeoutMs}ms") 224 } as T 225 226 fun <T : CallbackEntry> eventuallyExpect( 227 type: KClass<T>, 228 timeoutMs: Long = defaultTimeoutMs, 229 predicate: (T: CallbackEntry) -> Boolean = { true } 230 ) = history.poll(timeoutMs) { type.java.isInstance(it) && predicate(it) }.also { 231 assertNotNull(it, "Callback ${type.java} not received within ${timeoutMs}ms") 232 } as T 233 234 // TODO (b/157405399) straighten and unify the method names 235 inline fun <reified T : CallbackEntry> eventuallyExpectOrNull( 236 timeoutMs: Long = defaultTimeoutMs, 237 from: Int = mark, 238 crossinline predicate: (T) -> Boolean = { true } 239 ) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T? 240 241 fun expectCallbackThat( 242 timeoutMs: Long = defaultTimeoutMs, 243 valid: (CallbackEntry) -> Boolean 244 ) = pollForNextCallback(timeoutMs).also { assertTrue(valid(it), "Unexpected callback : $it") } 245 246 fun expectCapabilitiesThat( 247 net: Network, 248 tmt: Long = defaultTimeoutMs, 249 valid: (NetworkCapabilities) -> Boolean 250 ): CapabilitiesChanged { 251 return expectCallback<CapabilitiesChanged>(net, tmt).also { 252 assertTrue(valid(it.caps), "Capabilities don't match expectations ${it.caps}") 253 } 254 } 255 256 fun expectLinkPropertiesThat( 257 net: Network, 258 tmt: Long = defaultTimeoutMs, 259 valid: (LinkProperties) -> Boolean 260 ): LinkPropertiesChanged { 261 return expectCallback<LinkPropertiesChanged>(net, tmt).also { 262 assertTrue(valid(it.lp), "LinkProperties don't match expectations ${it.lp}") 263 } 264 } 265 266 // Expects onAvailable and the callbacks that follow it. These are: 267 // - onSuspended, iff the network was suspended when the callbacks fire. 268 // - onCapabilitiesChanged. 269 // - onLinkPropertiesChanged. 270 // - onBlockedStatusChanged. 271 // 272 // @param network the network to expect the callbacks on. 273 // @param suspended whether to expect a SUSPENDED callback. 274 // @param validated the expected value of the VALIDATED capability in the 275 // onCapabilitiesChanged callback. 276 // @param tmt how long to wait for the callbacks. 277 fun expectAvailableCallbacks( 278 net: Network, 279 suspended: Boolean = false, 280 validated: Boolean = true, 281 blocked: Boolean = false, 282 tmt: Long = defaultTimeoutMs 283 ) { 284 expectAvailableCallbacksCommon(net, suspended, validated, tmt) 285 expectBlockedStatusCallback(blocked, net, tmt) 286 } 287 288 fun expectAvailableCallbacks( 289 net: Network, 290 suspended: Boolean, 291 validated: Boolean, 292 blockedStatus: Int, 293 tmt: Long 294 ) { 295 expectAvailableCallbacksCommon(net, suspended, validated, tmt) 296 expectBlockedStatusCallback(blockedStatus, net) 297 } 298 299 private fun expectAvailableCallbacksCommon( 300 net: Network, 301 suspended: Boolean, 302 validated: Boolean, 303 tmt: Long 304 ) { 305 expectCallback<Available>(net, tmt) 306 if (suspended) { 307 expectCallback<Suspended>(net, tmt) 308 } 309 expectCapabilitiesThat(net, tmt) { validated == it.hasCapability(NET_CAPABILITY_VALIDATED) } 310 expectCallback<LinkPropertiesChanged>(net, tmt) 311 } 312 313 // Backward compatibility for existing Java code. Use named arguments instead and remove all 314 // these when there is no user left. 315 fun expectAvailableAndSuspendedCallbacks( 316 net: Network, 317 validated: Boolean, 318 tmt: Long = defaultTimeoutMs 319 ) = expectAvailableCallbacks(net, suspended = true, validated = validated, tmt = tmt) 320 321 fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) { 322 expectCallback<BlockedStatus>(net, tmt).also { 323 assertEquals(it.blocked, blocked, "Unexpected blocked status ${it.blocked}") 324 } 325 } 326 327 fun expectBlockedStatusCallback(blocked: Int, net: Network, tmt: Long = defaultTimeoutMs) { 328 expectCallback<BlockedStatusInt>(net, tmt).also { 329 assertEquals(it.blocked, blocked, "Unexpected blocked status ${it.blocked}") 330 } 331 } 332 333 // Expects the available callbacks (where the onCapabilitiesChanged must contain the 334 // VALIDATED capability), plus another onCapabilitiesChanged which is identical to the 335 // one we just sent. 336 // TODO: this is likely a bug. Fix it and remove this method. 337 fun expectAvailableDoubleValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) { 338 val mark = history.mark 339 expectAvailableCallbacks(net, tmt = tmt) 340 val firstCaps = history.poll(tmt, mark) { it is CapabilitiesChanged } 341 assertEquals(firstCaps, expectCallback<CapabilitiesChanged>(net, tmt)) 342 } 343 344 // Expects the available callbacks where the onCapabilitiesChanged must not have validated, 345 // then expects another onCapabilitiesChanged that has the validated bit set. This is used 346 // when a network connects and satisfies a callback, and then immediately validates. 347 fun expectAvailableThenValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) { 348 expectAvailableCallbacks(net, validated = false, tmt = tmt) 349 expectCapabilitiesThat(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) } 350 } 351 352 fun expectAvailableThenValidatedCallbacks( 353 net: Network, 354 blockedStatus: Int, 355 tmt: Long = defaultTimeoutMs 356 ) { 357 expectAvailableCallbacks(net, validated = false, suspended = false, 358 blockedStatus = blockedStatus, tmt = tmt) 359 expectCapabilitiesThat(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) } 360 } 361 362 // Temporary Java compat measure : have MockNetworkAgent implement this so that all existing 363 // calls with networkAgent can be routed through here without moving MockNetworkAgent. 364 // TODO: clean this up, remove this method. 365 interface HasNetwork { 366 val network: Network 367 } 368 369 @JvmOverloads 370 open fun <T : CallbackEntry> expectCallback( 371 type: KClass<T>, 372 n: Network?, 373 timeoutMs: Long = defaultTimeoutMs 374 ) = pollForNextCallback(timeoutMs).also { 375 val network = n ?: NULL_NETWORK 376 // TODO : remove this .java access if the tests ever use kotlin-reflect. At the time of 377 // this writing this would be the only use of this library in the tests. 378 assertTrue(type.java.isInstance(it) && it.network == network, 379 "Unexpected callback : $it, expected ${type.java} with Network[$network]") 380 } as T 381 382 @JvmOverloads 383 open fun <T : CallbackEntry> expectCallback( 384 type: KClass<T>, 385 n: HasNetwork?, 386 timeoutMs: Long = defaultTimeoutMs 387 ) = expectCallback(type, n?.network, timeoutMs) 388 389 fun expectAvailableCallbacks( 390 n: HasNetwork, 391 suspended: Boolean, 392 validated: Boolean, 393 blocked: Boolean, 394 timeoutMs: Long 395 ) = expectAvailableCallbacks(n.network, suspended, validated, blocked, timeoutMs) 396 397 fun expectAvailableAndSuspendedCallbacks(n: HasNetwork, expectValidated: Boolean) { 398 expectAvailableAndSuspendedCallbacks(n.network, expectValidated) 399 } 400 401 fun expectAvailableCallbacksValidated(n: HasNetwork) { 402 expectAvailableCallbacks(n.network) 403 } 404 405 fun expectAvailableCallbacksValidatedAndBlocked(n: HasNetwork) { 406 expectAvailableCallbacks(n.network, blocked = true) 407 } 408 409 fun expectAvailableCallbacksUnvalidated(n: HasNetwork) { 410 expectAvailableCallbacks(n.network, validated = false) 411 } 412 413 fun expectAvailableCallbacksUnvalidatedAndBlocked(n: HasNetwork) { 414 expectAvailableCallbacks(n.network, validated = false, blocked = true) 415 } 416 417 fun expectAvailableDoubleValidatedCallbacks(n: HasNetwork) { 418 expectAvailableDoubleValidatedCallbacks(n.network, defaultTimeoutMs) 419 } 420 421 fun expectAvailableThenValidatedCallbacks(n: HasNetwork) { 422 expectAvailableThenValidatedCallbacks(n.network, defaultTimeoutMs) 423 } 424 425 @JvmOverloads 426 fun expectLinkPropertiesThat( 427 n: HasNetwork, 428 tmt: Long = defaultTimeoutMs, 429 valid: (LinkProperties) -> Boolean 430 ) = expectLinkPropertiesThat(n.network, tmt, valid) 431 432 @JvmOverloads 433 fun expectCapabilitiesThat( 434 n: HasNetwork, 435 tmt: Long = defaultTimeoutMs, 436 valid: (NetworkCapabilities) -> Boolean 437 ) = expectCapabilitiesThat(n.network, tmt, valid) 438 439 @JvmOverloads 440 fun expectCapabilitiesWith( 441 capability: Int, 442 n: HasNetwork, 443 timeoutMs: Long = defaultTimeoutMs 444 ): NetworkCapabilities { 445 return expectCapabilitiesThat(n.network, timeoutMs) { it.hasCapability(capability) }.caps 446 } 447 448 @JvmOverloads 449 fun expectCapabilitiesWithout( 450 capability: Int, 451 n: HasNetwork, 452 timeoutMs: Long = defaultTimeoutMs 453 ): NetworkCapabilities { 454 return expectCapabilitiesThat(n.network, timeoutMs) { !it.hasCapability(capability) }.caps 455 } 456 457 fun expectBlockedStatusCallback(expectBlocked: Boolean, n: HasNetwork) { 458 expectBlockedStatusCallback(expectBlocked, n.network, defaultTimeoutMs) 459 } 460 } 461