1 /* 2 * Copyright (C) 2021 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.internal.graphics.palette; 18 19 import android.annotation.NonNull; 20 import android.annotation.Nullable; 21 import android.util.Log; 22 23 import java.util.ArrayList; 24 import java.util.Arrays; 25 import java.util.HashSet; 26 import java.util.List; 27 import java.util.Map; 28 import java.util.Random; 29 import java.util.Set; 30 31 /** 32 * A color quantizer based on the Kmeans algorithm. Prefer using QuantizerCelebi. 33 * 34 * This is an implementation of Kmeans based on Celebi's 2011 paper, 35 * "Improving the Performance of K-Means for Color Quantization". In the paper, this algorithm is 36 * referred to as "WSMeans", or, "Weighted Square Means" The main advantages of this Kmeans 37 * implementation are taking advantage of triangle properties to avoid distance calculations, as 38 * well as indexing colors by their count, thus minimizing the number of points to move around. 39 * 40 * Celebi's paper also stabilizes results and guarantees high quality by using starting centroids 41 * from Wu's quantization algorithm. See QuantizerCelebi for more info. 42 */ 43 public final class WSMeansQuantizer implements Quantizer { 44 private static final String TAG = "QuantizerWsmeans"; 45 private static final boolean DEBUG = false; 46 private static final int MAX_ITERATIONS = 10; 47 // Points won't be moved to a closer cluster, if the closer cluster is within 48 // this distance. 3.0 used because L*a*b* delta E < 3 is considered imperceptible. 49 private static final float MIN_MOVEMENT_DISTANCE = 3.0f; 50 51 private final PointProvider mPointProvider; 52 private @Nullable Map<Integer, Integer> mInputPixelToCount; 53 private float[][] mClusters; 54 private int[] mClusterPopulations; 55 private float[][] mPoints; 56 private int[] mPixels; 57 private int[] mClusterIndices; 58 private int[][] mIndexMatrix = {}; 59 private float[][] mDistanceMatrix = {}; 60 61 private Palette mPalette; 62 WSMeansQuantizer(int[] inClusters, PointProvider pointProvider, @Nullable Map<Integer, Integer> inputPixelToCount)63 public WSMeansQuantizer(int[] inClusters, PointProvider pointProvider, 64 @Nullable Map<Integer, Integer> inputPixelToCount) { 65 mPointProvider = pointProvider; 66 67 mClusters = new float[inClusters.length][3]; 68 int index = 0; 69 for (int cluster : inClusters) { 70 float[] point = pointProvider.fromInt(cluster); 71 mClusters[index++] = point; 72 } 73 74 mInputPixelToCount = inputPixelToCount; 75 } 76 77 @Override getQuantizedColors()78 public List<Palette.Swatch> getQuantizedColors() { 79 return mPalette.getSwatches(); 80 } 81 82 @Override quantize(@onNull int[] pixels, int maxColors)83 public void quantize(@NonNull int[] pixels, int maxColors) { 84 assert (pixels.length > 0); 85 86 if (mInputPixelToCount == null) { 87 QuantizerMap mapQuantizer = new QuantizerMap(); 88 mapQuantizer.quantize(pixels, maxColors); 89 mInputPixelToCount = mapQuantizer.getColorToCount(); 90 } 91 92 mPoints = new float[mInputPixelToCount.size()][3]; 93 mPixels = new int[mInputPixelToCount.size()]; 94 int index = 0; 95 for (int pixel : mInputPixelToCount.keySet()) { 96 mPixels[index] = pixel; 97 mPoints[index] = mPointProvider.fromInt(pixel); 98 index++; 99 } 100 if (mClusters.length > 0) { 101 // This implies that the constructor was provided starting clusters. If that was the 102 // case, we limit the number of clusters to the number of starting clusters and don't 103 // initialize random clusters. 104 maxColors = Math.min(maxColors, mClusters.length); 105 } 106 maxColors = Math.min(maxColors, mPoints.length); 107 108 initializeClusters(maxColors); 109 for (int i = 0; i < MAX_ITERATIONS; i++) { 110 calculateClusterDistances(maxColors); 111 if (!reassignPoints(maxColors)) { 112 break; 113 } 114 recalculateClusterCenters(maxColors); 115 } 116 117 List<Palette.Swatch> swatches = new ArrayList<>(); 118 for (int i = 0; i < maxColors; i++) { 119 float[] cluster = mClusters[i]; 120 int colorInt = mPointProvider.toInt(cluster); 121 swatches.add(new Palette.Swatch(colorInt, mClusterPopulations[i])); 122 } 123 mPalette = Palette.from(swatches); 124 } 125 126 initializeClusters(int maxColors)127 private void initializeClusters(int maxColors) { 128 boolean hadInputClusters = mClusters.length > 0; 129 if (!hadInputClusters) { 130 int additionalClustersNeeded = maxColors - mClusters.length; 131 if (DEBUG) { 132 Log.d(TAG, "have " + mClusters.length + " clusters, want " + maxColors 133 + " results, so need " + additionalClustersNeeded + " additional clusters"); 134 } 135 136 Random random = new Random(0x42688); 137 List<float[]> additionalClusters = new ArrayList<>(additionalClustersNeeded); 138 Set<Integer> clusterIndicesUsed = new HashSet<>(); 139 for (int i = 0; i < additionalClustersNeeded; i++) { 140 int index = random.nextInt(mPoints.length); 141 while (clusterIndicesUsed.contains(index) 142 && clusterIndicesUsed.size() < mPoints.length) { 143 index = random.nextInt(mPoints.length); 144 } 145 clusterIndicesUsed.add(index); 146 additionalClusters.add(mPoints[index]); 147 } 148 149 float[][] newClusters = (float[][]) additionalClusters.toArray(); 150 float[][] clusters = Arrays.copyOf(mClusters, maxColors); 151 System.arraycopy(newClusters, 0, clusters, clusters.length, newClusters.length); 152 mClusters = clusters; 153 } 154 155 mClusterIndices = new int[mPixels.length]; 156 mClusterPopulations = new int[mPixels.length]; 157 Random random = new Random(0x42688); 158 for (int i = 0; i < mPixels.length; i++) { 159 int clusterIndex = random.nextInt(maxColors); 160 mClusterIndices[i] = clusterIndex; 161 mClusterPopulations[i] = mInputPixelToCount.get(mPixels[i]); 162 } 163 } 164 calculateClusterDistances(int maxColors)165 void calculateClusterDistances(int maxColors) { 166 if (mDistanceMatrix.length != maxColors) { 167 mDistanceMatrix = new float[maxColors][maxColors]; 168 } 169 170 for (int i = 0; i <= maxColors; i++) { 171 for (int j = i + 1; j < maxColors; j++) { 172 float distance = mPointProvider.distance(mClusters[i], mClusters[j]); 173 mDistanceMatrix[j][i] = distance; 174 mDistanceMatrix[i][j] = distance; 175 } 176 } 177 178 if (mIndexMatrix.length != maxColors) { 179 mIndexMatrix = new int[maxColors][maxColors]; 180 } 181 182 for (int i = 0; i < maxColors; i++) { 183 ArrayList<Distance> distances = new ArrayList<>(maxColors); 184 for (int index = 0; index < maxColors; index++) { 185 distances.add(new Distance(index, mDistanceMatrix[i][index])); 186 } 187 distances.sort( 188 (a, b) -> Float.compare(a.getDistance(), b.getDistance())); 189 190 for (int j = 0; j < maxColors; j++) { 191 mIndexMatrix[i][j] = distances.get(j).getIndex(); 192 } 193 } 194 } 195 reassignPoints(int maxColors)196 boolean reassignPoints(int maxColors) { 197 boolean colorMoved = false; 198 for (int i = 0; i < mPoints.length; i++) { 199 float[] point = mPoints[i]; 200 int previousClusterIndex = mClusterIndices[i]; 201 float[] previousCluster = mClusters[previousClusterIndex]; 202 float previousDistance = mPointProvider.distance(point, previousCluster); 203 204 float minimumDistance = previousDistance; 205 int newClusterIndex = -1; 206 for (int j = 1; j < maxColors; j++) { 207 int t = mIndexMatrix[previousClusterIndex][j]; 208 if (mDistanceMatrix[previousClusterIndex][t] >= 4 * previousDistance) { 209 // Triangle inequality proves there's can be no closer center. 210 break; 211 } 212 float distance = mPointProvider.distance(point, mClusters[t]); 213 if (distance < minimumDistance) { 214 minimumDistance = distance; 215 newClusterIndex = t; 216 } 217 } 218 if (newClusterIndex != -1) { 219 float distanceChange = (float) 220 Math.abs((Math.sqrt(minimumDistance) - Math.sqrt(previousDistance))); 221 if (distanceChange > MIN_MOVEMENT_DISTANCE) { 222 colorMoved = true; 223 mClusterIndices[i] = newClusterIndex; 224 } 225 } 226 } 227 return colorMoved; 228 } 229 recalculateClusterCenters(int maxColors)230 void recalculateClusterCenters(int maxColors) { 231 mClusterPopulations = new int[maxColors]; 232 float[] aSums = new float[maxColors]; 233 float[] bSums = new float[maxColors]; 234 float[] cSums = new float[maxColors]; 235 for (int i = 0; i < mPoints.length; i++) { 236 int clusterIndex = mClusterIndices[i]; 237 float[] point = mPoints[i]; 238 int pixel = mPixels[i]; 239 int count = mInputPixelToCount.get(pixel); 240 mClusterPopulations[clusterIndex] += count; 241 aSums[clusterIndex] += point[0] * count; 242 bSums[clusterIndex] += point[1] * count; 243 cSums[clusterIndex] += point[2] * count; 244 245 } 246 for (int i = 0; i < maxColors; i++) { 247 int count = mClusterPopulations[i]; 248 float aSum = aSums[i]; 249 float bSum = bSums[i]; 250 float cSum = cSums[i]; 251 mClusters[i][0] = aSum / count; 252 mClusters[i][1] = bSum / count; 253 mClusters[i][2] = cSum / count; 254 } 255 } 256 257 private static class Distance { 258 private final int mIndex; 259 private final float mDistance; 260 getIndex()261 int getIndex() { 262 return mIndex; 263 } 264 getDistance()265 float getDistance() { 266 return mDistance; 267 } 268 Distance(int index, float distance)269 Distance(int index, float distance) { 270 mIndex = index; 271 mDistance = distance; 272 } 273 } 274 } 275