1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.pm;
18 
19 import android.app.ActivityManager;
20 import android.app.IActivityManager;
21 import android.app.IUserSwitchObserver;
22 import android.app.UserSwitchObserver;
23 import android.os.RemoteException;
24 import android.util.Log;
25 
26 import com.android.internal.util.FunctionalUtils;
27 
28 import java.io.Closeable;
29 import java.io.IOException;
30 import java.util.Map;
31 import java.util.concurrent.ConcurrentHashMap;
32 import java.util.concurrent.Semaphore;
33 import java.util.concurrent.TimeUnit;
34 
35 public class UserSwitchWaiter implements Closeable {
36 
37     private final String mTag;
38     private final int mTimeoutInSecond;
39     private final IActivityManager mActivityManager;
40     private final IUserSwitchObserver mUserSwitchObserver = new UserSwitchObserver() {
41         @Override
42         public void onUserSwitchComplete(int newUserId) {
43             getSemaphoreSwitchComplete(newUserId).release();
44         }
45 
46         @Override
47         public void onLockedBootComplete(int newUserId) {
48             getSemaphoreBootComplete(newUserId).release();
49         }
50     };
51 
52     private final Map<Integer, Semaphore> mSemaphoresMapSwitchComplete = new ConcurrentHashMap<>();
getSemaphoreSwitchComplete(final int userId)53     private Semaphore getSemaphoreSwitchComplete(final int userId) {
54         return mSemaphoresMapSwitchComplete.computeIfAbsent(userId,
55                 (Integer absentKey) -> new Semaphore(0));
56     }
57 
58     private final Map<Integer, Semaphore> mSemaphoresMapBootComplete = new ConcurrentHashMap<>();
getSemaphoreBootComplete(final int userId)59     private Semaphore getSemaphoreBootComplete(final int userId) {
60         return mSemaphoresMapBootComplete.computeIfAbsent(userId,
61                 (Integer absentKey) -> new Semaphore(0));
62     }
63 
UserSwitchWaiter(String tag, int timeoutInSecond)64     public UserSwitchWaiter(String tag, int timeoutInSecond) throws RemoteException {
65         mTag = tag;
66         mTimeoutInSecond = timeoutInSecond;
67         mActivityManager = ActivityManager.getService();
68 
69         mActivityManager.registerUserSwitchObserver(mUserSwitchObserver, mTag);
70     }
71 
72     @Override
close()73     public void close() throws IOException {
74         try {
75             mActivityManager.unregisterUserSwitchObserver(mUserSwitchObserver);
76         } catch (RemoteException e) {
77             Log.e(mTag, "Failed to unregister user switch observer", e);
78         }
79     }
80 
runThenWaitUntilSwitchCompleted(int userId, FunctionalUtils.ThrowingRunnable runnable, Runnable onFail)81     public void runThenWaitUntilSwitchCompleted(int userId,
82             FunctionalUtils.ThrowingRunnable runnable, Runnable onFail) {
83         final Semaphore semaphore = getSemaphoreSwitchComplete(userId);
84         semaphore.drainPermits();
85         runnable.run();
86         waitForSemaphore(semaphore, onFail);
87     }
88 
runThenWaitUntilBootCompleted(int userId, FunctionalUtils.ThrowingRunnable runnable, Runnable onFail)89     public void runThenWaitUntilBootCompleted(int userId,
90             FunctionalUtils.ThrowingRunnable runnable, Runnable onFail) {
91         final Semaphore semaphore = getSemaphoreBootComplete(userId);
92         semaphore.drainPermits();
93         runnable.run();
94         waitForSemaphore(semaphore, onFail);
95     }
96 
waitForSemaphore(Semaphore semaphore, Runnable onFail)97     private void waitForSemaphore(Semaphore semaphore, Runnable onFail) {
98         boolean success = false;
99         try {
100             success = semaphore.tryAcquire(mTimeoutInSecond, TimeUnit.SECONDS);
101         } catch (InterruptedException e) {
102             Log.e(mTag, "Thread interrupted unexpectedly.", e);
103         }
104         if (!success && onFail != null) {
105             onFail.run();
106         }
107     }
108 }
109