1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.internal.util;
18 
19 import android.util.ArraySet;
20 import android.util.Pair;
21 
22 import junit.framework.TestCase;
23 
24 import java.util.ArrayList;
25 import java.util.Comparator;
26 import java.util.HashMap;
27 import java.util.List;
28 import java.util.Map;
29 import java.util.Random;
30 import java.util.stream.Collectors;
31 
32 /**
33  * Tests for {@link HeavyHitterSketch}.
34  */
35 public final class HeavyHitterSketchTest extends TestCase {
36 
37     private static final float EPSILON = 0.00001f;
38 
39     /**
40      * A naive counter based heavy hitter sketch, tracks every single input. To be used to validate
41      * the correctness of {@link HeavyHitterSketch}.
42      */
43     private class CounterBased<T> {
44         private final HashMap<T, Integer> mData = new HashMap<>();
45         private int mTotalInput = 0;
46 
add(final T newInstance)47         public void add(final T newInstance) {
48             int val = mData.getOrDefault(newInstance, 0);
49             mData.put(newInstance, val + 1);
50             mTotalInput++;
51         }
52 
getTopHeavyHitters(final int k)53         public List<Pair<T, Float>> getTopHeavyHitters(final int k) {
54             final int lower = mTotalInput / (k + 1);
55             return mData.entrySet().stream()
56                     .filter(e -> e.getValue() >= lower)
57                     .limit(k)
58                     .sorted(Map.Entry.comparingByValue(Comparator.reverseOrder()))
59                     .map((v) -> new Pair<T, Float>(v.getKey(), (float) v.getValue() / mTotalInput))
60                     .collect(Collectors.toList());
61         }
62     }
63 
getTopHeavyHitters(final int[] input, final int capacity)64     private List<Pair<Integer, Float>> getTopHeavyHitters(final int[] input, final int capacity) {
65         final CounterBased counter = new CounterBased<Integer>();
66         final HeavyHitterSketch<Integer> sketcher = HeavyHitterSketch.<Integer>newDefault();
67         final float ratio = sketcher.getRequiredValidationInputRatio();
68         final int total = (int) (input.length / (1 - ratio));
69         sketcher.setConfig(total, capacity);
70         for (int i = 0; i < input.length; i++) {
71             sketcher.add(input[i]);
72             counter.add(input[i]);
73         }
74         int validationSize = total - input.length;
75         assertTrue(validationSize <= input.length);
76         for (int i = 0; i < validationSize; i++) {
77             sketcher.add(input[i]);
78         }
79         final List<Float> freqs = new ArrayList<>();
80         final List<Integer> tops = sketcher.getTopHeavyHitters(capacity - 1, null, freqs);
81         final List<Pair<Integer, Float>> result = new ArrayList<>();
82         if (tops != null) {
83             assertEquals(freqs.size(), tops.size());
84             final List<Pair<Integer, Float>> cl = counter.getTopHeavyHitters(capacity - 1);
85             for (int i = 0; i < tops.size(); i++) {
86                 final Pair<Integer, Float> pair = cl.get(i);
87                 assertEquals(pair.first.intValue(), tops.get(i).intValue());
88                 assertTrue(Math.abs(pair.second - freqs.get(i)) < EPSILON);
89                 result.add(new Pair<>(tops.get(i), freqs.get(i)));
90             }
91         } else {
92             assertTrue(counter.getTopHeavyHitters(capacity - 1).isEmpty());
93         }
94         return result;
95     }
96 
97     private List<Integer> getCandidates(final int[] input, final int capacity) {
98         final HeavyHitterSketch<Integer> sketcher = HeavyHitterSketch.<Integer>newDefault();
99         final float ratio = sketcher.getRequiredValidationInputRatio();
100         final int total = (int) (input.length / (1 - ratio));
101         sketcher.setConfig(total, capacity);
102         for (int i = 0; i < input.length; i++) {
103             sketcher.add(input[i]);
104         }
105         return sketcher.getCandidates(null);
106     }
107 
108     private void verify(final int[] input, final int capacity, final int[] expected,
109             final float[] freqs) throws Exception {
110         final List<Integer> candidates = getCandidates(input, capacity);
111         final List<Pair<Integer, Float>> result = getTopHeavyHitters(input, capacity);
112         if (expected != null) {
113             assertTrue(candidates != null);
114             for (int i = 0; i < expected.length; i++) {
115                 assertTrue(candidates.contains(expected[i]));
116             }
117             assertTrue(result != null);
118             assertEquals(expected.length, result.size());
119             for (int i = 0; i < expected.length; i++) {
120                 final Pair<Integer, Float> pair = result.get(i);
121                 assertEquals(expected[i], pair.first.intValue());
122                 assertTrue(Math.abs(freqs[i] - pair.second) < EPSILON);
123             }
124         } else {
125             assertEquals(null, result);
126         }
127     }
128 
129     private void verifyNotExpected(final int[] input, final int capacity, final int[] notExpected)
130             throws Exception {
131         final List<Pair<Integer, Float>> result = getTopHeavyHitters(input, capacity);
132         if (result != null) {
133             final ArraySet<Integer> set = new ArraySet<>();
134             for (Pair<Integer, Float> p : result) {
135                 set.add(p.first);
136             }
137             for (int i = 0; i < notExpected.length; i++) {
138                 assertFalse(set.contains(notExpected[i]));
139             }
140         }
141     }
142 
143     private int[] generateRandomInput(final int size, final int[] hitters) {
144         final Random random = new Random();
145         final Random random2 = new Random();
146         final int[] input = new int[size];
147         // 80% of them would be hitters, 20% will be random numbers
148         final int numOfRandoms = size / 5;
149         final int numOfHitters = size - numOfRandoms;
150         for (int i = 0, j = 0, m = numOfRandoms, n = numOfHitters; i < size; i++) {
151             int r = m > 0 && n > 0 ? random2.nextInt(size) : (m > 0 ? 0 : numOfRandoms);
152             if (r < numOfRandoms) {
153                 input[i] = random.nextInt(size);
154                 m--;
155             } else {
156                 input[i] = hitters[j++];
157                 if (j == hitters.length) {
158                     j = 0;
159                 }
160                 n--;
161             }
162         }
163         return input;
164     }
165 
166     public void testPositive() throws Exception {
167         // Simple case
168         verify(new int[]{2, 9, 9, 9, 7, 6, 4, 9, 9, 9, 3, 9}, 2, new int[]{9},
169                 new float[]{0.583333f});
170 
171         // Two heavy hitters
172         verify(new int[]{2, 3, 9, 3, 9, 3, 9, 7, 6, 4, 9, 9, 3, 9, 3, 9}, 3, new int[]{9, 3},
173                 new float[]{0.4375f, 0.3125f});
174 
175         // Create a random data set and insert some numbers
176         final int[] input = generateRandomInput(100,
177                 new int[]{1001, 1002, 1002, 1003, 1003, 1003, 1004, 1004, 1004, 1004});
178         verify(input, 12, new int[]{1004, 1003, 1002, 1001},
179                 new float[]{0.32f, 0.24f, 0.16f, 0.08f});
180     }
181 
182     public void testNegative() throws Exception {
183         // Simple case
184         verifyNotExpected(new int[]{2, 9, 9, 9, 7, 6, 4, 9, 9, 9, 3, 9}, 2, new int[]{0, 1, 2});
185 
186         // Two heavy hitters
187         verifyNotExpected(new int[]{2, 3, 9, 3, 9, 3, 9, 7, 6, 4, 9, 9, 3, 9, 3, 9}, 3,
188                 new int[]{0, 1, 2});
189 
190         // Create a random data set and insert some numbers
191         final int[] input = generateRandomInput(100,
192                 new int[]{1001, 1002, 1002, 1003, 1003, 1003, 1004, 1004, 1004, 1004});
193         verifyNotExpected(input, 12, new int[]{0, 1, 2, 1000, 1005});
194     }
195 
196     public void testFalsePositive() throws Exception {
197         // Simple case
198         verifyNotExpected(new int[]{2, 9, 2, 2, 7, 6, 4, 9, 9, 9, 3, 9}, 2, new int[]{9});
199 
200         // One heavy hitter
201         verifyNotExpected(new int[]{2, 3, 9, 3, 9, 3, 9, 7, 6, 4, 9, 9, 3, 9, 2, 9}, 3,
202                 new int[]{3});
203 
204         // Create a random data set and insert some numbers
205         final int[] input = generateRandomInput(100,
206                 new int[]{1001, 1002, 1002, 1003, 1003, 1003, 1004, 1004, 1004, 1004});
207         verifyNotExpected(input, 11, new int[]{1001});
208     }
209 }
210