1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.testutils
18 
19 import android.content.Context
20 import android.net.ConnectivityManager
21 import android.net.ConnectivityManager.NetworkCallback
22 import android.net.LinkAddress
23 import android.net.Network
24 import android.net.NetworkCapabilities
25 import android.net.NetworkRequest
26 import android.net.TestNetworkInterface
27 import android.net.TestNetworkManager
28 import android.os.Binder
29 import com.android.modules.utils.build.SdkLevel.isAtLeastS
30 import java.util.concurrent.CompletableFuture
31 import java.util.concurrent.TimeUnit
32 
33 /**
34  * Create a test network based on a TUN interface.
35  *
36  * This method will block until the test network is available. Requires
37  * [android.Manifest.permission.CHANGE_NETWORK_STATE] and
38  * [android.Manifest.permission.MANAGE_TEST_NETWORKS].
39  */
40 fun initTestNetwork(context: Context, interfaceAddr: LinkAddress, setupTimeoutMs: Long = 10_000L):
41         TestNetworkTracker {
42     val tnm = context.getSystemService(TestNetworkManager::class.java)
43     val iface = if (isAtLeastS()) tnm.createTunInterface(listOf(interfaceAddr))
44             else tnm.createTunInterface(arrayOf(interfaceAddr))
45     return TestNetworkTracker(context, iface, tnm, setupTimeoutMs)
46 }
47 
48 /**
49  * Utility class to create and track test networks.
50  *
51  * This class is not thread-safe.
52  */
53 class TestNetworkTracker internal constructor(
54     val context: Context,
55     val iface: TestNetworkInterface,
56     val tnm: TestNetworkManager,
57     setupTimeoutMs: Long
58 ) {
59     private val cm = context.getSystemService(ConnectivityManager::class.java)
60     private val binder = Binder()
61 
62     private val networkCallback: NetworkCallback
63     val network: Network
64 
65     init {
66         val networkFuture = CompletableFuture<Network>()
67         val networkRequest = NetworkRequest.Builder()
68                 .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
69                 // Test networks do not have NOT_VPN or TRUSTED capabilities by default
70                 .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
71                 .removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
72                 .setNetworkSpecifier(CompatUtil.makeTestNetworkSpecifier(iface.interfaceName))
73                 .build()
74         networkCallback = object : NetworkCallback() {
75             override fun onAvailable(network: Network) {
76                 networkFuture.complete(network)
77             }
78         }
79         cm.requestNetwork(networkRequest, networkCallback)
80 
81         try {
82             tnm.setupTestNetwork(iface.interfaceName, binder)
83             network = networkFuture.get(setupTimeoutMs, TimeUnit.MILLISECONDS)
84         } catch (e: Throwable) {
85             teardown()
86             throw e
87         }
88     }
89 
90     fun teardown() {
91         cm.unregisterNetworkCallback(networkCallback)
92         tnm.teardownTestNetwork(network)
93     }
94 }