1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package android.ext.services.resolver; 18 19 import android.content.Context; 20 import android.content.Intent; 21 import android.content.SharedPreferences; 22 import android.os.Environment; 23 import android.os.IBinder; 24 import android.os.UserManager; 25 import android.os.storage.StorageManager; 26 import android.service.resolver.ResolverRankerService; 27 import android.service.resolver.ResolverTarget; 28 import android.util.ArrayMap; 29 import android.util.Log; 30 31 import androidx.annotation.VisibleForTesting; 32 import androidx.core.util.Preconditions; 33 34 import java.io.File; 35 import java.util.Collection; 36 import java.util.List; 37 import java.util.Map; 38 39 /** 40 * A Logistic Regression based {@link android.service.resolver.ResolverRankerService}, to be used 41 * in {@link ResolverComparator}. 42 */ 43 public final class LRResolverRankerService extends ResolverRankerService { 44 private static final String TAG = "LRResolverRankerService"; 45 46 private static final boolean DEBUG = false; 47 48 private static final String PARAM_SHARED_PREF_NAME = "resolver_ranker_params"; 49 private static final String BIAS_PREF_KEY = "bias"; 50 private static final String VERSION_PREF_KEY = "version"; 51 52 private static final String LAUNCH_SCORE = "launch"; 53 private static final String TIME_SPENT_SCORE = "timeSpent"; 54 private static final String RECENCY_SCORE = "recency"; 55 private static final String CHOOSER_SCORE = "chooser"; 56 57 // parameters for a pre-trained model, to initialize the app ranker. When updating the 58 // pre-trained model, please update these params, as well as initModel(). 59 private static final int CURRENT_VERSION = 1; 60 private static final float LEARNING_RATE = 0.0001f; 61 private static final float REGULARIZER_PARAM = 0.0001f; 62 63 private SharedPreferences mParamSharedPref; 64 private float mBias; 65 private boolean mInitModelDone; 66 67 @VisibleForTesting 68 ArrayMap<String, Float> mFeatureWeights; 69 70 @Override onBind(Intent intent)71 public IBinder onBind(Intent intent) { 72 initModel(); 73 return super.onBind(intent); 74 } 75 76 @Override onPredictSharingProbabilities(List<ResolverTarget> targets)77 public void onPredictSharingProbabilities(List<ResolverTarget> targets) { 78 Preconditions.checkState(initModel(), "Service is not ready yet"); 79 80 final int size = targets.size(); 81 for (int i = 0; i < size; ++i) { 82 ResolverTarget target = targets.get(i); 83 ArrayMap<String, Float> features = getFeatures(target); 84 target.setSelectProbability(predict(features)); 85 } 86 } 87 88 @Override onTrainRankingModel(List<ResolverTarget> targets, int selectedPosition)89 public void onTrainRankingModel(List<ResolverTarget> targets, int selectedPosition) { 90 Preconditions.checkState(initModel(), "Service is not ready yet"); 91 92 final int size = targets.size(); 93 if (selectedPosition < 0 || selectedPosition >= size) { 94 if (DEBUG) { 95 Log.d(TAG, "Invalid Position of Selected App " + selectedPosition); 96 } 97 return; 98 } 99 final ArrayMap<String, Float> positive = getFeatures(targets.get(selectedPosition)); 100 final float positiveProbability = targets.get(selectedPosition).getSelectProbability(); 101 final int targetSize = targets.size(); 102 for (int i = 0; i < targetSize; ++i) { 103 if (i == selectedPosition) { 104 continue; 105 } 106 final ArrayMap<String, Float> negative = getFeatures(targets.get(i)); 107 final float negativeProbability = targets.get(i).getSelectProbability(); 108 if (negativeProbability > positiveProbability) { 109 update(negative, negativeProbability, false); 110 update(positive, positiveProbability, true); 111 } 112 } 113 commitUpdate(); 114 } 115 116 // This is not thread safe, but ResolverRankerService has added the protection to call into it 117 // in the same Handler. initModel()118 private boolean initModel() { 119 if (mInitModelDone) { 120 return true; 121 } 122 final UserManager userManager = (UserManager) getSystemService(Context.USER_SERVICE); 123 if (userManager == null || !userManager.isUserUnlocked()) { 124 return false; 125 } 126 mParamSharedPref = getParamSharedPref(); 127 mFeatureWeights = new ArrayMap<>(4); 128 if (mParamSharedPref == null || 129 mParamSharedPref.getInt(VERSION_PREF_KEY, 0) < CURRENT_VERSION) { 130 // Initializing the app ranker to a pre-trained model. When updating the pre-trained 131 // model, please increment CURRENT_VERSION, and update LEARNING_RATE and 132 // REGULARIZER_PARAM. 133 mBias = -1.6568f; 134 mFeatureWeights.put(LAUNCH_SCORE, 2.5543f); 135 mFeatureWeights.put(TIME_SPENT_SCORE, 2.8412f); 136 mFeatureWeights.put(RECENCY_SCORE, 0.269f); 137 mFeatureWeights.put(CHOOSER_SCORE, 4.2222f); 138 } else { 139 mBias = mParamSharedPref.getFloat(BIAS_PREF_KEY, 0.0f); 140 mFeatureWeights.put(LAUNCH_SCORE, mParamSharedPref.getFloat(LAUNCH_SCORE, 0.0f)); 141 mFeatureWeights.put( 142 TIME_SPENT_SCORE, mParamSharedPref.getFloat(TIME_SPENT_SCORE, 0.0f)); 143 mFeatureWeights.put(RECENCY_SCORE, mParamSharedPref.getFloat(RECENCY_SCORE, 0.0f)); 144 mFeatureWeights.put(CHOOSER_SCORE, mParamSharedPref.getFloat(CHOOSER_SCORE, 0.0f)); 145 } 146 mInitModelDone = true; 147 return true; 148 } 149 getFeatures(ResolverTarget target)150 private ArrayMap<String, Float> getFeatures(ResolverTarget target) { 151 ArrayMap<String, Float> features = new ArrayMap<>(4); 152 features.put(RECENCY_SCORE, target.getRecencyScore()); 153 features.put(TIME_SPENT_SCORE, target.getTimeSpentScore()); 154 features.put(LAUNCH_SCORE, target.getLaunchScore()); 155 features.put(CHOOSER_SCORE, target.getChooserScore()); 156 return features; 157 } 158 predict(ArrayMap<String, Float> target)159 private float predict(ArrayMap<String, Float> target) { 160 if (target == null) { 161 return 0.0f; 162 } 163 final int featureSize = target.size(); 164 float sum = 0.0f; 165 for (int i = 0; i < featureSize; i++) { 166 String featureName = target.keyAt(i); 167 float weight = mFeatureWeights.getOrDefault(featureName, 0.0f); 168 sum += weight * target.valueAt(i); 169 } 170 return (float) (1.0 / (1.0 + Math.exp(-mBias - sum))); 171 } 172 update(ArrayMap<String, Float> target, float predict, boolean isSelected)173 private void update(ArrayMap<String, Float> target, float predict, boolean isSelected) { 174 if (target == null) { 175 return; 176 } 177 final int featureSize = target.size(); 178 float error = isSelected ? 1.0f - predict : -predict; 179 for (int i = 0; i < featureSize; i++) { 180 String featureName = target.keyAt(i); 181 float currentWeight = mFeatureWeights.getOrDefault(featureName, 0.0f); 182 mBias += LEARNING_RATE * error; 183 currentWeight = currentWeight - LEARNING_RATE * REGULARIZER_PARAM * currentWeight + 184 LEARNING_RATE * error * target.valueAt(i); 185 mFeatureWeights.put(featureName, currentWeight); 186 } 187 if (DEBUG) { 188 Log.d(TAG, "Weights: " + mFeatureWeights + " Bias: " + mBias); 189 } 190 } 191 commitUpdate()192 private void commitUpdate() { 193 try { 194 SharedPreferences.Editor editor = mParamSharedPref.edit(); 195 editor.putFloat(BIAS_PREF_KEY, mBias); 196 final int size = mFeatureWeights.size(); 197 for (int i = 0; i < size; i++) { 198 editor.putFloat(mFeatureWeights.keyAt(i), mFeatureWeights.valueAt(i)); 199 } 200 editor.putInt(VERSION_PREF_KEY, CURRENT_VERSION); 201 editor.apply(); 202 } catch (Exception e) { 203 Log.e(TAG, "Failed to commit update" + e); 204 } 205 } 206 getParamSharedPref()207 private SharedPreferences getParamSharedPref() { 208 // NOTE: EXtServices sets android:defaultToDeviceProtectedStorage="true" so we need this 209 // to make sure we're upgrading these preferences correctly. 210 return createCredentialProtectedStorageContext() 211 .getSharedPreferences(PARAM_SHARED_PREF_NAME + ".xml", Context.MODE_PRIVATE); 212 } 213 }