1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License 15 */ 16 17 package com.android.internal.ml.clustering; 18 19 import android.annotation.NonNull; 20 import android.util.Log; 21 22 import com.android.internal.annotations.VisibleForTesting; 23 24 import java.util.ArrayList; 25 import java.util.Arrays; 26 import java.util.List; 27 import java.util.Random; 28 29 /** 30 * Simple K-Means implementation 31 */ 32 public class KMeans { 33 34 private static final boolean DEBUG = false; 35 private static final String TAG = "KMeans"; 36 private final Random mRandomState; 37 private final int mMaxIterations; 38 private float mSqConvergenceEpsilon; 39 KMeans()40 public KMeans() { 41 this(new Random()); 42 } 43 KMeans(Random random)44 public KMeans(Random random) { 45 this(random, 30 /* maxIterations */, 0.005f /* convergenceEpsilon */); 46 } KMeans(Random random, int maxIterations, float convergenceEpsilon)47 public KMeans(Random random, int maxIterations, float convergenceEpsilon) { 48 mRandomState = random; 49 mMaxIterations = maxIterations; 50 mSqConvergenceEpsilon = convergenceEpsilon * convergenceEpsilon; 51 } 52 53 /** 54 * Runs k-means on the input data (X) trying to find k means. 55 * 56 * K-Means is known for getting stuck into local optima, so you might 57 * want to run it multiple time and argmax on {@link KMeans#score(List)} 58 * 59 * @param k The number of points to return. 60 * @param inputData Input data. 61 * @return An array of k Means, each representing a centroid and data points that belong to it. 62 */ predict(final int k, final float[][] inputData)63 public List<Mean> predict(final int k, final float[][] inputData) { 64 checkDataSetSanity(inputData); 65 int dimension = inputData[0].length; 66 67 final ArrayList<Mean> means = new ArrayList<>(); 68 for (int i = 0; i < k; i++) { 69 Mean m = new Mean(dimension); 70 for (int j = 0; j < dimension; j++) { 71 m.mCentroid[j] = mRandomState.nextFloat(); 72 } 73 means.add(m); 74 } 75 76 // Iterate until we converge or run out of iterations 77 boolean converged = false; 78 for (int i = 0; i < mMaxIterations; i++) { 79 converged = step(means, inputData); 80 if (converged) { 81 if (DEBUG) Log.d(TAG, "Converged at iteration: " + i); 82 break; 83 } 84 } 85 if (!converged && DEBUG) Log.d(TAG, "Did not converge"); 86 87 return means; 88 } 89 90 /** 91 * Score calculates the inertia between means. 92 * This can be considered as an E step of an EM algorithm. 93 * 94 * @param means Means to use when calculating score. 95 * @return The score 96 */ score(@onNull List<Mean> means)97 public static double score(@NonNull List<Mean> means) { 98 double score = 0; 99 final int meansSize = means.size(); 100 for (int i = 0; i < meansSize; i++) { 101 Mean mean = means.get(i); 102 for (int j = 0; j < meansSize; j++) { 103 Mean compareTo = means.get(j); 104 if (mean == compareTo) { 105 continue; 106 } 107 double distance = Math.sqrt(sqDistance(mean.mCentroid, compareTo.mCentroid)); 108 score += distance; 109 } 110 } 111 return score; 112 } 113 114 @VisibleForTesting checkDataSetSanity(float[][] inputData)115 public void checkDataSetSanity(float[][] inputData) { 116 if (inputData == null) { 117 throw new IllegalArgumentException("Data set is null."); 118 } else if (inputData.length == 0) { 119 throw new IllegalArgumentException("Data set is empty."); 120 } else if (inputData[0] == null) { 121 throw new IllegalArgumentException("Bad data set format."); 122 } 123 124 final int dimension = inputData[0].length; 125 final int length = inputData.length; 126 for (int i = 1; i < length; i++) { 127 if (inputData[i] == null || inputData[i].length != dimension) { 128 throw new IllegalArgumentException("Bad data set format."); 129 } 130 } 131 } 132 133 /** 134 * K-Means iteration. 135 * 136 * @param means Current means 137 * @param inputData Input data 138 * @return True if data set converged 139 */ step(final ArrayList<Mean> means, final float[][] inputData)140 private boolean step(final ArrayList<Mean> means, final float[][] inputData) { 141 142 // Clean up the previous state because we need to compute 143 // which point belongs to each mean again. 144 for (int i = means.size() - 1; i >= 0; i--) { 145 final Mean mean = means.get(i); 146 mean.mClosestItems.clear(); 147 } 148 for (int i = inputData.length - 1; i >= 0; i--) { 149 final float[] current = inputData[i]; 150 final Mean nearest = nearestMean(current, means); 151 nearest.mClosestItems.add(current); 152 } 153 154 boolean converged = true; 155 // Move each mean towards the nearest data set points 156 for (int i = means.size() - 1; i >= 0; i--) { 157 final Mean mean = means.get(i); 158 if (mean.mClosestItems.size() == 0) { 159 continue; 160 } 161 162 // Compute the new mean centroid: 163 // 1. Sum all all points 164 // 2. Average them 165 final float[] oldCentroid = mean.mCentroid; 166 mean.mCentroid = new float[oldCentroid.length]; 167 for (int j = 0; j < mean.mClosestItems.size(); j++) { 168 // Update each centroid component 169 for (int p = 0; p < mean.mCentroid.length; p++) { 170 mean.mCentroid[p] += mean.mClosestItems.get(j)[p]; 171 } 172 } 173 for (int j = 0; j < mean.mCentroid.length; j++) { 174 mean.mCentroid[j] /= mean.mClosestItems.size(); 175 } 176 177 // We converged if the centroid didn't move for any of the means. 178 if (sqDistance(oldCentroid, mean.mCentroid) > mSqConvergenceEpsilon) { 179 converged = false; 180 } 181 } 182 return converged; 183 } 184 185 @VisibleForTesting nearestMean(float[] point, List<Mean> means)186 public static Mean nearestMean(float[] point, List<Mean> means) { 187 Mean nearest = null; 188 float nearestDistance = Float.MAX_VALUE; 189 190 final int meanCount = means.size(); 191 for (int i = 0; i < meanCount; i++) { 192 Mean next = means.get(i); 193 // We don't need the sqrt when comparing distances in euclidean space 194 // because they exist on both sides of the equation and cancel each other out. 195 float nextDistance = sqDistance(point, next.mCentroid); 196 if (nextDistance < nearestDistance) { 197 nearest = next; 198 nearestDistance = nextDistance; 199 } 200 } 201 return nearest; 202 } 203 204 @VisibleForTesting sqDistance(float[] a, float[] b)205 public static float sqDistance(float[] a, float[] b) { 206 float dist = 0; 207 final int length = a.length; 208 for (int i = 0; i < length; i++) { 209 dist += (a[i] - b[i]) * (a[i] - b[i]); 210 } 211 return dist; 212 } 213 214 /** 215 * Definition of a mean, contains a centroid and points on its cluster. 216 */ 217 public static class Mean { 218 float[] mCentroid; 219 final ArrayList<float[]> mClosestItems = new ArrayList<>(); 220 Mean(int dimension)221 public Mean(int dimension) { 222 mCentroid = new float[dimension]; 223 } 224 Mean(float ...centroid)225 public Mean(float ...centroid) { 226 mCentroid = centroid; 227 } 228 getCentroid()229 public float[] getCentroid() { 230 return mCentroid; 231 } 232 getItems()233 public List<float[]> getItems() { 234 return mClosestItems; 235 } 236 237 @Override toString()238 public String toString() { 239 return "Mean(centroid: " + Arrays.toString(mCentroid) + ", size: " 240 + mClosestItems.size() + ")"; 241 } 242 } 243 } 244