1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.internal.ml.clustering; 18 19 import static org.junit.Assert.assertEquals; 20 import static org.junit.Assert.assertTrue; 21 22 import android.annotation.SuppressLint; 23 24 import androidx.test.filters.SmallTest; 25 import androidx.test.runner.AndroidJUnit4; 26 27 import org.junit.Assert; 28 import org.junit.Before; 29 import org.junit.Test; 30 import org.junit.runner.RunWith; 31 32 import java.util.Arrays; 33 import java.util.List; 34 import java.util.Random; 35 36 @SmallTest 37 @RunWith(AndroidJUnit4.class) 38 public class KMeansTest { 39 40 // Error tolerance (epsilon) 41 private static final double EPS = 0.01; 42 43 private KMeans mKMeans; 44 45 @Before setUp()46 public void setUp() { 47 // Setup with a random seed to have predictable results 48 mKMeans = new KMeans(new Random(0), 30, 0); 49 } 50 51 @Test getCheckDataSanityTest()52 public void getCheckDataSanityTest() { 53 try { 54 mKMeans.checkDataSetSanity(new float[][] { 55 {0, 1, 2}, 56 {1, 2, 3} 57 }); 58 } catch (IllegalArgumentException e) { 59 Assert.fail("Valid data didn't pass sanity check"); 60 } 61 62 try { 63 mKMeans.checkDataSetSanity(new float[][] { 64 null, 65 {1, 2, 3} 66 }); 67 Assert.fail("Data has null items and passed"); 68 } catch (IllegalArgumentException e) {} 69 70 try { 71 mKMeans.checkDataSetSanity(new float[][] { 72 {0, 1, 2, 4}, 73 {1, 2, 3} 74 }); 75 Assert.fail("Data has invalid shape and passed"); 76 } catch (IllegalArgumentException e) {} 77 78 try { 79 mKMeans.checkDataSetSanity(null); 80 Assert.fail("Null data should throw exception"); 81 } catch (IllegalArgumentException e) {} 82 } 83 84 @Test sqDistanceTest()85 public void sqDistanceTest() { 86 float a[] = {4, 10}; 87 float b[] = {5, 2}; 88 float sqDist = (float) (Math.pow(a[0] - b[0], 2) + Math.pow(a[1] - b[1], 2)); 89 90 assertEquals("Squared distance not valid", mKMeans.sqDistance(a, b), sqDist, EPS); 91 } 92 93 @Test nearestMeanTest()94 public void nearestMeanTest() { 95 KMeans.Mean meanA = new KMeans.Mean(0, 1); 96 KMeans.Mean meanB = new KMeans.Mean(1, 1); 97 List<KMeans.Mean> means = Arrays.asList(meanA, meanB); 98 99 KMeans.Mean nearest = mKMeans.nearestMean(new float[] {1, 1}, means); 100 101 assertEquals("Unexpected nearest mean for point {1, 1}", nearest, meanB); 102 } 103 104 @SuppressLint("DefaultLocale") 105 @Test scoreTest()106 public void scoreTest() { 107 List<KMeans.Mean> closeMeans = Arrays.asList(new KMeans.Mean(0, 0.1f, 0.1f), 108 new KMeans.Mean(0, 0.1f, 0.15f), 109 new KMeans.Mean(0.1f, 0.2f, 0.1f)); 110 List<KMeans.Mean> farMeans = Arrays.asList(new KMeans.Mean(0, 0, 0), 111 new KMeans.Mean(0, 0.5f, 0.5f), 112 new KMeans.Mean(1, 0.9f, 0.9f)); 113 114 double closeScore = KMeans.score(closeMeans); 115 double farScore = KMeans.score(farMeans); 116 assertTrue(String.format("Score of well distributed means should be greater than " 117 + "close means but got: %f, %f", farScore, closeScore), farScore > closeScore); 118 } 119 120 @Test predictTest()121 public void predictTest() { 122 float[] expectedCentroid1 = {1, 1, 1}; 123 float[] expectedCentroid2 = {0, 0, 0}; 124 float[][] X = new float[][] { 125 {1, 1, 1}, 126 {1, 1, 1}, 127 {1, 1, 1}, 128 {0, 0, 0}, 129 {0, 0, 0}, 130 {0, 0, 0}, 131 }; 132 133 final int numClusters = 2; 134 135 // Here we assume that we won't get stuck into a local optima. 136 // It's fine because we're seeding a random, we won't ever have 137 // unstable results but in real life we need multiple initialization 138 // and score comparison 139 List<KMeans.Mean> means = mKMeans.predict(numClusters, X); 140 141 assertEquals("Expected number of clusters is invalid", numClusters, means.size()); 142 143 boolean exists1 = false, exists2 = false; 144 for (KMeans.Mean mean : means) { 145 if (Arrays.equals(mean.getCentroid(), expectedCentroid1)) { 146 exists1 = true; 147 } else if (Arrays.equals(mean.getCentroid(), expectedCentroid2)) { 148 exists2 = true; 149 } else { 150 throw new AssertionError("Unexpected mean: " + mean); 151 } 152 } 153 assertTrue("Expected means were not predicted, got: " + means, 154 exists1 && exists2); 155 } 156 } 157