1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.internal.graphics.palette;
18 
19 import android.annotation.NonNull;
20 import android.annotation.Nullable;
21 import android.util.Log;
22 
23 import java.util.ArrayList;
24 import java.util.Arrays;
25 import java.util.HashSet;
26 import java.util.List;
27 import java.util.Map;
28 import java.util.Random;
29 import java.util.Set;
30 
31 /**
32  * A color quantizer based on the Kmeans algorithm. Prefer using QuantizerCelebi.
33  *
34  * This is an implementation of Kmeans based on Celebi's 2011 paper,
35  * "Improving the Performance of K-Means for Color Quantization". In the paper, this algorithm is
36  * referred to as "WSMeans", or, "Weighted Square Means" The main advantages of this Kmeans
37  * implementation are taking advantage of triangle properties to avoid distance calculations, as
38  * well as indexing colors by their count, thus minimizing the number of points to move around.
39  *
40  * Celebi's paper also stabilizes results and guarantees high quality by using starting centroids
41  * from Wu's quantization algorithm. See QuantizerCelebi for more info.
42  */
43 public final class WSMeansQuantizer implements Quantizer {
44     private static final String TAG = "QuantizerWsmeans";
45     private static final boolean DEBUG = false;
46     private static final int MAX_ITERATIONS = 10;
47     // Points won't be moved to a closer cluster, if the closer cluster is within
48     // this distance. 3.0 used because L*a*b* delta E < 3 is considered imperceptible.
49     private static final float MIN_MOVEMENT_DISTANCE = 3.0f;
50 
51     private final PointProvider mPointProvider;
52     private @Nullable Map<Integer, Integer> mInputPixelToCount;
53     private float[][] mClusters;
54     private int[] mClusterPopulations;
55     private float[][] mPoints;
56     private int[] mPixels;
57     private int[] mClusterIndices;
58     private int[][] mIndexMatrix = {};
59     private float[][] mDistanceMatrix = {};
60 
61     private Palette mPalette;
62 
WSMeansQuantizer(int[] inClusters, PointProvider pointProvider, @Nullable Map<Integer, Integer> inputPixelToCount)63     public WSMeansQuantizer(int[] inClusters, PointProvider pointProvider,
64             @Nullable Map<Integer, Integer> inputPixelToCount) {
65         mPointProvider = pointProvider;
66 
67         mClusters = new float[inClusters.length][3];
68         int index = 0;
69         for (int cluster : inClusters) {
70             float[] point = pointProvider.fromInt(cluster);
71             mClusters[index++] = point;
72         }
73 
74         mInputPixelToCount = inputPixelToCount;
75     }
76 
77     @Override
getQuantizedColors()78     public List<Palette.Swatch> getQuantizedColors() {
79         return mPalette.getSwatches();
80     }
81 
82     @Override
quantize(@onNull int[] pixels, int maxColors)83     public void quantize(@NonNull int[] pixels, int maxColors) {
84         assert (pixels.length > 0);
85 
86         if (mInputPixelToCount == null) {
87             QuantizerMap mapQuantizer = new QuantizerMap();
88             mapQuantizer.quantize(pixels, maxColors);
89             mInputPixelToCount = mapQuantizer.getColorToCount();
90         }
91 
92         mPoints = new float[mInputPixelToCount.size()][3];
93         mPixels = new int[mInputPixelToCount.size()];
94         int index = 0;
95         for (int pixel : mInputPixelToCount.keySet()) {
96             mPixels[index] = pixel;
97             mPoints[index] = mPointProvider.fromInt(pixel);
98             index++;
99         }
100         if (mClusters.length > 0) {
101             // This implies that the constructor was provided starting clusters. If that was the
102             // case, we limit the number of clusters to the number of starting clusters and don't
103             // initialize random clusters.
104             maxColors = Math.min(maxColors, mClusters.length);
105         }
106         maxColors = Math.min(maxColors, mPoints.length);
107 
108         initializeClusters(maxColors);
109         for (int i = 0; i < MAX_ITERATIONS; i++) {
110             calculateClusterDistances(maxColors);
111             if (!reassignPoints(maxColors)) {
112                 break;
113             }
114             recalculateClusterCenters(maxColors);
115         }
116 
117         List<Palette.Swatch> swatches = new ArrayList<>();
118         for (int i = 0; i < maxColors; i++) {
119             float[] cluster = mClusters[i];
120             int colorInt = mPointProvider.toInt(cluster);
121             swatches.add(new Palette.Swatch(colorInt, mClusterPopulations[i]));
122         }
123         mPalette = Palette.from(swatches);
124     }
125 
126 
initializeClusters(int maxColors)127     private void initializeClusters(int maxColors) {
128         boolean hadInputClusters = mClusters.length > 0;
129         if (!hadInputClusters) {
130             int additionalClustersNeeded = maxColors - mClusters.length;
131             if (DEBUG) {
132                 Log.d(TAG, "have " + mClusters.length + " clusters, want " + maxColors
133                         + " results, so need " + additionalClustersNeeded + " additional clusters");
134             }
135 
136             Random random = new Random(0x42688);
137             List<float[]> additionalClusters = new ArrayList<>(additionalClustersNeeded);
138             Set<Integer> clusterIndicesUsed = new HashSet<>();
139             for (int i = 0; i < additionalClustersNeeded; i++) {
140                 int index = random.nextInt(mPoints.length);
141                 while (clusterIndicesUsed.contains(index)
142                         && clusterIndicesUsed.size() < mPoints.length) {
143                     index = random.nextInt(mPoints.length);
144                 }
145                 clusterIndicesUsed.add(index);
146                 additionalClusters.add(mPoints[index]);
147             }
148 
149             float[][] newClusters = (float[][]) additionalClusters.toArray();
150             float[][] clusters = Arrays.copyOf(mClusters, maxColors);
151             System.arraycopy(newClusters, 0, clusters, clusters.length, newClusters.length);
152             mClusters = clusters;
153         }
154 
155         mClusterIndices = new int[mPixels.length];
156         mClusterPopulations = new int[mPixels.length];
157         Random random = new Random(0x42688);
158         for (int i = 0; i < mPixels.length; i++) {
159             int clusterIndex = random.nextInt(maxColors);
160             mClusterIndices[i] = clusterIndex;
161             mClusterPopulations[i] = mInputPixelToCount.get(mPixels[i]);
162         }
163     }
164 
calculateClusterDistances(int maxColors)165     void calculateClusterDistances(int maxColors) {
166         if (mDistanceMatrix.length != maxColors) {
167             mDistanceMatrix = new float[maxColors][maxColors];
168         }
169 
170         for (int i = 0; i <= maxColors; i++) {
171             for (int j = i + 1; j < maxColors; j++) {
172                 float distance = mPointProvider.distance(mClusters[i], mClusters[j]);
173                 mDistanceMatrix[j][i] = distance;
174                 mDistanceMatrix[i][j] = distance;
175             }
176         }
177 
178         if (mIndexMatrix.length != maxColors) {
179             mIndexMatrix = new int[maxColors][maxColors];
180         }
181 
182         for (int i = 0; i < maxColors; i++) {
183             ArrayList<Distance> distances = new ArrayList<>(maxColors);
184             for (int index = 0; index < maxColors; index++) {
185                 distances.add(new Distance(index, mDistanceMatrix[i][index]));
186             }
187             distances.sort(
188                     (a, b) -> Float.compare(a.getDistance(), b.getDistance()));
189 
190             for (int j = 0; j < maxColors; j++) {
191                 mIndexMatrix[i][j] = distances.get(j).getIndex();
192             }
193         }
194     }
195 
reassignPoints(int maxColors)196     boolean reassignPoints(int maxColors) {
197         boolean colorMoved = false;
198         for (int i = 0; i < mPoints.length; i++) {
199             float[] point = mPoints[i];
200             int previousClusterIndex = mClusterIndices[i];
201             float[] previousCluster = mClusters[previousClusterIndex];
202             float previousDistance = mPointProvider.distance(point, previousCluster);
203 
204             float minimumDistance = previousDistance;
205             int newClusterIndex = -1;
206             for (int j = 1; j < maxColors; j++) {
207                 int t = mIndexMatrix[previousClusterIndex][j];
208                 if (mDistanceMatrix[previousClusterIndex][t] >= 4 * previousDistance) {
209                     // Triangle inequality proves there's can be no closer center.
210                     break;
211                 }
212                 float distance = mPointProvider.distance(point, mClusters[t]);
213                 if (distance < minimumDistance) {
214                     minimumDistance = distance;
215                     newClusterIndex = t;
216                 }
217             }
218             if (newClusterIndex != -1) {
219                 float distanceChange = (float)
220                         Math.abs((Math.sqrt(minimumDistance) - Math.sqrt(previousDistance)));
221                 if (distanceChange > MIN_MOVEMENT_DISTANCE) {
222                     colorMoved = true;
223                     mClusterIndices[i] = newClusterIndex;
224                 }
225             }
226         }
227         return colorMoved;
228     }
229 
recalculateClusterCenters(int maxColors)230     void recalculateClusterCenters(int maxColors) {
231         mClusterPopulations = new int[maxColors];
232         float[] aSums = new float[maxColors];
233         float[] bSums = new float[maxColors];
234         float[] cSums = new float[maxColors];
235         for (int i = 0; i < mPoints.length; i++) {
236             int clusterIndex = mClusterIndices[i];
237             float[] point = mPoints[i];
238             int pixel = mPixels[i];
239             int count =  mInputPixelToCount.get(pixel);
240             mClusterPopulations[clusterIndex] += count;
241             aSums[clusterIndex] += point[0] * count;
242             bSums[clusterIndex] += point[1] * count;
243             cSums[clusterIndex] += point[2] * count;
244 
245         }
246         for (int i = 0; i < maxColors; i++) {
247             int count = mClusterPopulations[i];
248             float aSum = aSums[i];
249             float bSum = bSums[i];
250             float cSum = cSums[i];
251             mClusters[i][0] = aSum / count;
252             mClusters[i][1] = bSum / count;
253             mClusters[i][2] = cSum / count;
254         }
255     }
256 
257     private static class Distance {
258         private final int mIndex;
259         private final float mDistance;
260 
getIndex()261         int getIndex() {
262             return mIndex;
263         }
264 
getDistance()265         float getDistance() {
266             return mDistance;
267         }
268 
Distance(int index, float distance)269         Distance(int index, float distance) {
270             mIndex = index;
271             mDistance = distance;
272         }
273     }
274 }
275