1 /* 2 * Copyright (C) 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.internal.util; 18 19 import android.util.ArraySet; 20 import android.util.Pair; 21 22 import junit.framework.TestCase; 23 24 import java.util.ArrayList; 25 import java.util.Comparator; 26 import java.util.HashMap; 27 import java.util.List; 28 import java.util.Map; 29 import java.util.Random; 30 import java.util.stream.Collectors; 31 32 /** 33 * Tests for {@link HeavyHitterSketch}. 34 */ 35 public final class HeavyHitterSketchTest extends TestCase { 36 37 private static final float EPSILON = 0.00001f; 38 39 /** 40 * A naive counter based heavy hitter sketch, tracks every single input. To be used to validate 41 * the correctness of {@link HeavyHitterSketch}. 42 */ 43 private class CounterBased<T> { 44 private final HashMap<T, Integer> mData = new HashMap<>(); 45 private int mTotalInput = 0; 46 add(final T newInstance)47 public void add(final T newInstance) { 48 int val = mData.getOrDefault(newInstance, 0); 49 mData.put(newInstance, val + 1); 50 mTotalInput++; 51 } 52 getTopHeavyHitters(final int k)53 public List<Pair<T, Float>> getTopHeavyHitters(final int k) { 54 final int lower = mTotalInput / (k + 1); 55 return mData.entrySet().stream() 56 .filter(e -> e.getValue() >= lower) 57 .limit(k) 58 .sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())) 59 .map((v) -> new Pair<T, Float>(v.getKey(), (float) v.getValue() / mTotalInput)) 60 .collect(Collectors.toList()); 61 } 62 } 63 getTopHeavyHitters(final int[] input, final int capacity)64 private List<Pair<Integer, Float>> getTopHeavyHitters(final int[] input, final int capacity) { 65 final CounterBased counter = new CounterBased<Integer>(); 66 final HeavyHitterSketch<Integer> sketcher = HeavyHitterSketch.<Integer>newDefault(); 67 final float ratio = sketcher.getRequiredValidationInputRatio(); 68 final int total = (int) (input.length / (1 - ratio)); 69 sketcher.setConfig(total, capacity); 70 for (int i = 0; i < input.length; i++) { 71 sketcher.add(input[i]); 72 counter.add(input[i]); 73 } 74 int validationSize = total - input.length; 75 assertTrue(validationSize <= input.length); 76 for (int i = 0; i < validationSize; i++) { 77 sketcher.add(input[i]); 78 } 79 final List<Float> freqs = new ArrayList<>(); 80 final List<Integer> tops = sketcher.getTopHeavyHitters(capacity - 1, null, freqs); 81 final List<Pair<Integer, Float>> result = new ArrayList<>(); 82 if (tops != null) { 83 assertEquals(freqs.size(), tops.size()); 84 final List<Pair<Integer, Float>> cl = counter.getTopHeavyHitters(capacity - 1); 85 for (int i = 0; i < tops.size(); i++) { 86 final Pair<Integer, Float> pair = cl.get(i); 87 assertEquals(pair.first.intValue(), tops.get(i).intValue()); 88 assertTrue(Math.abs(pair.second - freqs.get(i)) < EPSILON); 89 result.add(new Pair<>(tops.get(i), freqs.get(i))); 90 } 91 } else { 92 assertTrue(counter.getTopHeavyHitters(capacity - 1).isEmpty()); 93 } 94 return result; 95 } 96 97 private List<Integer> getCandidates(final int[] input, final int capacity) { 98 final HeavyHitterSketch<Integer> sketcher = HeavyHitterSketch.<Integer>newDefault(); 99 final float ratio = sketcher.getRequiredValidationInputRatio(); 100 final int total = (int) (input.length / (1 - ratio)); 101 sketcher.setConfig(total, capacity); 102 for (int i = 0; i < input.length; i++) { 103 sketcher.add(input[i]); 104 } 105 return sketcher.getCandidates(null); 106 } 107 108 private void verify(final int[] input, final int capacity, final int[] expected, 109 final float[] freqs) throws Exception { 110 final List<Integer> candidates = getCandidates(input, capacity); 111 final List<Pair<Integer, Float>> result = getTopHeavyHitters(input, capacity); 112 if (expected != null) { 113 assertTrue(candidates != null); 114 for (int i = 0; i < expected.length; i++) { 115 assertTrue(candidates.contains(expected[i])); 116 } 117 assertTrue(result != null); 118 assertEquals(expected.length, result.size()); 119 for (int i = 0; i < expected.length; i++) { 120 final Pair<Integer, Float> pair = result.get(i); 121 assertEquals(expected[i], pair.first.intValue()); 122 assertTrue(Math.abs(freqs[i] - pair.second) < EPSILON); 123 } 124 } else { 125 assertEquals(null, result); 126 } 127 } 128 129 private void verifyNotExpected(final int[] input, final int capacity, final int[] notExpected) 130 throws Exception { 131 final List<Pair<Integer, Float>> result = getTopHeavyHitters(input, capacity); 132 if (result != null) { 133 final ArraySet<Integer> set = new ArraySet<>(); 134 for (Pair<Integer, Float> p : result) { 135 set.add(p.first); 136 } 137 for (int i = 0; i < notExpected.length; i++) { 138 assertFalse(set.contains(notExpected[i])); 139 } 140 } 141 } 142 143 private int[] generateRandomInput(final int size, final int[] hitters) { 144 final Random random = new Random(); 145 final Random random2 = new Random(); 146 final int[] input = new int[size]; 147 // 80% of them would be hitters, 20% will be random numbers 148 final int numOfRandoms = size / 5; 149 final int numOfHitters = size - numOfRandoms; 150 for (int i = 0, j = 0, m = numOfRandoms, n = numOfHitters; i < size; i++) { 151 int r = m > 0 && n > 0 ? random2.nextInt(size) : (m > 0 ? 0 : numOfRandoms); 152 if (r < numOfRandoms) { 153 input[i] = random.nextInt(size); 154 m--; 155 } else { 156 input[i] = hitters[j++]; 157 if (j == hitters.length) { 158 j = 0; 159 } 160 n--; 161 } 162 } 163 return input; 164 } 165 166 public void testPositive() throws Exception { 167 // Simple case 168 verify(new int[]{2, 9, 9, 9, 7, 6, 4, 9, 9, 9, 3, 9}, 2, new int[]{9}, 169 new float[]{0.583333f}); 170 171 // Two heavy hitters 172 verify(new int[]{2, 3, 9, 3, 9, 3, 9, 7, 6, 4, 9, 9, 3, 9, 3, 9}, 3, new int[]{9, 3}, 173 new float[]{0.4375f, 0.3125f}); 174 175 // Create a random data set and insert some numbers 176 final int[] input = generateRandomInput(100, 177 new int[]{1001, 1002, 1002, 1003, 1003, 1003, 1004, 1004, 1004, 1004}); 178 verify(input, 12, new int[]{1004, 1003, 1002, 1001}, 179 new float[]{0.32f, 0.24f, 0.16f, 0.08f}); 180 } 181 182 public void testNegative() throws Exception { 183 // Simple case 184 verifyNotExpected(new int[]{2, 9, 9, 9, 7, 6, 4, 9, 9, 9, 3, 9}, 2, new int[]{0, 1, 2}); 185 186 // Two heavy hitters 187 verifyNotExpected(new int[]{2, 3, 9, 3, 9, 3, 9, 7, 6, 4, 9, 9, 3, 9, 3, 9}, 3, 188 new int[]{0, 1, 2}); 189 190 // Create a random data set and insert some numbers 191 final int[] input = generateRandomInput(100, 192 new int[]{1001, 1002, 1002, 1003, 1003, 1003, 1004, 1004, 1004, 1004}); 193 verifyNotExpected(input, 12, new int[]{0, 1, 2, 1000, 1005}); 194 } 195 196 public void testFalsePositive() throws Exception { 197 // Simple case 198 verifyNotExpected(new int[]{2, 9, 2, 2, 7, 6, 4, 9, 9, 9, 3, 9}, 2, new int[]{9}); 199 200 // One heavy hitter 201 verifyNotExpected(new int[]{2, 3, 9, 3, 9, 3, 9, 7, 6, 4, 9, 9, 3, 9, 2, 9}, 3, 202 new int[]{3}); 203 204 // Create a random data set and insert some numbers 205 final int[] input = generateRandomInput(100, 206 new int[]{1001, 1002, 1002, 1003, 1003, 1003, 1004, 1004, 1004, 1004}); 207 verifyNotExpected(input, 11, new int[]{1001}); 208 } 209 } 210