1 package android.net.testutils 2 3 import android.net.LinkAddress 4 import android.net.LinkProperties 5 import android.net.Network 6 import android.net.NetworkCapabilities 7 import com.android.testutils.ConcurrentInterpreter 8 import com.android.testutils.InterpretMatcher 9 import com.android.testutils.RecorderCallback.CallbackEntry 10 import com.android.testutils.RecorderCallback.CallbackEntry.Available 11 import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus 12 import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged 13 import com.android.testutils.TestableNetworkCallback 14 import com.android.testutils.intArg 15 import com.android.testutils.strArg 16 import com.android.testutils.timeArg 17 import org.junit.Before 18 import org.junit.Test 19 import org.junit.runner.RunWith 20 import org.junit.runners.JUnit4 21 import kotlin.reflect.KClass 22 import kotlin.test.assertEquals 23 import kotlin.test.assertFails 24 import kotlin.test.assertNull 25 import kotlin.test.assertTrue 26 import kotlin.test.fail 27 28 const val SHORT_TIMEOUT_MS = 20L 29 const val DEFAULT_LINGER_DELAY_MS = 30000 30 const val NOT_METERED = NetworkCapabilities.NET_CAPABILITY_NOT_METERED 31 const val WIFI = NetworkCapabilities.TRANSPORT_WIFI 32 const val CELLULAR = NetworkCapabilities.TRANSPORT_CELLULAR 33 const val TEST_INTERFACE_NAME = "testInterfaceName" 34 35 @RunWith(JUnit4::class) 36 class TestableNetworkCallbackTest { 37 private lateinit var mCallback: TestableNetworkCallback 38 39 private fun makeHasNetwork(netId: Int) = object : TestableNetworkCallback.HasNetwork { 40 override val network: Network = Network(netId) 41 } 42 43 @Before 44 fun setUp() { 45 mCallback = TestableNetworkCallback() 46 } 47 48 @Test 49 fun testLastAvailableNetwork() { 50 // Make sure there is no last available network at first, then the last available network 51 // is returned after onAvailable is called. 52 val net2097 = Network(2097) 53 assertNull(mCallback.lastAvailableNetwork) 54 mCallback.onAvailable(net2097) 55 assertEquals(mCallback.lastAvailableNetwork, net2097) 56 57 // Make sure calling onCapsChanged/onLinkPropertiesChanged don't affect the last available 58 // network. 59 mCallback.onCapabilitiesChanged(net2097, NetworkCapabilities()) 60 mCallback.onLinkPropertiesChanged(net2097, LinkProperties()) 61 assertEquals(mCallback.lastAvailableNetwork, net2097) 62 63 // Make sure onLost clears the last available network. 64 mCallback.onLost(net2097) 65 assertNull(mCallback.lastAvailableNetwork) 66 67 // Do the same but with a different network after onLost : make sure the last available 68 // network is the new one, not the original one. 69 val net2098 = Network(2098) 70 mCallback.onAvailable(net2098) 71 mCallback.onCapabilitiesChanged(net2098, NetworkCapabilities()) 72 mCallback.onLinkPropertiesChanged(net2098, LinkProperties()) 73 assertEquals(mCallback.lastAvailableNetwork, net2098) 74 75 // Make sure onAvailable changes the last available network even if onLost was not called. 76 val net2099 = Network(2099) 77 mCallback.onAvailable(net2099) 78 assertEquals(mCallback.lastAvailableNetwork, net2099) 79 80 // For legacy reasons, lastAvailableNetwork is null as soon as any is lost, not necessarily 81 // the last available one. Check that behavior. 82 mCallback.onLost(net2098) 83 assertNull(mCallback.lastAvailableNetwork) 84 85 // Make sure that losing the really last available one still results in null. 86 mCallback.onLost(net2099) 87 assertNull(mCallback.lastAvailableNetwork) 88 89 // Make sure multiple onAvailable in a row then onLost still results in null. 90 mCallback.onAvailable(net2097) 91 mCallback.onAvailable(net2098) 92 mCallback.onAvailable(net2099) 93 mCallback.onLost(net2097) 94 assertNull(mCallback.lastAvailableNetwork) 95 } 96 97 @Test 98 fun testAssertNoCallback() { 99 mCallback.assertNoCallback(SHORT_TIMEOUT_MS) 100 mCallback.onAvailable(Network(100)) 101 assertFails { mCallback.assertNoCallback(SHORT_TIMEOUT_MS) } 102 } 103 104 @Test 105 fun testCapabilitiesWithAndWithout() { 106 val net = Network(101) 107 val matcher = makeHasNetwork(101) 108 val meteredNc = NetworkCapabilities() 109 val unmeteredNc = NetworkCapabilities().addCapability(NOT_METERED) 110 // Check that expecting caps (with or without) fails when no callback has been received. 111 assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) } 112 assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) } 113 114 // Add NOT_METERED and check that With succeeds and Without fails. 115 mCallback.onCapabilitiesChanged(net, unmeteredNc) 116 mCallback.expectCapabilitiesWith(NOT_METERED, matcher) 117 mCallback.onCapabilitiesChanged(net, unmeteredNc) 118 assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) } 119 120 // Don't add NOT_METERED and check that With fails and Without succeeds. 121 mCallback.onCapabilitiesChanged(net, meteredNc) 122 assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) } 123 mCallback.onCapabilitiesChanged(net, meteredNc) 124 mCallback.expectCapabilitiesWithout(NOT_METERED, matcher) 125 } 126 127 @Test 128 fun testExpectCallbackThat() { 129 val net = Network(193) 130 val netCaps = NetworkCapabilities().addTransportType(CELLULAR) 131 // Check that expecting callbackThat anything fails when no callback has been received. 132 assertFails { mCallback.expectCallbackThat(SHORT_TIMEOUT_MS) { true } } 133 134 // Basic test for true and false 135 mCallback.onAvailable(net) 136 mCallback.expectCallbackThat { true } 137 mCallback.onAvailable(net) 138 assertFails { mCallback.expectCallbackThat(SHORT_TIMEOUT_MS) { false } } 139 140 // Try a positive and a negative case 141 mCallback.onBlockedStatusChanged(net, true) 142 mCallback.expectCallbackThat { cb -> cb is BlockedStatus && cb.blocked } 143 mCallback.onCapabilitiesChanged(net, netCaps) 144 assertFails { mCallback.expectCallbackThat(SHORT_TIMEOUT_MS) { cb -> 145 cb is CapabilitiesChanged && cb.caps.hasTransport(WIFI) 146 } } 147 } 148 149 @Test 150 fun testCapabilitiesThat() { 151 val net = Network(101) 152 val netCaps = NetworkCapabilities().addCapability(NOT_METERED).addTransportType(WIFI) 153 // Check that expecting capabilitiesThat anything fails when no callback has been received. 154 assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { true } } 155 156 // Basic test for true and false 157 mCallback.onCapabilitiesChanged(net, netCaps) 158 mCallback.expectCapabilitiesThat(net) { true } 159 mCallback.onCapabilitiesChanged(net, netCaps) 160 assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { false } } 161 162 // Try a positive and a negative case 163 mCallback.onCapabilitiesChanged(net, netCaps) 164 mCallback.expectCapabilitiesThat(net) { caps -> 165 caps.hasCapability(NOT_METERED) && 166 caps.hasTransport(WIFI) && 167 !caps.hasTransport(CELLULAR) 168 } 169 mCallback.onCapabilitiesChanged(net, netCaps) 170 assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { caps -> 171 caps.hasTransport(CELLULAR) 172 } } 173 174 // Try a matching callback on the wrong network 175 mCallback.onCapabilitiesChanged(net, netCaps) 176 assertFails { mCallback.expectCapabilitiesThat(Network(100), SHORT_TIMEOUT_MS) { true } } 177 } 178 179 @Test 180 fun testLinkPropertiesThat() { 181 val net = Network(112) 182 val linkAddress = LinkAddress("fe80::ace:d00d/64") 183 val mtu = 1984 184 val linkProps = LinkProperties().apply { 185 this.mtu = mtu 186 interfaceName = TEST_INTERFACE_NAME 187 addLinkAddress(linkAddress) 188 } 189 190 // Check that expecting linkPropsThat anything fails when no callback has been received. 191 assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { true } } 192 193 // Basic test for true and false 194 mCallback.onLinkPropertiesChanged(net, linkProps) 195 mCallback.expectLinkPropertiesThat(net) { true } 196 mCallback.onLinkPropertiesChanged(net, linkProps) 197 assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { false } } 198 199 // Try a positive and negative case 200 mCallback.onLinkPropertiesChanged(net, linkProps) 201 mCallback.expectLinkPropertiesThat(net) { lp -> 202 lp.interfaceName == TEST_INTERFACE_NAME && 203 lp.linkAddresses.contains(linkAddress) && 204 lp.mtu == mtu 205 } 206 mCallback.onLinkPropertiesChanged(net, linkProps) 207 assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { lp -> 208 lp.interfaceName != TEST_INTERFACE_NAME 209 } } 210 211 // Try a matching callback on the wrong network 212 mCallback.onLinkPropertiesChanged(net, linkProps) 213 assertFails { mCallback.expectLinkPropertiesThat(Network(114), SHORT_TIMEOUT_MS) { lp -> 214 lp.interfaceName == TEST_INTERFACE_NAME 215 } } 216 } 217 218 @Test 219 fun testExpectCallback() { 220 val net = Network(103) 221 // Test expectCallback fails when nothing was sent. 222 assertFails { mCallback.expectCallback<BlockedStatus>(net, SHORT_TIMEOUT_MS) } 223 224 // Test onAvailable is seen and can be expected 225 mCallback.onAvailable(net) 226 mCallback.expectCallback<Available>(net, SHORT_TIMEOUT_MS) 227 228 // Test onAvailable won't return calls with a different network 229 mCallback.onAvailable(Network(106)) 230 assertFails { mCallback.expectCallback<Available>(net, SHORT_TIMEOUT_MS) } 231 232 // Test onAvailable won't return calls with a different callback 233 mCallback.onAvailable(net) 234 assertFails { mCallback.expectCallback<BlockedStatus>(net, SHORT_TIMEOUT_MS) } 235 } 236 237 @Test 238 fun testPollForNextCallback() { 239 assertFails { mCallback.pollForNextCallback(SHORT_TIMEOUT_MS) } 240 TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1, 241 threadTransform = { cb -> cb.createLinkedCopy() }, spec = """ 242 sleep; onAvailable(133) | poll(2) = Available(133) time 1..4 243 | poll(1) fails 244 onCapabilitiesChanged(108) | poll(1) = CapabilitiesChanged(108) time 0..3 245 onBlockedStatus(199) | poll(1) = BlockedStatus(199) time 0..3 246 """) 247 } 248 } 249 250 private object TNCInterpreter : ConcurrentInterpreter<TestableNetworkCallback>(interpretTable) 251 252 val EntryList = CallbackEntry::class.sealedSubclasses.map { it.simpleName }.joinToString("|") 253 private fun callbackEntryFromString(name: String): KClass<out CallbackEntry> { 254 return CallbackEntry::class.sealedSubclasses.first { it.simpleName == name } 255 } 256 257 private val interpretTable = listOf<InterpretMatcher<TestableNetworkCallback>>( 258 // Interpret "Available(xx)" as "call to onAvailable with netId xx", and likewise for 259 // all callback types. This is implemented above by enumerating the subclasses of 260 // CallbackEntry and reading their simpleName. 261 Regex("""(.*)\s+=\s+($EntryList)\((\d+)\)""") to { i, cb, t -> 262 val record = i.interpret(t.strArg(1), cb) 263 assertTrue(callbackEntryFromString(t.strArg(2)).isInstance(record)) 264 // Strictly speaking testing for is CallbackEntry is useless as it's been tested above 265 // but the compiler can't figure things out from the isInstance call. It does understand 266 // from the assertTrue(is CallbackEntry) that this is true, which allows to access 267 // the 'network' member below. 268 assertTrue(record is CallbackEntry) 269 assertEquals(record.network.netId, t.intArg(3)) 270 }, 271 // Interpret "onAvailable(xx)" as calling "onAvailable" with a netId of xx, and likewise for 272 // all callback types. NetworkCapabilities and LinkProperties just get an empty object 273 // as their argument. Losing gets the default linger timer. Blocked gets false. 274 Regex("""on($EntryList)\((\d+)\)""") to { i, cb, t -> 275 val net = Network(t.intArg(2)) 276 when (t.strArg(1)) { 277 "Available" -> cb.onAvailable(net) 278 // PreCheck not used in tests. Add it here if it becomes useful. 279 "CapabilitiesChanged" -> cb.onCapabilitiesChanged(net, NetworkCapabilities()) 280 "LinkPropertiesChanged" -> cb.onLinkPropertiesChanged(net, LinkProperties()) 281 "Suspended" -> cb.onNetworkSuspended(net) 282 "Resumed" -> cb.onNetworkResumed(net) 283 "Losing" -> cb.onLosing(net, DEFAULT_LINGER_DELAY_MS) 284 "Lost" -> cb.onLost(net) 285 "Unavailable" -> cb.onUnavailable() 286 "BlockedStatus" -> cb.onBlockedStatusChanged(net, false) 287 else -> fail("Unknown callback type") 288 } 289 }, 290 Regex("""poll\((\d+)\)""") to { i, cb, t -> 291 cb.pollForNextCallback(t.timeArg(1)) 292 } 293 ) 294