1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License
15  */
16 
17 package com.android.internal.ml.clustering;
18 
19 import android.annotation.NonNull;
20 import android.util.Log;
21 
22 import com.android.internal.annotations.VisibleForTesting;
23 
24 import java.util.ArrayList;
25 import java.util.Arrays;
26 import java.util.List;
27 import java.util.Random;
28 
29 /**
30  * Simple K-Means implementation
31  */
32 public class KMeans {
33 
34     private static final boolean DEBUG = false;
35     private static final String TAG = "KMeans";
36     private final Random mRandomState;
37     private final int mMaxIterations;
38     private float mSqConvergenceEpsilon;
39 
KMeans()40     public KMeans() {
41         this(new Random());
42     }
43 
KMeans(Random random)44     public KMeans(Random random) {
45         this(random, 30 /* maxIterations */, 0.005f /* convergenceEpsilon */);
46     }
KMeans(Random random, int maxIterations, float convergenceEpsilon)47     public KMeans(Random random, int maxIterations, float convergenceEpsilon) {
48         mRandomState = random;
49         mMaxIterations = maxIterations;
50         mSqConvergenceEpsilon = convergenceEpsilon * convergenceEpsilon;
51     }
52 
53     /**
54      * Runs k-means on the input data (X) trying to find k means.
55      *
56      * K-Means is known for getting stuck into local optima, so you might
57      * want to run it multiple time and argmax on {@link KMeans#score(List)}
58      *
59      * @param k The number of points to return.
60      * @param inputData Input data.
61      * @return An array of k Means, each representing a centroid and data points that belong to it.
62      */
predict(final int k, final float[][] inputData)63     public List<Mean> predict(final int k, final float[][] inputData) {
64         checkDataSetSanity(inputData);
65         int dimension = inputData[0].length;
66 
67         final ArrayList<Mean> means = new ArrayList<>();
68         for (int i = 0; i < k; i++) {
69             Mean m = new Mean(dimension);
70             for (int j = 0; j < dimension; j++) {
71                 m.mCentroid[j] = mRandomState.nextFloat();
72             }
73             means.add(m);
74         }
75 
76         // Iterate until we converge or run out of iterations
77         boolean converged = false;
78         for (int i = 0; i < mMaxIterations; i++) {
79             converged = step(means, inputData);
80             if (converged) {
81                 if (DEBUG) Log.d(TAG, "Converged at iteration: " + i);
82                 break;
83             }
84         }
85         if (!converged && DEBUG) Log.d(TAG, "Did not converge");
86 
87         return means;
88     }
89 
90     /**
91      * Score calculates the inertia between means.
92      * This can be considered as an E step of an EM algorithm.
93      *
94      * @param means Means to use when calculating score.
95      * @return The score
96      */
score(@onNull List<Mean> means)97     public static double score(@NonNull List<Mean> means) {
98         double score = 0;
99         final int meansSize = means.size();
100         for (int i = 0; i < meansSize; i++) {
101             Mean mean = means.get(i);
102             for (int j = 0; j < meansSize; j++) {
103                 Mean compareTo = means.get(j);
104                 if (mean == compareTo) {
105                     continue;
106                 }
107                 double distance = Math.sqrt(sqDistance(mean.mCentroid, compareTo.mCentroid));
108                 score += distance;
109             }
110         }
111         return score;
112     }
113 
114     @VisibleForTesting
checkDataSetSanity(float[][] inputData)115     public void checkDataSetSanity(float[][] inputData) {
116         if (inputData == null) {
117             throw new IllegalArgumentException("Data set is null.");
118         } else if (inputData.length == 0) {
119             throw new IllegalArgumentException("Data set is empty.");
120         } else if (inputData[0] == null) {
121             throw new IllegalArgumentException("Bad data set format.");
122         }
123 
124         final int dimension = inputData[0].length;
125         final int length = inputData.length;
126         for (int i = 1; i < length; i++) {
127             if (inputData[i] == null || inputData[i].length != dimension) {
128                 throw new IllegalArgumentException("Bad data set format.");
129             }
130         }
131     }
132 
133     /**
134      * K-Means iteration.
135      *
136      * @param means Current means
137      * @param inputData Input data
138      * @return True if data set converged
139      */
step(final ArrayList<Mean> means, final float[][] inputData)140     private boolean step(final ArrayList<Mean> means, final float[][] inputData) {
141 
142         // Clean up the previous state because we need to compute
143         // which point belongs to each mean again.
144         for (int i = means.size() - 1; i >= 0; i--) {
145             final Mean mean = means.get(i);
146             mean.mClosestItems.clear();
147         }
148         for (int i = inputData.length - 1; i >= 0; i--) {
149             final float[] current = inputData[i];
150             final Mean nearest = nearestMean(current, means);
151             nearest.mClosestItems.add(current);
152         }
153 
154         boolean converged = true;
155         // Move each mean towards the nearest data set points
156         for (int i = means.size() - 1; i >= 0; i--) {
157             final Mean mean = means.get(i);
158             if (mean.mClosestItems.size() == 0) {
159                 continue;
160             }
161 
162             // Compute the new mean centroid:
163             //   1. Sum all all points
164             //   2. Average them
165             final float[] oldCentroid = mean.mCentroid;
166             mean.mCentroid = new float[oldCentroid.length];
167             for (int j = 0; j < mean.mClosestItems.size(); j++) {
168                 // Update each centroid component
169                 for (int p = 0; p < mean.mCentroid.length; p++) {
170                     mean.mCentroid[p] += mean.mClosestItems.get(j)[p];
171                 }
172             }
173             for (int j = 0; j < mean.mCentroid.length; j++) {
174                 mean.mCentroid[j] /= mean.mClosestItems.size();
175             }
176 
177             // We converged if the centroid didn't move for any of the means.
178             if (sqDistance(oldCentroid, mean.mCentroid) > mSqConvergenceEpsilon) {
179                 converged = false;
180             }
181         }
182         return converged;
183     }
184 
185     @VisibleForTesting
nearestMean(float[] point, List<Mean> means)186     public static Mean nearestMean(float[] point, List<Mean> means) {
187         Mean nearest = null;
188         float nearestDistance = Float.MAX_VALUE;
189 
190         final int meanCount = means.size();
191         for (int i = 0; i < meanCount; i++) {
192             Mean next = means.get(i);
193             // We don't need the sqrt when comparing distances in euclidean space
194             // because they exist on both sides of the equation and cancel each other out.
195             float nextDistance = sqDistance(point, next.mCentroid);
196             if (nextDistance < nearestDistance) {
197                 nearest = next;
198                 nearestDistance = nextDistance;
199             }
200         }
201         return nearest;
202     }
203 
204     @VisibleForTesting
sqDistance(float[] a, float[] b)205     public static float sqDistance(float[] a, float[] b) {
206         float dist = 0;
207         final int length = a.length;
208         for (int i = 0; i < length; i++) {
209             dist += (a[i] - b[i]) * (a[i] - b[i]);
210         }
211         return dist;
212     }
213 
214     /**
215      * Definition of a mean, contains a centroid and points on its cluster.
216      */
217     public static class Mean {
218         float[] mCentroid;
219         final ArrayList<float[]> mClosestItems = new ArrayList<>();
220 
Mean(int dimension)221         public Mean(int dimension) {
222             mCentroid = new float[dimension];
223         }
224 
Mean(float ...centroid)225         public Mean(float ...centroid) {
226             mCentroid = centroid;
227         }
228 
getCentroid()229         public float[] getCentroid() {
230             return mCentroid;
231         }
232 
getItems()233         public List<float[]> getItems() {
234             return mClosestItems;
235         }
236 
237         @Override
toString()238         public String toString() {
239             return "Mean(centroid: " + Arrays.toString(mCentroid) + ", size: "
240                     + mClosestItems.size() + ")";
241         }
242     }
243 }
244