1 /* 2 * Copyright (C) 2019 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.net.module.util 18 19 import java.util.concurrent.TimeUnit 20 import java.util.concurrent.locks.Condition 21 import java.util.concurrent.locks.ReentrantLock 22 import java.util.concurrent.locks.StampedLock 23 import kotlin.concurrent.withLock 24 25 /** 26 * A List that additionally offers the ability to append via the add() method, and to retrieve 27 * an element by its index optionally waiting for it to become available. 28 */ 29 interface TrackRecord<E> : List<E> { 30 /** 31 * Adds an element to this queue, waking up threads waiting for one. Returns true, as 32 * per the contract for List. 33 */ 34 fun add(e: E): Boolean 35 36 /** 37 * Returns the first element after {@param pos}, possibly blocking until one is available, or 38 * null if no such element can be found within the timeout. 39 * If a predicate is given, only elements matching the predicate are returned. 40 * 41 * @param timeoutMs how long, in milliseconds, to wait at most (best effort approximation). 42 * @param pos the position at which to start polling. 43 * @param predicate an optional predicate to filter elements to be returned. 44 * @return an element matching the predicate, or null if timeout. 45 */ 46 fun poll(timeoutMs: Long, pos: Int, predicate: (E) -> Boolean = { true }): E? 47 } 48 49 /** 50 * A thread-safe implementation of TrackRecord that is backed by an ArrayList. 51 * 52 * This class also supports the creation of a read-head for easier single-thread access. 53 * Refer to the documentation of {@link ArrayTrackRecord.ReadHead}. 54 */ 55 class ArrayTrackRecord<E> : TrackRecord<E> { 56 private val lock = ReentrantLock() 57 private val condition = lock.newCondition() 58 // Backing store. This stores the elements in this ArrayTrackRecord. 59 private val elements = ArrayList<E>() 60 61 // The list iterator for RecordingQueue iterates over a snapshot of the collection at the 62 // time the operator is created. Because TrackRecord is only ever mutated by appending, 63 // that makes this iterator thread-safe as it sees an effectively immutable List. 64 class ArrayTrackRecordIterator<E>( 65 private val list: ArrayList<E>, 66 start: Int, 67 private val end: Int 68 ) : ListIterator<E> { 69 var index = start 70 override fun hasNext() = index < end 71 override fun next() = list[index++] 72 override fun hasPrevious() = index > 0 73 override fun nextIndex() = index + 1 74 override fun previous() = list[--index] 75 override fun previousIndex() = index - 1 76 } 77 78 // List<E> implementation 79 override val size get() = lock.withLock { elements.size } 80 override fun contains(element: E) = lock.withLock { elements.contains(element) } 81 override fun containsAll(elements: Collection<E>) = lock.withLock { 82 this.elements.containsAll(elements) 83 } 84 override operator fun get(index: Int) = lock.withLock { elements[index] } 85 override fun indexOf(element: E): Int = lock.withLock { elements.indexOf(element) } 86 override fun lastIndexOf(element: E): Int = lock.withLock { elements.lastIndexOf(element) } 87 override fun isEmpty() = lock.withLock { elements.isEmpty() } 88 override fun listIterator(index: Int) = ArrayTrackRecordIterator(elements, index, size) 89 override fun listIterator() = listIterator(0) 90 override fun iterator() = listIterator() 91 override fun subList(fromIndex: Int, toIndex: Int): List<E> = lock.withLock { 92 elements.subList(fromIndex, toIndex) 93 } 94 95 // TrackRecord<E> implementation 96 override fun add(e: E): Boolean { 97 lock.withLock { 98 elements.add(e) 99 condition.signalAll() 100 } 101 return true 102 } 103 override fun poll(timeoutMs: Long, pos: Int, predicate: (E) -> Boolean) = lock.withLock { 104 elements.getOrNull(pollForIndexReadLocked(timeoutMs, pos, predicate)) 105 } 106 107 // For convenience 108 fun getOrNull(pos: Int, predicate: (E) -> Boolean) = lock.withLock { 109 if (pos < 0 || pos > size) null else elements.subList(pos, size).find(predicate) 110 } 111 112 // Returns the index of the next element whose position is >= pos matching the predicate, if 113 // necessary waiting until such a time that such an element is available, with a timeout. 114 // If no such element is found within the timeout -1 is returned. 115 private fun pollForIndexReadLocked(timeoutMs: Long, pos: Int, predicate: (E) -> Boolean): Int { 116 val deadline = System.currentTimeMillis() + timeoutMs 117 var index = pos 118 do { 119 while (index < elements.size) { 120 if (predicate(elements[index])) return index 121 ++index 122 } 123 } while (condition.await(deadline - System.currentTimeMillis())) 124 return -1 125 } 126 127 /** 128 * Returns a ReadHead over this ArrayTrackRecord. The returned ReadHead is tied to the 129 * current thread. 130 */ 131 fun newReadHead() = ReadHead() 132 133 /** 134 * ReadHead is an object that helps users of ArrayTrackRecord keep track of how far 135 * it has read this far in the ArrayTrackRecord. A ReadHead is always associated with 136 * a single instance of ArrayTrackRecord. Multiple ReadHeads can be created and used 137 * on the same instance of ArrayTrackRecord concurrently, and the ArrayTrackRecord 138 * instance can also be used concurrently. ReadHead maintains the current index that is 139 * the next to be read, and calls this the "mark". 140 * 141 * In a ReadHead, {@link poll(Long, (E) -> Boolean)} works similarly to a LinkedBlockingQueue. 142 * It can be called repeatedly and will return the elements as they arrive. 143 * 144 * Intended usage looks something like this : 145 * val TrackRecord<MyObject> record = ArrayTrackRecord().newReadHead() 146 * Thread().start { 147 * // do stuff 148 * record.add(something) 149 * // do stuff 150 * } 151 * 152 * val obj1 = record.poll(timeout) 153 * // do something with obj1 154 * val obj2 = record.poll(timeout) 155 * // do something with obj2 156 * 157 * The point is that the caller does not have to track the mark like it would have to if 158 * it was using ArrayTrackRecord directly. 159 * 160 * Thread safety : 161 * A ReadHead delegates all TrackRecord methods to its associated ArrayTrackRecord, and 162 * inherits its thread-safe properties for all the TrackRecord methods. 163 * 164 * Poll() operates under its own set of rules that only allow execution on multiple threads 165 * within constrained boundaries, and never concurrently or pseudo-concurrently. This is 166 * because concurrent calls to poll() fundamentally do not make sense. poll() will move 167 * the mark according to what events remained to be read by this read head, and therefore 168 * if multiple threads were calling poll() concurrently on the same ReadHead, what 169 * happens to the mark and the return values could not be useful because there is no way to 170 * provide either a guarantee not to skip objects nor a guarantee about the mark position at 171 * the exit of poll(). This is even more true in the presence of a predicate to filter 172 * returned elements, because one thread might be filtering out the events the other is 173 * interested in. For this reason, this class will fail-fast if any concurrent access is 174 * detected with ConcurrentAccessException. 175 * It is possible to use poll() on different threads as long as the following can be 176 * guaranteed : one thread must call poll() for the last time, then execute a write barrier, 177 * then the other thread must execute a read barrier before calling poll() for the first time. 178 * This allows in particular to call poll in @Before and @After methods in JUnit unit tests, 179 * because JUnit will enforce those barriers by creating the testing thread after executing 180 * @Before and joining the thread after executing @After. 181 * 182 * peek() can be used by multiple threads concurrently, but only if no thread is calling 183 * poll() outside of the boundaries above. For simplicity, it can be considered that peek() 184 * is safe to call only when poll() is safe to call. 185 * 186 * Polling concurrently from the same ArrayTrackRecord is supported by creating multiple 187 * ReadHeads on the same instance of ArrayTrackRecord (or of course by using ArrayTrackRecord 188 * directly). Each ReadHead is then guaranteed to see all events always and 189 * guarantees are made on the value of the mark upon return. {@see poll(Long, (E) -> Boolean)} 190 * for details. Be careful to create each ReadHead on the thread it is meant to be used on, or 191 * to have a clear synchronization point between creation and use. 192 * 193 * Users of a ReadHead can ask for the current position of the mark at any time, on a thread 194 * where it's safe to call peek(). This mark can be used later to replay the history of events 195 * either on this ReadHead, on the associated ArrayTrackRecord or on another ReadHead 196 * associated with the same ArrayTrackRecord. It might look like this in the reader thread : 197 * 198 * val markAtStart = record.mark 199 * // Start processing interesting events 200 * while (val element = record.poll(timeout) { it.isInteresting() }) { 201 * // Do something with element 202 * } 203 * // Look for stuff that happened while searching for interesting events 204 * val firstElementReceived = record.getOrNull(markAtStart) 205 * val firstSpecialElement = record.getOrNull(markAtStart) { it.isSpecial() } 206 * // Get the first special element since markAtStart, possibly blocking until one is available 207 * val specialElement = record.poll(timeout, markAtStart) { it.isSpecial() } 208 */ 209 inner class ReadHead : TrackRecord<E> by this@ArrayTrackRecord { 210 // This lock only controls access to the readHead member below. The ArrayTrackRecord 211 // object has its own synchronization following different (and more usual) semantics. 212 // See the comment on the ReadHead class for details. 213 private val slock = StampedLock() 214 private var readHead = 0 215 216 /** 217 * @return the current value of the mark. 218 */ 219 var mark 220 get() = checkThread { readHead } 221 set(v: Int) = rewind(v) 222 fun rewind(v: Int) { 223 val stamp = slock.tryWriteLock() 224 if (0L == stamp) concurrentAccessDetected() 225 readHead = v 226 slock.unlockWrite(stamp) 227 } 228 229 private fun <T> checkThread(r: (Long) -> T): T { 230 // tryOptimisticRead is a read barrier, guarantees writes from other threads are visible 231 // after it 232 val stamp = slock.tryOptimisticRead() 233 val result = r(stamp) 234 // validate also performs a read barrier, guaranteeing that if validate returns true, 235 // then any change either happens-before tryOptimisticRead, or happens-after validate. 236 if (!slock.validate(stamp)) concurrentAccessDetected() 237 return result 238 } 239 240 private fun concurrentAccessDetected(): Nothing { 241 throw ConcurrentModificationException( 242 "ReadHeads can't be used concurrently. Check your threading model.") 243 } 244 245 /** 246 * Returns the first element after the mark, optionally blocking until one is available, or 247 * null if no such element can be found within the timeout. 248 * If a predicate is given, only elements matching the predicate are returned. 249 * 250 * Upon return the mark will be set to immediately after the returned element, or after 251 * the last element in the queue if null is returned. This means this method will always 252 * skip elements that do not match the predicate, even if it returns null. 253 * 254 * This method can only be used by the thread that created this ManagedRecordingQueue. 255 * If used on another thread, this throws IllegalStateException. 256 * 257 * @param timeoutMs how long, in milliseconds, to wait at most (best effort approximation). 258 * @param predicate an optional predicate to filter elements to be returned. 259 * @return an element matching the predicate, or null if timeout. 260 */ 261 fun poll(timeoutMs: Long, predicate: (E) -> Boolean = { true }): E? { 262 val stamp = slock.tryWriteLock() 263 if (0L == stamp) concurrentAccessDetected() 264 try { 265 lock.withLock { 266 val index = pollForIndexReadLocked(timeoutMs, readHead, predicate) 267 readHead = if (index < 0) size else index + 1 268 return getOrNull(index) 269 } 270 } finally { 271 slock.unlockWrite(stamp) 272 } 273 } 274 275 /** 276 * Returns the first element after the mark or null. This never blocks. 277 * 278 * This method is subject to threading restrictions. It can be used concurrently on 279 * multiple threads but not if any other thread might be executing poll() at the same 280 * time. See the class comment for details. 281 */ 282 fun peek(): E? = checkThread { getOrNull(readHead) } 283 } 284 } 285 286 // Private helper 287 private fun Condition.await(timeoutMs: Long) = this.await(timeoutMs, TimeUnit.MILLISECONDS) 288