1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.flags;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static com.google.common.truth.Truth.assertWithMessage;
21 
22 import static org.junit.Assert.fail;
23 import static org.mockito.Mockito.any;
24 import static org.mockito.Mockito.doThrow;
25 import static org.mockito.Mockito.eq;
26 import static org.mockito.Mockito.verify;
27 import static org.mockito.Mockito.when;
28 
29 import android.flags.IFeatureFlagsCallback;
30 import android.flags.SyncableFlag;
31 import android.os.IBinder;
32 import android.os.RemoteException;
33 import android.platform.test.annotations.Presubmit;
34 
35 import androidx.test.filters.SmallTest;
36 
37 import org.junit.Before;
38 import org.junit.Rule;
39 import org.junit.Test;
40 import org.mockito.Mock;
41 import org.mockito.junit.MockitoJUnit;
42 import org.mockito.junit.MockitoRule;
43 
44 import java.util.List;
45 
46 @Presubmit
47 @SmallTest
48 public class FeatureFlagsServiceTest {
49     private static final String NS = "ns";
50     private static final String NAME = "name";
51     private static final String PROP_NAME = FlagOverrideStore.getPropName(NS, NAME);
52 
53     @Rule
54     public final MockitoRule mockito = MockitoJUnit.rule();
55 
56     @Mock
57     private FlagOverrideStore mFlagStore;
58     @Mock
59     private FlagsShellCommand mFlagCommand;
60     @Mock
61     private IFeatureFlagsCallback mIFeatureFlagsCallback;
62     @Mock
63     private IBinder mIFeatureFlagsCallbackAsBinder;
64     @Mock
65     private FeatureFlagsService.PermissionsChecker mPermissionsChecker;
66 
67     private FeatureFlagsBinder mFeatureFlagsService;
68 
69     @Before
setup()70     public void setup() {
71         when(mIFeatureFlagsCallback.asBinder()).thenReturn(mIFeatureFlagsCallbackAsBinder);
72         mFeatureFlagsService = new FeatureFlagsBinder(
73                 mFlagStore, mFlagCommand, mPermissionsChecker);
74     }
75 
76     @Test
testRegisterCallback()77     public void testRegisterCallback() {
78         mFeatureFlagsService.registerCallback(mIFeatureFlagsCallback);
79         try {
80             verify(mIFeatureFlagsCallbackAsBinder).linkToDeath(any(), eq(0));
81         } catch (RemoteException e) {
82             fail("Our mock threw a Remote Exception?");
83         }
84     }
85 
86     @Test
testOverrideFlag_requiresWritePermission()87     public void testOverrideFlag_requiresWritePermission() {
88         SecurityException exc = new SecurityException("not allowed");
89         doThrow(exc).when(mPermissionsChecker).assertWritePermission();
90 
91         SyncableFlag f = new SyncableFlag(NS, "a", "false", false);
92 
93         try {
94             mFeatureFlagsService.overrideFlag(f);
95             fail("Should have thrown exception");
96         } catch (SecurityException e) {
97             assertThat(exc).isEqualTo(e);
98         } catch (Exception e) {
99             fail("should have thrown a security exception");
100         }
101     }
102 
103     @Test
testResetFlag_requiresWritePermission()104     public void testResetFlag_requiresWritePermission() {
105         SecurityException exc = new SecurityException("not allowed");
106         doThrow(exc).when(mPermissionsChecker).assertWritePermission();
107 
108         SyncableFlag f = new SyncableFlag(NS, "a", "false", false);
109 
110         try {
111             mFeatureFlagsService.resetFlag(f);
112             fail("Should have thrown exception");
113         } catch (SecurityException e) {
114             assertThat(exc).isEqualTo(e);
115         } catch (Exception e) {
116             fail("should have thrown a security exception");
117         }
118     }
119 
120     @Test
testSyncFlags_noOverrides()121     public void testSyncFlags_noOverrides() {
122         List<SyncableFlag> inputFlags = List.of(
123                 new SyncableFlag(NS, "a", "false", false),
124                 new SyncableFlag(NS, "b", "true", false),
125                 new SyncableFlag(NS, "c", "false", false)
126         );
127 
128         List<SyncableFlag> outputFlags = mFeatureFlagsService.syncFlags(inputFlags);
129 
130         assertThat(inputFlags.size()).isEqualTo(outputFlags.size());
131 
132         for (SyncableFlag inpF: inputFlags) {
133             boolean found = false;
134             for (SyncableFlag outF : outputFlags) {
135                 if (compareSyncableFlagsNames(inpF, outF)) {
136                     found = true;
137                     break;
138                 }
139             }
140             assertWithMessage("Failed to find input flag " + inpF + " in the output")
141                     .that(found).isTrue();
142         }
143     }
144 
145     @Test
testSyncFlags_withSomeOverrides()146     public void testSyncFlags_withSomeOverrides() {
147         List<SyncableFlag> inputFlags = List.of(
148                 new SyncableFlag(NS, "a", "false", false),
149                 new SyncableFlag(NS, "b", "true", false),
150                 new SyncableFlag(NS, "c", "false", false)
151         );
152 
153         assertThat(mFlagStore).isNotNull();
154         when(mFlagStore.get(NS, "c")).thenReturn("true");
155         List<SyncableFlag> outputFlags = mFeatureFlagsService.syncFlags(inputFlags);
156 
157         assertThat(inputFlags.size()).isEqualTo(outputFlags.size());
158 
159         for (SyncableFlag inpF: inputFlags) {
160             boolean found = false;
161             for (SyncableFlag outF : outputFlags) {
162                 if (compareSyncableFlagsNames(inpF, outF)) {
163                     found = true;
164 
165                     // Once we've found "c", do an extra check
166                     if (outF.getName().equals("c")) {
167                         assertWithMessage("Flag " + outF + "was not returned with an override")
168                                 .that(outF.getValue()).isEqualTo("true");
169                     }
170                     break;
171                 }
172             }
173             assertWithMessage("Failed to find input flag " + inpF + " in the output")
174                     .that(found).isTrue();
175         }
176     }
177 
178     @Test
testSyncFlags_twoCallsWithDifferentDefaults()179     public void testSyncFlags_twoCallsWithDifferentDefaults() {
180         List<SyncableFlag> inputFlagsFirst = List.of(
181                 new SyncableFlag(NS, "a", "false", false)
182         );
183         List<SyncableFlag> inputFlagsSecond = List.of(
184                 new SyncableFlag(NS, "a", "true", false),
185                 new SyncableFlag(NS, "b", "false", false)
186         );
187 
188         List<SyncableFlag> outputFlagsFirst = mFeatureFlagsService.syncFlags(inputFlagsFirst);
189         List<SyncableFlag> outputFlagsSecond = mFeatureFlagsService.syncFlags(inputFlagsSecond);
190 
191         assertThat(inputFlagsFirst.size()).isEqualTo(outputFlagsFirst.size());
192         assertThat(inputFlagsSecond.size()).isEqualTo(outputFlagsSecond.size());
193 
194         // This test only cares that the "a" flag passed in the second time came out with the
195         // same value that was passed in the first time.
196 
197         boolean found = false;
198         for (SyncableFlag second : outputFlagsSecond) {
199             if (compareSyncableFlagsNames(second, inputFlagsFirst.get(0))) {
200                 found = true;
201                 assertThat(second.getValue()).isEqualTo(inputFlagsFirst.get(0).getValue());
202                 break;
203             }
204         }
205 
206         assertWithMessage(
207                 "Failed to find flag " + inputFlagsFirst.get(0) + " in the second calls output")
208                 .that(found).isTrue();
209     }
210 
211     @Test
testQueryFlags_onlyOnce()212     public void testQueryFlags_onlyOnce() {
213         List<SyncableFlag> inputFlags = List.of(
214                 new SyncableFlag(NS, "a", "false", false),
215                 new SyncableFlag(NS, "b", "true", false),
216                 new SyncableFlag(NS, "c", "false", false)
217         );
218 
219         List<SyncableFlag> outputFlags = mFeatureFlagsService.queryFlags(inputFlags);
220 
221         assertThat(inputFlags.size()).isEqualTo(outputFlags.size());
222 
223         for (SyncableFlag inpF: inputFlags) {
224             boolean found = false;
225             for (SyncableFlag outF : outputFlags) {
226                 if (compareSyncableFlagsNames(inpF, outF)) {
227                     found = true;
228                     break;
229                 }
230             }
231             assertWithMessage("Failed to find input flag " + inpF + " in the output")
232                     .that(found).isTrue();
233         }
234     }
235 
236     @Test
testQueryFlags_twoCallsWithDifferentDefaults()237     public void testQueryFlags_twoCallsWithDifferentDefaults() {
238         List<SyncableFlag> inputFlagsFirst = List.of(
239                 new SyncableFlag(NS, "a", "false", false)
240         );
241         List<SyncableFlag> inputFlagsSecond = List.of(
242                 new SyncableFlag(NS, "a", "true", false),
243                 new SyncableFlag(NS, "b", "false", false)
244         );
245 
246         List<SyncableFlag> outputFlagsFirst = mFeatureFlagsService.queryFlags(inputFlagsFirst);
247         List<SyncableFlag> outputFlagsSecond = mFeatureFlagsService.queryFlags(inputFlagsSecond);
248 
249         assertThat(inputFlagsFirst.size()).isEqualTo(outputFlagsFirst.size());
250         assertThat(inputFlagsSecond.size()).isEqualTo(outputFlagsSecond.size());
251 
252         // This test only cares that the "a" flag passed in the second time came out with the
253         // same value that was passed in (i.e. it wasn't cached).
254 
255         boolean found = false;
256         for (SyncableFlag second : outputFlagsSecond) {
257             if (compareSyncableFlagsNames(second, inputFlagsSecond.get(0))) {
258                 found = true;
259                 assertThat(second.getValue()).isEqualTo(inputFlagsSecond.get(0).getValue());
260                 break;
261             }
262         }
263 
264         assertWithMessage(
265                 "Failed to find flag " + inputFlagsSecond.get(0) + " in the second calls output")
266                 .that(found).isTrue();
267     }
268 
269     @Test
testOverrideFlag()270     public void testOverrideFlag() {
271         SyncableFlag f = new SyncableFlag(NS, "a", "false", false);
272 
273         mFeatureFlagsService.overrideFlag(f);
274 
275         verify(mFlagStore).set(f.getNamespace(), f.getName(), f.getValue());
276     }
277 
278     @Test
testResetFlag()279     public void testResetFlag() {
280         SyncableFlag f = new SyncableFlag(NS, "a", "false", false);
281 
282         mFeatureFlagsService.resetFlag(f);
283 
284         verify(mFlagStore).erase(f.getNamespace(), f.getName());
285     }
286 
287 
compareSyncableFlagsNames(SyncableFlag a, SyncableFlag b)288     private static boolean compareSyncableFlagsNames(SyncableFlag a, SyncableFlag b) {
289         return a.getNamespace().equals(b.getNamespace())
290                 && a.getName().equals(b.getName())
291                 && a.isDynamic() == b.isDynamic();
292     }
293 }
294