1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.networkstack
18 
19 import android.app.Notification
20 import android.app.NotificationChannel
21 import android.app.NotificationManager
22 import android.app.NotificationManager.IMPORTANCE_DEFAULT
23 import android.app.NotificationManager.IMPORTANCE_NONE
24 import android.app.PendingIntent
25 import android.app.PendingIntent.FLAG_IMMUTABLE
26 import android.content.Context
27 import android.content.Intent
28 import android.content.res.Resources
29 import android.net.CaptivePortalData
30 import android.net.ConnectivityManager
31 import android.net.ConnectivityManager.EXTRA_NETWORK
32 import android.net.ConnectivityManager.NetworkCallback
33 import android.net.LinkProperties
34 import android.net.Network
35 import android.net.NetworkCapabilities
36 import android.net.NetworkCapabilities.NET_CAPABILITY_VALIDATED
37 import android.net.NetworkCapabilities.TRANSPORT_WIFI
38 import android.net.Uri
39 import android.os.Handler
40 import android.os.UserHandle
41 import android.provider.Settings
42 import android.testing.AndroidTestingRunner
43 import android.testing.TestableLooper
44 import android.testing.TestableLooper.RunWithLooper
45 import androidx.test.filters.SmallTest
46 import androidx.test.platform.app.InstrumentationRegistry
47 import com.android.dx.mockito.inline.extended.ExtendedMockito.verify
48 import com.android.networkstack.NetworkStackNotifier.CHANNEL_CONNECTED
49 import com.android.networkstack.NetworkStackNotifier.CHANNEL_VENUE_INFO
50 import com.android.networkstack.NetworkStackNotifier.CONNECTED_NOTIFICATION_TIMEOUT_MS
51 import com.android.networkstack.NetworkStackNotifier.Dependencies
52 import com.android.networkstack.apishim.NetworkInformationShimImpl
53 import org.junit.Assume.assumeTrue
54 import org.junit.Before
55 import org.junit.Test
56 import org.junit.runner.RunWith
57 import org.mockito.ArgumentCaptor
58 import org.mockito.ArgumentMatchers.anyInt
59 import org.mockito.ArgumentMatchers.eq
60 import org.mockito.ArgumentMatchers.intThat
61 import org.mockito.Captor
62 import org.mockito.Mock
63 import org.mockito.Mockito.any
64 import org.mockito.Mockito.doReturn
65 import org.mockito.Mockito.never
66 import org.mockito.MockitoAnnotations
67 import kotlin.reflect.KClass
68 import kotlin.test.assertEquals
69 
70 @RunWith(AndroidTestingRunner::class)
71 @SmallTest
72 @RunWithLooper
73 class NetworkStackNotifierTest {
74     @Mock
75     private lateinit var mContext: Context
76     @Mock
77     private lateinit var mCurrentUserContext: Context
78     @Mock
79     private lateinit var mAllUserContext: Context
80     @Mock
81     private lateinit var mDependencies: Dependencies
82     @Mock
83     private lateinit var mNm: NotificationManager
84     @Mock
85     private lateinit var mNotificationChannelsNm: NotificationManager
86     @Mock
87     private lateinit var mCm: ConnectivityManager
88     @Mock
89     private lateinit var mResources: Resources
90     @Mock
91     private lateinit var mPendingIntent: PendingIntent
92     @Captor
93     private lateinit var mNoteCaptor: ArgumentCaptor<Notification>
94     @Captor
95     private lateinit var mNoteIdCaptor: ArgumentCaptor<Int>
96     @Captor
97     private lateinit var mIntentCaptor: ArgumentCaptor<Intent>
98     private lateinit var mLooper: TestableLooper
99     private lateinit var mHandler: Handler
100     private lateinit var mNotifier: NetworkStackNotifier
101 
102     private lateinit var mAllNetworksCb: NetworkCallback
103     private lateinit var mDefaultNetworkCb: NetworkCallback
104 
105     // Lazy-init as CaptivePortalData does not exist on Q.
106     private val mTestCapportLp by lazy {
107         LinkProperties().apply {
108             captivePortalData = CaptivePortalData.Builder()
109                     .setCaptive(false)
110                     .setVenueInfoUrl(Uri.parse(TEST_VENUE_INFO_URL))
111                     .build()
112         }
113     }
114 
115     private val mTestCapportVenueUrlWithFriendlyNameLp by lazy {
116         LinkProperties().apply {
117             captivePortalData = CaptivePortalData.Builder()
118                     .setCaptive(false)
119                     .setVenueInfoUrl(Uri.parse(TEST_VENUE_INFO_URL))
120                     .build()
121             val networkShim = NetworkInformationShimImpl.newInstance()
122             val captivePortalDataShim = networkShim.getCaptivePortalData(this)
123 
124             if (captivePortalDataShim != null) {
125                 networkShim.setCaptivePortalData(this, captivePortalDataShim
126                         .withVenueFriendlyName(TEST_NETWORK_FRIENDLY_NAME))
127             }
128         }
129     }
130 
131     private val TEST_NETWORK = Network(42)
132     private val TEST_NETWORK_TAG = TEST_NETWORK.networkHandle.toString()
133     private val TEST_SSID = "TestSsid"
134     private val EMPTY_CAPABILITIES = NetworkCapabilities()
135     private val VALIDATED_CAPABILITIES = NetworkCapabilities()
136             .addTransportType(TRANSPORT_WIFI)
137             .addCapability(NET_CAPABILITY_VALIDATED)
138 
139     private val TEST_CONNECTED_DESCRIPTION = "Connected"
140     private val TEST_VENUE_DESCRIPTION = "Connected / Tap to view website"
141 
142     private val TEST_VENUE_INFO_URL = "https://testvenue.example.com/info"
143     private val EMPTY_CAPPORT_LP = LinkProperties()
144     private val TEST_NETWORK_FRIENDLY_NAME = "Network Friendly Name"
145 
146     @Before
147     fun setUp() {
148         MockitoAnnotations.initMocks(this)
149         mLooper = TestableLooper.get(this)
150         doReturn(mResources).`when`(mContext).resources
151         doReturn(TEST_CONNECTED_DESCRIPTION).`when`(mResources).getString(R.string.connected)
152         doReturn(TEST_VENUE_DESCRIPTION).`when`(mResources).getString(R.string.tap_for_info)
153 
154         // applicationInfo is used by Notification.Builder
155         val realContext = InstrumentationRegistry.getInstrumentation().context
156         doReturn(realContext.applicationInfo).`when`(mContext).applicationInfo
157         doReturn(realContext.packageName).`when`(mContext).packageName
158 
159         doReturn(mCurrentUserContext).`when`(mContext).createPackageContextAsUser(
160                 realContext.packageName, 0, UserHandle.CURRENT)
161         doReturn(mAllUserContext).`when`(mContext).createPackageContextAsUser(
162                 realContext.packageName, 0, UserHandle.ALL)
163 
164         mAllUserContext.mockService(Context.NOTIFICATION_SERVICE, NotificationManager::class, mNm)
165         mContext.mockService(Context.NOTIFICATION_SERVICE, NotificationManager::class,
166                 mNotificationChannelsNm)
167         mContext.mockService(Context.CONNECTIVITY_SERVICE, ConnectivityManager::class, mCm)
168 
169         doReturn(NotificationChannel(CHANNEL_VENUE_INFO, "TestChannel", IMPORTANCE_DEFAULT))
170                 .`when`(mNotificationChannelsNm).getNotificationChannel(CHANNEL_VENUE_INFO)
171 
172         doReturn(mPendingIntent).`when`(mDependencies).getActivityPendingIntent(
173                 any(), any(), anyInt())
174         mNotifier = NetworkStackNotifier(mContext, mLooper.looper, mDependencies)
175         mHandler = mNotifier.handler
176 
177         val allNetworksCbCaptor = ArgumentCaptor.forClass(NetworkCallback::class.java)
178         verify(mCm).registerNetworkCallback(any() /* request */, allNetworksCbCaptor.capture(),
179                 eq(mHandler))
180         mAllNetworksCb = allNetworksCbCaptor.value
181 
182         val defaultNetworkCbCaptor = ArgumentCaptor.forClass(NetworkCallback::class.java)
183         verify(mCm).registerDefaultNetworkCallback(defaultNetworkCbCaptor.capture(), eq(mHandler))
184         mDefaultNetworkCb = defaultNetworkCbCaptor.value
185     }
186 
187     private fun <T : Any> Context.mockService(name: String, clazz: KClass<T>, service: T) {
188         doReturn(service).`when`(this).getSystemService(name)
189         doReturn(name).`when`(this).getSystemServiceName(clazz.java)
190         doReturn(service).`when`(this).getSystemService(clazz.java)
191     }
192 
193     @Test
194     fun testNoNotification() {
195         onCapabilitiesChanged(EMPTY_CAPABILITIES)
196         onCapabilitiesChanged(VALIDATED_CAPABILITIES)
197 
198         mLooper.processAllMessages()
199         verify(mNm, never()).notify(any(), anyInt(), any())
200     }
201 
202     private fun verifyConnectedNotification(timeout: Long = CONNECTED_NOTIFICATION_TIMEOUT_MS) {
203         verify(mNm).notify(eq(TEST_NETWORK_TAG), mNoteIdCaptor.capture(), mNoteCaptor.capture())
204         val note = mNoteCaptor.value
205         assertEquals(mPendingIntent, note.contentIntent)
206         assertEquals(CHANNEL_CONNECTED, note.channelId)
207         assertEquals(timeout, note.timeoutAfter)
208         verify(mDependencies).getActivityPendingIntent(
209                 eq(mCurrentUserContext), mIntentCaptor.capture(),
210                 intThat { it or FLAG_IMMUTABLE != 0 })
211     }
212 
213     private fun verifyCanceledNotificationAfterNetworkLost() {
214         onLost(TEST_NETWORK)
215         mLooper.processAllMessages()
216         verify(mNm).cancel(TEST_NETWORK_TAG, mNoteIdCaptor.value)
217     }
218 
219     private fun verifyCanceledNotificationAfterDefaultNetworkLost() {
220         onDefaultNetworkLost(TEST_NETWORK)
221         mLooper.processAllMessages()
222         verify(mNm).cancel(TEST_NETWORK_TAG, mNoteIdCaptor.value)
223     }
224 
225     @Test
226     fun testConnectedNotification_NoSsid() {
227         onCapabilitiesChanged(EMPTY_CAPABILITIES)
228         mNotifier.notifyCaptivePortalValidationPending(TEST_NETWORK)
229         onCapabilitiesChanged(VALIDATED_CAPABILITIES)
230         mLooper.processAllMessages()
231         // There is no notification when SSID is not set.
232         verify(mNm, never()).notify(any(), anyInt(), any())
233     }
234 
235     @Test
236     fun testConnectedNotification_WithSsid() {
237         // NetworkCapabilities#getSSID is not available for API <= Q
238         assumeTrue(NetworkInformationShimImpl.useApiAboveQ())
239         val capabilities = NetworkCapabilities(VALIDATED_CAPABILITIES).setSSID(TEST_SSID)
240 
241         onCapabilitiesChanged(EMPTY_CAPABILITIES)
242         mNotifier.notifyCaptivePortalValidationPending(TEST_NETWORK)
243         onCapabilitiesChanged(capabilities)
244         mLooper.processAllMessages()
245 
246         verifyConnectedNotification()
247         verify(mResources).getString(R.string.connected)
248         verifyWifiSettingsIntent(mIntentCaptor.value)
249         verifyCanceledNotificationAfterNetworkLost()
250     }
251 
252     @Test
253     fun testConnectedVenueInfoNotification() {
254         // Venue info (CaptivePortalData) is not available for API <= Q
255         assumeTrue(NetworkInformationShimImpl.useApiAboveQ())
256         mNotifier.notifyCaptivePortalValidationPending(TEST_NETWORK)
257         onLinkPropertiesChanged(mTestCapportLp)
258         onDefaultNetworkAvailable(TEST_NETWORK)
259         val capabilities = NetworkCapabilities(VALIDATED_CAPABILITIES).setSSID(TEST_SSID)
260         onCapabilitiesChanged(capabilities)
261 
262         mLooper.processAllMessages()
263 
264         verifyConnectedNotification(timeout = 0)
265         verifyVenueInfoIntent(mIntentCaptor.value)
266         verify(mResources).getString(R.string.tap_for_info)
267         verifyCanceledNotificationAfterDefaultNetworkLost()
268     }
269 
270     @Test
271     fun testConnectedVenueInfoNotification_VenueInfoDisabled() {
272         // Venue info (CaptivePortalData) is not available for API <= Q
273         assumeTrue(NetworkInformationShimImpl.useApiAboveQ())
274         val channel = NotificationChannel(CHANNEL_VENUE_INFO, "test channel", IMPORTANCE_NONE)
275         doReturn(channel).`when`(mNotificationChannelsNm).getNotificationChannel(CHANNEL_VENUE_INFO)
276         mNotifier.notifyCaptivePortalValidationPending(TEST_NETWORK)
277         onLinkPropertiesChanged(mTestCapportLp)
278         onDefaultNetworkAvailable(TEST_NETWORK)
279         val capabilities = NetworkCapabilities(VALIDATED_CAPABILITIES).setSSID(TEST_SSID)
280         onCapabilitiesChanged(capabilities)
281         mLooper.processAllMessages()
282 
283         verifyConnectedNotification()
284         verifyWifiSettingsIntent(mIntentCaptor.value)
285         verify(mResources, never()).getString(R.string.tap_for_info)
286         verifyCanceledNotificationAfterNetworkLost()
287     }
288 
289     @Test
290     fun testVenueInfoNotification() {
291         // Venue info (CaptivePortalData) is not available for API <= Q
292         assumeTrue(NetworkInformationShimImpl.useApiAboveQ())
293         onLinkPropertiesChanged(mTestCapportLp)
294         onDefaultNetworkAvailable(TEST_NETWORK)
295         val capabilities = NetworkCapabilities(VALIDATED_CAPABILITIES).setSSID(TEST_SSID)
296         onCapabilitiesChanged(capabilities)
297         mLooper.processAllMessages()
298 
299         verify(mNm).notify(eq(TEST_NETWORK_TAG), mNoteIdCaptor.capture(), mNoteCaptor.capture())
300         verify(mDependencies).getActivityPendingIntent(
301                 eq(mCurrentUserContext), mIntentCaptor.capture(),
302                 intThat { it or FLAG_IMMUTABLE != 0 })
303         verifyVenueInfoIntent(mIntentCaptor.value)
304         verifyCanceledNotificationAfterDefaultNetworkLost()
305     }
306 
307     @Test
308     fun testVenueInfoNotification_VenueInfoDisabled() {
309         // Venue info (CaptivePortalData) is not available for API <= Q
310         assumeTrue(NetworkInformationShimImpl.useApiAboveQ())
311         doReturn(null).`when`(mNm).getNotificationChannel(CHANNEL_VENUE_INFO)
312         onLinkPropertiesChanged(mTestCapportLp)
313         onDefaultNetworkAvailable(TEST_NETWORK)
314         onCapabilitiesChanged(VALIDATED_CAPABILITIES)
315         mLooper.processAllMessages()
316 
317         verify(mNm, never()).notify(any(), anyInt(), any())
318     }
319 
320     @Test
321     fun testNonDefaultVenueInfoNotification() {
322         // Venue info (CaptivePortalData) is not available for API <= Q
323         assumeTrue(NetworkInformationShimImpl.useApiAboveQ())
324         onLinkPropertiesChanged(mTestCapportLp)
325         onCapabilitiesChanged(VALIDATED_CAPABILITIES)
326         mLooper.processAllMessages()
327 
328         verify(mNm, never()).notify(eq(TEST_NETWORK_TAG), anyInt(), any())
329     }
330 
331     @Test
332     fun testEmptyCaptivePortalDataVenueInfoNotification() {
333         // Venue info (CaptivePortalData) is not available for API <= Q
334         assumeTrue(NetworkInformationShimImpl.useApiAboveQ())
335         onLinkPropertiesChanged(EMPTY_CAPPORT_LP)
336         onCapabilitiesChanged(VALIDATED_CAPABILITIES)
337         mLooper.processAllMessages()
338 
339         verify(mNm, never()).notify(eq(TEST_NETWORK_TAG), anyInt(), any())
340     }
341 
342     @Test
343     fun testUnvalidatedNetworkVenueInfoNotification() {
344         // Venue info (CaptivePortalData) is not available for API <= Q
345         assumeTrue(NetworkInformationShimImpl.useApiAboveQ())
346         onLinkPropertiesChanged(mTestCapportLp)
347         onCapabilitiesChanged(EMPTY_CAPABILITIES)
348         mLooper.processAllMessages()
349 
350         verify(mNm, never()).notify(eq(TEST_NETWORK_TAG), anyInt(), any())
351     }
352 
353     @Test
354     fun testConnectedVenueInfoWithFriendlyNameNotification() {
355         // Venue info (CaptivePortalData) with friendly name is not available for API <= R
356         assumeTrue(NetworkInformationShimImpl.useApiAboveR())
357         mNotifier.notifyCaptivePortalValidationPending(TEST_NETWORK)
358         onLinkPropertiesChanged(mTestCapportVenueUrlWithFriendlyNameLp)
359         onDefaultNetworkAvailable(TEST_NETWORK)
360         val capabilities = NetworkCapabilities(VALIDATED_CAPABILITIES).setSSID(TEST_SSID)
361         onCapabilitiesChanged(capabilities)
362 
363         mLooper.processAllMessages()
364 
365         verifyConnectedNotification(timeout = 0)
366         verifyVenueInfoIntent(mIntentCaptor.value)
367         verify(mResources).getString(R.string.tap_for_info)
368         verify(mNm).notify(eq(TEST_NETWORK_TAG), mNoteIdCaptor.capture(), mNoteCaptor.capture())
369         val note = mNoteCaptor.value
370         assertEquals(TEST_NETWORK_FRIENDLY_NAME, note.extras
371                 .getCharSequence(Notification.EXTRA_TITLE))
372         verifyCanceledNotificationAfterDefaultNetworkLost()
373     }
374 
375     private fun verifyVenueInfoIntent(intent: Intent) {
376         assertEquals(Intent.ACTION_VIEW, intent.action)
377         assertEquals(Uri.parse(TEST_VENUE_INFO_URL), intent.data)
378         assertEquals<Network?>(TEST_NETWORK, intent.getParcelableExtra(EXTRA_NETWORK))
379     }
380 
381     private fun verifyWifiSettingsIntent(intent: Intent) {
382         assertEquals(Settings.ACTION_WIFI_SETTINGS, intent.action)
383     }
384 
385     private fun onDefaultNetworkAvailable(network: Network) {
386         mHandler.post {
387             mDefaultNetworkCb.onAvailable(network)
388         }
389     }
390 
391     private fun onDefaultNetworkLost(network: Network) {
392         mHandler.post {
393             mDefaultNetworkCb.onLost(network)
394         }
395     }
396 
397     private fun onCapabilitiesChanged(capabilities: NetworkCapabilities) {
398         mHandler.post {
399             mAllNetworksCb.onCapabilitiesChanged(TEST_NETWORK, capabilities)
400         }
401     }
402 
403     private fun onLinkPropertiesChanged(lp: LinkProperties) {
404         mHandler.post {
405             mAllNetworksCb.onLinkPropertiesChanged(TEST_NETWORK, lp)
406         }
407     }
408 
409     private fun onLost(network: Network) {
410         mHandler.post {
411             mAllNetworksCb.onLost(network)
412         }
413     }
414 }