1 /* 2 * Copyright (C) 2022 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.server.companion.transport; 18 19 import static android.Manifest.permission.DELIVER_COMPANION_MESSAGES; 20 21 import static com.android.server.companion.transport.Transport.MESSAGE_REQUEST_PERMISSION_RESTORE; 22 23 import android.annotation.NonNull; 24 import android.annotation.SuppressLint; 25 import android.companion.AssociationInfo; 26 import android.companion.IOnMessageReceivedListener; 27 import android.companion.IOnTransportsChangedListener; 28 import android.content.Context; 29 import android.os.Build; 30 import android.os.ParcelFileDescriptor; 31 import android.os.RemoteCallbackList; 32 import android.os.RemoteException; 33 import android.util.Slog; 34 import android.util.SparseArray; 35 36 import com.android.internal.annotations.GuardedBy; 37 import com.android.server.companion.AssociationStore; 38 39 import java.io.FileDescriptor; 40 import java.io.IOException; 41 import java.nio.charset.StandardCharsets; 42 import java.util.ArrayList; 43 import java.util.List; 44 import java.util.concurrent.CompletableFuture; 45 import java.util.concurrent.Future; 46 47 @SuppressLint("LongLogTag") 48 public class CompanionTransportManager { 49 private static final String TAG = "CDM_CompanionTransportManager"; 50 private static final boolean DEBUG = false; 51 52 private boolean mSecureTransportEnabled = true; 53 54 private final Context mContext; 55 private final AssociationStore mAssociationStore; 56 57 /** Association id -> Transport */ 58 @GuardedBy("mTransports") 59 private final SparseArray<Transport> mTransports = new SparseArray<>(); 60 @NonNull 61 private final RemoteCallbackList<IOnTransportsChangedListener> mTransportsListeners = 62 new RemoteCallbackList<>(); 63 /** Message type -> IOnMessageReceivedListener */ 64 @NonNull 65 private final SparseArray<IOnMessageReceivedListener> mMessageListeners = new SparseArray<>(); 66 CompanionTransportManager(Context context, AssociationStore associationStore)67 public CompanionTransportManager(Context context, AssociationStore associationStore) { 68 mContext = context; 69 mAssociationStore = associationStore; 70 } 71 72 /** 73 * Add a listener to receive callbacks when a message is received for the message type 74 */ addListener(int message, @NonNull IOnMessageReceivedListener listener)75 public void addListener(int message, @NonNull IOnMessageReceivedListener listener) { 76 mMessageListeners.put(message, listener); 77 synchronized (mTransports) { 78 for (int i = 0; i < mTransports.size(); i++) { 79 mTransports.valueAt(i).addListener(message, listener); 80 } 81 } 82 } 83 84 /** 85 * Add a listener to receive callbacks when any of the transports is changed 86 */ addListener(IOnTransportsChangedListener listener)87 public void addListener(IOnTransportsChangedListener listener) { 88 Slog.i(TAG, "Registering OnTransportsChangedListener"); 89 mTransportsListeners.register(listener); 90 List<AssociationInfo> associations = new ArrayList<>(); 91 synchronized (mTransports) { 92 for (int i = 0; i < mTransports.size(); i++) { 93 AssociationInfo association = mAssociationStore.getAssociationById( 94 mTransports.keyAt(i)); 95 if (association != null) { 96 associations.add(association); 97 } 98 } 99 } 100 mTransportsListeners.broadcast(listener1 -> { 101 // callback to the current listener with all the associations of the transports 102 // immediately 103 if (listener1 == listener) { 104 try { 105 listener.onTransportsChanged(associations); 106 } catch (RemoteException ignored) { 107 } 108 } 109 }); 110 } 111 112 /** 113 * Remove the listener for receiving callbacks when any of the transports is changed 114 */ removeListener(IOnTransportsChangedListener listener)115 public void removeListener(IOnTransportsChangedListener listener) { 116 mTransportsListeners.unregister(listener); 117 } 118 119 /** 120 * Remove the listener to stop receiving calbacks when a message is received for the given type 121 */ removeListener(int messageType, IOnMessageReceivedListener listener)122 public void removeListener(int messageType, IOnMessageReceivedListener listener) { 123 mMessageListeners.remove(messageType); 124 } 125 126 /** 127 * Send a message to remote devices through the transports 128 */ sendMessage(int message, byte[] data, int[] associationIds)129 public void sendMessage(int message, byte[] data, int[] associationIds) { 130 Slog.i(TAG, "Sending message 0x" + Integer.toHexString(message) 131 + " data length " + data.length); 132 synchronized (mTransports) { 133 for (int i = 0; i < associationIds.length; i++) { 134 if (mTransports.contains(associationIds[i])) { 135 mTransports.get(associationIds[i]).requestForResponse(message, data); 136 } 137 } 138 } 139 } 140 attachSystemDataTransport(String packageName, int userId, int associationId, ParcelFileDescriptor fd)141 public void attachSystemDataTransport(String packageName, int userId, int associationId, 142 ParcelFileDescriptor fd) { 143 mContext.enforceCallingOrSelfPermission(DELIVER_COMPANION_MESSAGES, TAG); 144 synchronized (mTransports) { 145 if (mTransports.contains(associationId)) { 146 detachSystemDataTransport(packageName, userId, associationId); 147 } 148 149 // TODO: Implement new API to pass a PSK 150 initializeTransport(associationId, fd, null); 151 152 notifyOnTransportsChanged(); 153 } 154 } 155 detachSystemDataTransport(String packageName, int userId, int associationId)156 public void detachSystemDataTransport(String packageName, int userId, int associationId) { 157 mContext.enforceCallingOrSelfPermission(DELIVER_COMPANION_MESSAGES, TAG); 158 synchronized (mTransports) { 159 final Transport transport = mTransports.get(associationId); 160 if (transport != null) { 161 mTransports.delete(associationId); 162 transport.stop(); 163 } 164 165 notifyOnTransportsChanged(); 166 } 167 } 168 notifyOnTransportsChanged()169 private void notifyOnTransportsChanged() { 170 List<AssociationInfo> associations = new ArrayList<>(); 171 synchronized (mTransports) { 172 for (int i = 0; i < mTransports.size(); i++) { 173 AssociationInfo association = mAssociationStore.getAssociationById( 174 mTransports.keyAt(i)); 175 if (association != null) { 176 associations.add(association); 177 } 178 } 179 } 180 mTransportsListeners.broadcast(listener -> { 181 try { 182 listener.onTransportsChanged(associations); 183 } catch (RemoteException ignored) { 184 } 185 }); 186 } 187 initializeTransport(int associationId, ParcelFileDescriptor fd, byte[] preSharedKey)188 private void initializeTransport(int associationId, 189 ParcelFileDescriptor fd, 190 byte[] preSharedKey) { 191 Slog.i(TAG, "Initializing transport"); 192 Transport transport; 193 if (!isSecureTransportEnabled()) { 194 // If secure transport is explicitly disabled for testing, use raw transport 195 Slog.i(TAG, "Secure channel is disabled. Creating raw transport"); 196 transport = new RawTransport(associationId, fd, mContext); 197 } else if (Build.isDebuggable()) { 198 // If device is debug build, use hardcoded test key for authentication 199 Slog.d(TAG, "Creating an unauthenticated secure channel"); 200 final byte[] testKey = "CDM".getBytes(StandardCharsets.UTF_8); 201 transport = new SecureTransport(associationId, fd, mContext, testKey, null); 202 } else if (preSharedKey != null) { 203 // If either device is not Android, then use app-specific pre-shared key 204 Slog.d(TAG, "Creating a PSK-authenticated secure channel"); 205 transport = new SecureTransport(associationId, fd, mContext, preSharedKey, null); 206 } else { 207 // If none of the above applies, then use secure channel with attestation verification 208 Slog.d(TAG, "Creating a secure channel"); 209 transport = new SecureTransport(associationId, fd, mContext); 210 } 211 212 addMessageListenersToTransport(transport); 213 transport.setOnTransportClosedListener(this::detachSystemDataTransport); 214 transport.start(); 215 synchronized (mTransports) { 216 mTransports.put(associationId, transport); 217 } 218 219 } 220 requestPermissionRestore(int associationId, byte[] data)221 public Future<?> requestPermissionRestore(int associationId, byte[] data) { 222 synchronized (mTransports) { 223 final Transport transport = mTransports.get(associationId); 224 if (transport == null) { 225 return CompletableFuture.failedFuture(new IOException("Missing transport")); 226 } 227 return transport.requestForResponse(MESSAGE_REQUEST_PERMISSION_RESTORE, data); 228 } 229 } 230 231 /** 232 * @hide 233 */ enableSecureTransport(boolean enabled)234 public void enableSecureTransport(boolean enabled) { 235 this.mSecureTransportEnabled = enabled; 236 } 237 238 /** 239 * For testing purpose only. 240 * 241 * Create an emulated RawTransport and notify onTransportChanged listeners. 242 */ createEmulatedTransport(int associationId)243 public EmulatedTransport createEmulatedTransport(int associationId) { 244 synchronized (mTransports) { 245 FileDescriptor fd = new FileDescriptor(); 246 ParcelFileDescriptor pfd = new ParcelFileDescriptor(fd); 247 EmulatedTransport transport = new EmulatedTransport(associationId, pfd, mContext); 248 addMessageListenersToTransport(transport); 249 mTransports.put(associationId, transport); 250 notifyOnTransportsChanged(); 251 return transport; 252 } 253 } 254 255 /** 256 * For testing purposes only. 257 * 258 * Emulates a transport for incoming messages but black-holes all messages sent back through it. 259 */ 260 public static class EmulatedTransport extends RawTransport { 261 EmulatedTransport(int associationId, ParcelFileDescriptor fd, Context context)262 EmulatedTransport(int associationId, ParcelFileDescriptor fd, Context context) { 263 super(associationId, fd, context); 264 } 265 266 /** Process an incoming message for testing purposes. */ processMessage(int messageType, int sequence, byte[] data)267 public void processMessage(int messageType, int sequence, byte[] data) throws IOException { 268 handleMessage(messageType, sequence, data); 269 } 270 271 @Override sendMessage(int messageType, int sequence, @NonNull byte[] data)272 protected void sendMessage(int messageType, int sequence, @NonNull byte[] data) 273 throws IOException { 274 Slog.e(TAG, "Black-holing emulated message type 0x" + Integer.toHexString(messageType) 275 + " sequence " + sequence + " length " + data.length 276 + " to association " + mAssociationId); 277 } 278 } 279 isSecureTransportEnabled()280 private boolean isSecureTransportEnabled() { 281 boolean enabled = !Build.IS_DEBUGGABLE || mSecureTransportEnabled; 282 283 return enabled; 284 } 285 addMessageListenersToTransport(Transport transport)286 private void addMessageListenersToTransport(Transport transport) { 287 for (int i = 0; i < mMessageListeners.size(); i++) { 288 transport.addListener(mMessageListeners.keyAt(i), mMessageListeners.valueAt(i)); 289 } 290 } 291 detachSystemDataTransport(Transport transport)292 void detachSystemDataTransport(Transport transport) { 293 int associationId = transport.mAssociationId; 294 AssociationInfo association = mAssociationStore.getAssociationById(associationId); 295 if (association != null) { 296 detachSystemDataTransport(association.getPackageName(), 297 association.getUserId(), 298 association.getId()); 299 } 300 } 301 } 302