1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.internal.ml.clustering;
18 
19 import static org.junit.Assert.assertEquals;
20 import static org.junit.Assert.assertTrue;
21 
22 import android.annotation.SuppressLint;
23 
24 import androidx.test.filters.SmallTest;
25 import androidx.test.runner.AndroidJUnit4;
26 
27 import org.junit.Assert;
28 import org.junit.Before;
29 import org.junit.Test;
30 import org.junit.runner.RunWith;
31 
32 import java.util.Arrays;
33 import java.util.List;
34 import java.util.Random;
35 
36 @SmallTest
37 @RunWith(AndroidJUnit4.class)
38 public class KMeansTest {
39 
40     // Error tolerance (epsilon)
41     private static final double EPS = 0.01;
42 
43     private KMeans mKMeans;
44 
45     @Before
setUp()46     public void setUp() {
47         // Setup with a random seed to have predictable results
48         mKMeans = new KMeans(new Random(0), 30, 0);
49     }
50 
51     @Test
getCheckDataSanityTest()52     public void getCheckDataSanityTest() {
53         try {
54             mKMeans.checkDataSetSanity(new float[][] {
55                     {0, 1, 2},
56                     {1, 2, 3}
57             });
58         } catch (IllegalArgumentException e) {
59             Assert.fail("Valid data didn't pass sanity check");
60         }
61 
62         try {
63             mKMeans.checkDataSetSanity(new float[][] {
64                     null,
65                     {1, 2, 3}
66             });
67             Assert.fail("Data has null items and passed");
68         } catch (IllegalArgumentException e) {}
69 
70         try {
71             mKMeans.checkDataSetSanity(new float[][] {
72                     {0, 1, 2, 4},
73                     {1, 2, 3}
74             });
75             Assert.fail("Data has invalid shape and passed");
76         } catch (IllegalArgumentException e) {}
77 
78         try {
79             mKMeans.checkDataSetSanity(null);
80             Assert.fail("Null data should throw exception");
81         } catch (IllegalArgumentException e) {}
82     }
83 
84     @Test
sqDistanceTest()85     public void sqDistanceTest() {
86         float a[] = {4, 10};
87         float b[] = {5, 2};
88         float sqDist = (float) (Math.pow(a[0] - b[0], 2) + Math.pow(a[1] - b[1], 2));
89 
90         assertEquals("Squared distance not valid", mKMeans.sqDistance(a, b), sqDist, EPS);
91     }
92 
93     @Test
nearestMeanTest()94     public void nearestMeanTest() {
95         KMeans.Mean meanA = new KMeans.Mean(0, 1);
96         KMeans.Mean meanB = new KMeans.Mean(1, 1);
97         List<KMeans.Mean> means = Arrays.asList(meanA, meanB);
98 
99         KMeans.Mean nearest = mKMeans.nearestMean(new float[] {1, 1}, means);
100 
101         assertEquals("Unexpected nearest mean for point {1, 1}", nearest, meanB);
102     }
103 
104     @SuppressLint("DefaultLocale")
105     @Test
scoreTest()106     public void scoreTest() {
107         List<KMeans.Mean> closeMeans = Arrays.asList(new KMeans.Mean(0, 0.1f, 0.1f),
108                 new KMeans.Mean(0, 0.1f, 0.15f),
109                 new KMeans.Mean(0.1f, 0.2f, 0.1f));
110         List<KMeans.Mean> farMeans = Arrays.asList(new KMeans.Mean(0, 0, 0),
111                 new KMeans.Mean(0, 0.5f, 0.5f),
112                 new KMeans.Mean(1, 0.9f, 0.9f));
113 
114         double closeScore = KMeans.score(closeMeans);
115         double farScore = KMeans.score(farMeans);
116         assertTrue(String.format("Score of well distributed means should be greater than "
117                 + "close means but got: %f, %f", farScore, closeScore), farScore > closeScore);
118     }
119 
120     @Test
predictTest()121     public void predictTest() {
122         float[] expectedCentroid1 = {1, 1, 1};
123         float[] expectedCentroid2 = {0, 0, 0};
124         float[][] X = new float[][] {
125                 {1, 1, 1},
126                 {1, 1, 1},
127                 {1, 1, 1},
128                 {0, 0, 0},
129                 {0, 0, 0},
130                 {0, 0, 0},
131         };
132 
133         final int numClusters = 2;
134 
135         // Here we assume that we won't get stuck into a local optima.
136         // It's fine because we're seeding a random, we won't ever have
137         // unstable results but in real life we need multiple initialization
138         // and score comparison
139         List<KMeans.Mean> means = mKMeans.predict(numClusters, X);
140 
141         assertEquals("Expected number of clusters is invalid", numClusters, means.size());
142 
143         boolean exists1 = false, exists2 = false;
144         for (KMeans.Mean mean : means) {
145             if (Arrays.equals(mean.getCentroid(), expectedCentroid1)) {
146                 exists1 = true;
147             } else if (Arrays.equals(mean.getCentroid(), expectedCentroid2)) {
148                 exists2 = true;
149             } else {
150                 throw new AssertionError("Unexpected mean: " + mean);
151             }
152         }
153         assertTrue("Expected means were not predicted, got: " + means,
154                 exists1 && exists2);
155     }
156 }
157