1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.companion.transport;
18 
19 import static android.Manifest.permission.DELIVER_COMPANION_MESSAGES;
20 
21 import static com.android.server.companion.transport.Transport.MESSAGE_REQUEST_PERMISSION_RESTORE;
22 
23 import android.annotation.NonNull;
24 import android.annotation.SuppressLint;
25 import android.companion.AssociationInfo;
26 import android.companion.IOnMessageReceivedListener;
27 import android.companion.IOnTransportsChangedListener;
28 import android.content.Context;
29 import android.os.Build;
30 import android.os.ParcelFileDescriptor;
31 import android.os.RemoteCallbackList;
32 import android.os.RemoteException;
33 import android.util.Slog;
34 import android.util.SparseArray;
35 
36 import com.android.internal.annotations.GuardedBy;
37 import com.android.server.companion.AssociationStore;
38 
39 import java.io.FileDescriptor;
40 import java.io.IOException;
41 import java.nio.charset.StandardCharsets;
42 import java.util.ArrayList;
43 import java.util.List;
44 import java.util.concurrent.CompletableFuture;
45 import java.util.concurrent.Future;
46 
47 @SuppressLint("LongLogTag")
48 public class CompanionTransportManager {
49     private static final String TAG = "CDM_CompanionTransportManager";
50     private static final boolean DEBUG = false;
51 
52     private boolean mSecureTransportEnabled = true;
53 
54     private final Context mContext;
55     private final AssociationStore mAssociationStore;
56 
57     /** Association id -> Transport */
58     @GuardedBy("mTransports")
59     private final SparseArray<Transport> mTransports = new SparseArray<>();
60     @NonNull
61     private final RemoteCallbackList<IOnTransportsChangedListener> mTransportsListeners =
62             new RemoteCallbackList<>();
63     /** Message type -> IOnMessageReceivedListener */
64     @NonNull
65     private final SparseArray<IOnMessageReceivedListener> mMessageListeners = new SparseArray<>();
66 
CompanionTransportManager(Context context, AssociationStore associationStore)67     public CompanionTransportManager(Context context, AssociationStore associationStore) {
68         mContext = context;
69         mAssociationStore = associationStore;
70     }
71 
72     /**
73      * Add a listener to receive callbacks when a message is received for the message type
74      */
addListener(int message, @NonNull IOnMessageReceivedListener listener)75     public void addListener(int message, @NonNull IOnMessageReceivedListener listener) {
76         mMessageListeners.put(message, listener);
77         synchronized (mTransports) {
78             for (int i = 0; i < mTransports.size(); i++) {
79                 mTransports.valueAt(i).addListener(message, listener);
80             }
81         }
82     }
83 
84     /**
85      * Add a listener to receive callbacks when any of the transports is changed
86      */
addListener(IOnTransportsChangedListener listener)87     public void addListener(IOnTransportsChangedListener listener) {
88         Slog.i(TAG, "Registering OnTransportsChangedListener");
89         mTransportsListeners.register(listener);
90         List<AssociationInfo> associations = new ArrayList<>();
91         synchronized (mTransports) {
92             for (int i = 0; i < mTransports.size(); i++) {
93                 AssociationInfo association = mAssociationStore.getAssociationById(
94                         mTransports.keyAt(i));
95                 if (association != null) {
96                     associations.add(association);
97                 }
98             }
99         }
100         mTransportsListeners.broadcast(listener1 -> {
101             // callback to the current listener with all the associations of the transports
102             // immediately
103             if (listener1 == listener) {
104                 try {
105                     listener.onTransportsChanged(associations);
106                 } catch (RemoteException ignored) {
107                 }
108             }
109         });
110     }
111 
112     /**
113      * Remove the listener for receiving callbacks when any of the transports is changed
114      */
removeListener(IOnTransportsChangedListener listener)115     public void removeListener(IOnTransportsChangedListener listener) {
116         mTransportsListeners.unregister(listener);
117     }
118 
119     /**
120      * Remove the listener to stop receiving calbacks when a message is received for the given type
121      */
removeListener(int messageType, IOnMessageReceivedListener listener)122     public void removeListener(int messageType, IOnMessageReceivedListener listener) {
123         mMessageListeners.remove(messageType);
124     }
125 
126     /**
127      * Send a message to remote devices through the transports
128      */
sendMessage(int message, byte[] data, int[] associationIds)129     public void sendMessage(int message, byte[] data, int[] associationIds) {
130         Slog.i(TAG, "Sending message 0x" + Integer.toHexString(message)
131                 + " data length " + data.length);
132         synchronized (mTransports) {
133             for (int i = 0; i < associationIds.length; i++) {
134                 if (mTransports.contains(associationIds[i])) {
135                     mTransports.get(associationIds[i]).requestForResponse(message, data);
136                 }
137             }
138         }
139     }
140 
attachSystemDataTransport(String packageName, int userId, int associationId, ParcelFileDescriptor fd)141     public void attachSystemDataTransport(String packageName, int userId, int associationId,
142             ParcelFileDescriptor fd) {
143         mContext.enforceCallingOrSelfPermission(DELIVER_COMPANION_MESSAGES, TAG);
144         synchronized (mTransports) {
145             if (mTransports.contains(associationId)) {
146                 detachSystemDataTransport(packageName, userId, associationId);
147             }
148 
149             // TODO: Implement new API to pass a PSK
150             initializeTransport(associationId, fd, null);
151 
152             notifyOnTransportsChanged();
153         }
154     }
155 
detachSystemDataTransport(String packageName, int userId, int associationId)156     public void detachSystemDataTransport(String packageName, int userId, int associationId) {
157         mContext.enforceCallingOrSelfPermission(DELIVER_COMPANION_MESSAGES, TAG);
158         synchronized (mTransports) {
159             final Transport transport = mTransports.get(associationId);
160             if (transport != null) {
161                 mTransports.delete(associationId);
162                 transport.stop();
163             }
164 
165             notifyOnTransportsChanged();
166         }
167     }
168 
notifyOnTransportsChanged()169     private void notifyOnTransportsChanged() {
170         List<AssociationInfo> associations = new ArrayList<>();
171         synchronized (mTransports) {
172             for (int i = 0; i < mTransports.size(); i++) {
173                 AssociationInfo association = mAssociationStore.getAssociationById(
174                         mTransports.keyAt(i));
175                 if (association != null) {
176                     associations.add(association);
177                 }
178             }
179         }
180         mTransportsListeners.broadcast(listener -> {
181             try {
182                 listener.onTransportsChanged(associations);
183             } catch (RemoteException ignored) {
184             }
185         });
186     }
187 
initializeTransport(int associationId, ParcelFileDescriptor fd, byte[] preSharedKey)188     private void initializeTransport(int associationId,
189                                      ParcelFileDescriptor fd,
190                                      byte[] preSharedKey) {
191         Slog.i(TAG, "Initializing transport");
192         Transport transport;
193         if (!isSecureTransportEnabled()) {
194             // If secure transport is explicitly disabled for testing, use raw transport
195             Slog.i(TAG, "Secure channel is disabled. Creating raw transport");
196             transport = new RawTransport(associationId, fd, mContext);
197         } else if (Build.isDebuggable()) {
198             // If device is debug build, use hardcoded test key for authentication
199             Slog.d(TAG, "Creating an unauthenticated secure channel");
200             final byte[] testKey = "CDM".getBytes(StandardCharsets.UTF_8);
201             transport = new SecureTransport(associationId, fd, mContext, testKey, null);
202         } else if (preSharedKey != null) {
203             // If either device is not Android, then use app-specific pre-shared key
204             Slog.d(TAG, "Creating a PSK-authenticated secure channel");
205             transport = new SecureTransport(associationId, fd, mContext, preSharedKey, null);
206         } else {
207             // If none of the above applies, then use secure channel with attestation verification
208             Slog.d(TAG, "Creating a secure channel");
209             transport = new SecureTransport(associationId, fd, mContext);
210         }
211 
212         addMessageListenersToTransport(transport);
213         transport.setOnTransportClosedListener(this::detachSystemDataTransport);
214         transport.start();
215         synchronized (mTransports) {
216             mTransports.put(associationId, transport);
217         }
218 
219     }
220 
requestPermissionRestore(int associationId, byte[] data)221     public Future<?> requestPermissionRestore(int associationId, byte[] data) {
222         synchronized (mTransports) {
223             final Transport transport = mTransports.get(associationId);
224             if (transport == null) {
225                 return CompletableFuture.failedFuture(new IOException("Missing transport"));
226             }
227             return transport.requestForResponse(MESSAGE_REQUEST_PERMISSION_RESTORE, data);
228         }
229     }
230 
231     /**
232      * @hide
233      */
enableSecureTransport(boolean enabled)234     public void enableSecureTransport(boolean enabled) {
235         this.mSecureTransportEnabled = enabled;
236     }
237 
238     /**
239      * For testing purpose only.
240      *
241      * Create an emulated RawTransport and notify onTransportChanged listeners.
242      */
createEmulatedTransport(int associationId)243     public EmulatedTransport createEmulatedTransport(int associationId) {
244         synchronized (mTransports) {
245             FileDescriptor fd = new FileDescriptor();
246             ParcelFileDescriptor pfd = new ParcelFileDescriptor(fd);
247             EmulatedTransport transport = new EmulatedTransport(associationId, pfd, mContext);
248             addMessageListenersToTransport(transport);
249             mTransports.put(associationId, transport);
250             notifyOnTransportsChanged();
251             return transport;
252         }
253     }
254 
255     /**
256      * For testing purposes only.
257      *
258      * Emulates a transport for incoming messages but black-holes all messages sent back through it.
259      */
260     public static class EmulatedTransport extends RawTransport {
261 
EmulatedTransport(int associationId, ParcelFileDescriptor fd, Context context)262         EmulatedTransport(int associationId, ParcelFileDescriptor fd, Context context) {
263             super(associationId, fd, context);
264         }
265 
266         /** Process an incoming message for testing purposes. */
processMessage(int messageType, int sequence, byte[] data)267         public void processMessage(int messageType, int sequence, byte[] data) throws IOException {
268             handleMessage(messageType, sequence, data);
269         }
270 
271         @Override
sendMessage(int messageType, int sequence, @NonNull byte[] data)272         protected void sendMessage(int messageType, int sequence, @NonNull byte[] data)
273                 throws IOException {
274             Slog.e(TAG, "Black-holing emulated message type 0x" + Integer.toHexString(messageType)
275                     + " sequence " + sequence + " length " + data.length
276                     + " to association " + mAssociationId);
277         }
278     }
279 
isSecureTransportEnabled()280     private boolean isSecureTransportEnabled() {
281         boolean enabled = !Build.IS_DEBUGGABLE || mSecureTransportEnabled;
282 
283         return enabled;
284     }
285 
addMessageListenersToTransport(Transport transport)286     private void addMessageListenersToTransport(Transport transport) {
287         for (int i = 0; i < mMessageListeners.size(); i++) {
288             transport.addListener(mMessageListeners.keyAt(i), mMessageListeners.valueAt(i));
289         }
290     }
291 
detachSystemDataTransport(Transport transport)292     void detachSystemDataTransport(Transport transport) {
293         int associationId = transport.mAssociationId;
294         AssociationInfo association = mAssociationStore.getAssociationById(associationId);
295         if (association != null) {
296             detachSystemDataTransport(association.getPackageName(),
297                     association.getUserId(),
298                     association.getId());
299         }
300     }
301 }
302