Move utilities to libs/net
These files used to be in the network stack directory, but the libs
directory is a much more suitable place for them.
Also fix a typo : ConcurrentIntepreter → ConcurrentInterpreter
Also move {FdEvents,Packet}Reader to internal annotations. That's
what they should have been using in the first place anyway.
Note that this does not fix preupload issues reported by
checkstyle to make review easier. The fixes are in a followup
patch to this one.
Test: checkbuild
Change-Id: I675077fd42cbb092c0e6bd56571f2fc022e582fd
diff --git a/staticlibs/devicetests/com/android/testutils/ConcurrentInterpreter.kt b/staticlibs/devicetests/com/android/testutils/ConcurrentInterpreter.kt
new file mode 100644
index 0000000..98464eb
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/ConcurrentInterpreter.kt
@@ -0,0 +1,188 @@
+package com.android.testutils
+
+import android.os.SystemClock
+import java.util.concurrent.CyclicBarrier
+import kotlin.system.measureTimeMillis
+import kotlin.test.assertEquals
+import kotlin.test.assertFails
+import kotlin.test.assertNull
+import kotlin.test.assertTrue
+
+// The table contains pairs associating a regexp with the code to run. The statement is matched
+// against each matcher in sequence and when a match is found the associated code is run, passing
+// it the TrackRecord under test and the result of the regexp match.
+typealias InterpretMatcher<T> = Pair<Regex, (ConcurrentInterpreter<T>, T, MatchResult) -> Any?>
+
+// The default unit of time for interpreted tests
+val INTERPRET_TIME_UNIT = 40L // ms
+
+/**
+ * A small interpreter for testing parallel code. The interpreter will read a list of lines
+ * consisting of "|"-separated statements. Each column runs in a different concurrent thread
+ * and all threads wait for each other in between lines. Each statement is split on ";" then
+ * matched with regular expressions in the instructionTable constant, which contains the
+ * code associated with each statement. The interpreter supports an object being passed to
+ * the interpretTestSpec() method to be passed in each lambda (think about the object under
+ * test), and an optional transform function to be executed on the object at the start of
+ * every thread.
+ *
+ * The time unit is defined in milliseconds by the interpretTimeUnit member, which has a default
+ * value but can be passed to the constructor. Whitespace is ignored.
+ *
+ * The interpretation table has to be passed as an argument. It's a table associating a regexp
+ * with the code that should execute, as a function taking three arguments : the interpreter,
+ * the regexp match, and the object. See the individual tests for the DSL of that test.
+ * Implementors for new interpreting languages are encouraged to look at the defaultInterpretTable
+ * constant below for an example of how to write an interpreting table.
+ * Some expressions already exist by default and can be used by all interpreters. They include :
+ * sleep(x) : sleeps for x time units and returns Unit ; sleep alone means sleep(1)
+ * EXPR = VALUE : asserts that EXPR equals VALUE. EXPR is interpreted. VALUE can either be the
+ * string "null" or an int. Returns Unit.
+ * EXPR time x..y : measures the time taken by EXPR and asserts it took at least x and at most
+ * y time units.
+ * EXPR // any text : comments are ignored.
+ */
+open class ConcurrentInterpreter<T>(
+ localInterpretTable: List<InterpretMatcher<T>>,
+ val interpretTimeUnit: Long = INTERPRET_TIME_UNIT
+) {
+ private val interpretTable: List<InterpretMatcher<T>> =
+ localInterpretTable + getDefaultInstructions()
+
+ // Split the line into multiple statements separated by ";" and execute them. Return whatever
+ // the last statement returned.
+ fun interpretMultiple(instr: String, r: T): Any? {
+ return instr.split(";").map { interpret(it.trim(), r) }.last()
+ }
+
+ // Match the statement to a regex and interpret it.
+ fun interpret(instr: String, r: T): Any? {
+ val (matcher, code) =
+ interpretTable.find { instr matches it.first } ?: throw SyntaxException(instr)
+ val match = matcher.matchEntire(instr) ?: throw SyntaxException(instr)
+ return code(this, r, match)
+ }
+
+ // Spins as many threads as needed by the test spec and interpret each program concurrently,
+ // having all threads waiting on a CyclicBarrier after each line.
+ // |lineShift| says how many lines after the call the spec starts. This is used for error
+ // reporting. Unfortunately AFAICT there is no way to get the line of an argument rather
+ // than the line at which the expression starts.
+ fun interpretTestSpec(
+ spec: String,
+ initial: T,
+ lineShift: Int = 0,
+ threadTransform: (T) -> T = { it }
+ ) {
+ // For nice stack traces
+ val callSite = getCallingMethod()
+ val lines = spec.trim().trim('\n').split("\n").map { it.split("|") }
+ // |threads| contains arrays of strings that make up the statements of a thread : in other
+ // words, it's an array that contains a list of statements for each column in the spec.
+ val threadCount = lines[0].size
+ assertTrue(lines.all { it.size == threadCount })
+ val threadInstructions = (0 until threadCount).map { i -> lines.map { it[i].trim() } }
+ val barrier = CyclicBarrier(threadCount)
+ var crash: InterpretException? = null
+ threadInstructions.mapIndexed { threadIndex, instructions ->
+ Thread {
+ val threadLocal = threadTransform(initial)
+ barrier.await()
+ var lineNum = 0
+ instructions.forEach {
+ if (null != crash) return@Thread
+ lineNum += 1
+ try {
+ interpretMultiple(it, threadLocal)
+ } catch (e: Throwable) {
+ // If fail() or some exception was called, the thread will come here ; if
+ // the exception isn't caught the process will crash, which is not nice for
+ // testing. Instead, catch the exception, cancel other threads, and report
+ // nicely. Catch throwable because fail() is AssertionError, which inherits
+ // from Error.
+ crash = InterpretException(threadIndex, it,
+ callSite.lineNumber + lineNum + lineShift,
+ callSite.className, callSite.methodName, callSite.fileName, e)
+ }
+ barrier.await()
+ }
+ }.also { it.start() }
+ }.forEach { it.join() }
+ // If the test failed, crash with line number
+ crash?.let { throw it }
+ }
+
+ // Helper to get the stack trace for a calling method
+ private fun getCallingStackTrace(): Array<StackTraceElement> {
+ try {
+ throw RuntimeException()
+ } catch (e: RuntimeException) {
+ return e.stackTrace
+ }
+ }
+
+ // Find the calling method. This is the first method in the stack trace that is annotated
+ // with @Test.
+ fun getCallingMethod(): StackTraceElement {
+ val stackTrace = getCallingStackTrace()
+ return stackTrace.find { element ->
+ val clazz = Class.forName(element.className)
+ // Because the stack trace doesn't list the formal arguments, find all methods with
+ // this name and return this name if any of them is annotated with @Test.
+ clazz.declaredMethods
+ .filter { method -> method.name == element.methodName }
+ .any { method -> method.getAnnotation(org.junit.Test::class.java) != null }
+ } ?: stackTrace[3]
+ // If no method is annotated return the 4th one, because that's what it usually is :
+ // 0 is getCallingStackTrace, 1 is this method, 2 is ConcurrentInterpreter#interpretTestSpec
+ }
+}
+
+private fun <T> getDefaultInstructions() = listOf<InterpretMatcher<T>>(
+ // Interpret an empty line as doing nothing.
+ Regex("") to { _, _, _ -> null },
+ // Ignore comments.
+ Regex("(.*)//.*") to { i, t, r -> i.interpret(r.strArg(1), t) },
+ // Interpret "XXX time x..y" : run XXX and check it took at least x and not more than y
+ Regex("""(.*)\s*time\s*(\d+)\.\.(\d+)""") to { i, t, r ->
+ val time = measureTimeMillis { i.interpret(r.strArg(1), t) }
+ assertTrue(time in r.timeArg(2)..r.timeArg(3), "$time not in ${r.timeArg(2)..r.timeArg(3)}")
+ },
+ // Interpret "XXX = YYY" : run XXX and assert its return value is equal to YYY. "null" supported
+ Regex("""(.*)\s*=\s*(null|\d+)""") to { i, t, r ->
+ i.interpret(r.strArg(1), t).also {
+ if ("null" == r.strArg(2)) assertNull(it) else assertEquals(r.intArg(2), it)
+ }
+ },
+ // Interpret sleep. Optional argument for the count, in INTERPRET_TIME_UNIT units.
+ Regex("""sleep(\((\d+)\))?""") to { i, t, r ->
+ SystemClock.sleep(if (r.strArg(2).isEmpty()) i.interpretTimeUnit else r.timeArg(2))
+ },
+ Regex("""(.*)\s*fails""") to { i, t, r ->
+ assertFails { i.interpret(r.strArg(1), t) }
+ }
+)
+
+class SyntaxException(msg: String, cause: Throwable? = null) : RuntimeException(msg, cause)
+class InterpretException(
+ threadIndex: Int,
+ instr: String,
+ lineNum: Int,
+ className: String,
+ methodName: String,
+ fileName: String,
+ cause: Throwable
+) : RuntimeException("Failure: $instr", cause) {
+ init {
+ stackTrace = arrayOf(StackTraceElement(
+ className,
+ "$methodName:thread$threadIndex",
+ fileName,
+ lineNum)) + super.getStackTrace()
+ }
+}
+
+// Some small helpers to avoid to say the large ".groupValues[index].trim()" every time
+fun MatchResult.strArg(index: Int) = this.groupValues[index].trim()
+fun MatchResult.intArg(index: Int) = strArg(index).toInt()
+fun MatchResult.timeArg(index: Int) = INTERPRET_TIME_UNIT * intArg(index)
diff --git a/staticlibs/devicetests/com/android/testutils/DevSdkIgnoreRule.kt b/staticlibs/devicetests/com/android/testutils/DevSdkIgnoreRule.kt
new file mode 100644
index 0000000..4a83f6f
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/DevSdkIgnoreRule.kt
@@ -0,0 +1,90 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.os.Build
+import org.junit.Assume.assumeTrue
+import org.junit.rules.TestRule
+import org.junit.runner.Description
+import org.junit.runners.model.Statement
+
+/**
+ * Returns true if the development SDK version of the device is in the provided range.
+ *
+ * If the device is not using a release SDK, the development SDK is considered to be higher than
+ * [Build.VERSION.SDK_INT].
+ */
+fun isDevSdkInRange(minExclusive: Int?, maxInclusive: Int?): Boolean {
+ // In-development API n+1 will have SDK_INT == n and CODENAME != REL.
+ // Stable API n has SDK_INT == n and CODENAME == REL.
+ val release = "REL" == Build.VERSION.CODENAME
+ val sdkInt = Build.VERSION.SDK_INT
+ val devApiLevel = sdkInt + if (release) 0 else 1
+
+ return (minExclusive == null || devApiLevel > minExclusive) &&
+ (maxInclusive == null || devApiLevel <= maxInclusive)
+}
+
+/**
+ * A test rule to ignore tests based on the development SDK level.
+ *
+ * If the device is not using a release SDK, the development SDK is considered to be higher than
+ * [Build.VERSION.SDK_INT].
+ *
+ * @param ignoreClassUpTo Skip all tests in the class if the device dev SDK is <= this value.
+ * @param ignoreClassAfter Skip all tests in the class if the device dev SDK is > this value.
+ */
+class DevSdkIgnoreRule @JvmOverloads constructor(
+ private val ignoreClassUpTo: Int? = null,
+ private val ignoreClassAfter: Int? = null
+) : TestRule {
+ override fun apply(base: Statement, description: Description): Statement {
+ return IgnoreBySdkStatement(base, description)
+ }
+
+ /**
+ * Ignore the test for any development SDK that is strictly after [value].
+ *
+ * If the device is not using a release SDK, the development SDK is considered to be higher
+ * than [Build.VERSION.SDK_INT].
+ */
+ annotation class IgnoreAfter(val value: Int)
+
+ /**
+ * Ignore the test for any development SDK that lower than or equal to [value].
+ *
+ * If the device is not using a release SDK, the development SDK is considered to be higher
+ * than [Build.VERSION.SDK_INT].
+ */
+ annotation class IgnoreUpTo(val value: Int)
+
+ private inner class IgnoreBySdkStatement(
+ private val base: Statement,
+ private val description: Description
+ ) : Statement() {
+ override fun evaluate() {
+ val ignoreAfter = description.getAnnotation(IgnoreAfter::class.java)
+ val ignoreUpTo = description.getAnnotation(IgnoreUpTo::class.java)
+
+ val message = "Skipping test for build ${Build.VERSION.CODENAME} " +
+ "with SDK ${Build.VERSION.SDK_INT}"
+ assumeTrue(message, isDevSdkInRange(ignoreClassUpTo, ignoreClassAfter))
+ assumeTrue(message, isDevSdkInRange(ignoreUpTo?.value, ignoreAfter?.value))
+ base.evaluate()
+ }
+ }
+}
\ No newline at end of file
diff --git a/staticlibs/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt b/staticlibs/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt
new file mode 100644
index 0000000..73b2843
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/DevSdkIgnoreRunner.kt
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import androidx.test.ext.junit.runners.AndroidJUnit4
+import com.android.testutils.DevSdkIgnoreRule.IgnoreAfter
+import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
+import org.junit.runner.Description
+import org.junit.runner.Runner
+import org.junit.runner.notification.RunNotifier
+
+/**
+ * A runner that can skip tests based on the development SDK as defined in [DevSdkIgnoreRule].
+ *
+ * Generally [DevSdkIgnoreRule] should be used for that purpose (using rules is preferable over
+ * replacing the test runner), however JUnit runners inspect all methods in the test class before
+ * processing test rules. This may cause issues if the test methods are referencing classes that do
+ * not exist on the SDK of the device the test is run on.
+ *
+ * This runner inspects [IgnoreAfter] and [IgnoreUpTo] annotations on the test class, and will skip
+ * the whole class if they do not match the development SDK as defined in [DevSdkIgnoreRule].
+ * Otherwise, it will delegate to [AndroidJUnit4] to run the test as usual.
+ *
+ * Example usage:
+ *
+ * @RunWith(DevSdkIgnoreRunner::class)
+ * @IgnoreUpTo(Build.VERSION_CODES.Q)
+ * class MyTestClass { ... }
+ */
+class DevSdkIgnoreRunner(private val klass: Class<*>) : Runner() {
+ private val baseRunner = klass.let {
+ val ignoreAfter = it.getAnnotation(IgnoreAfter::class.java)
+ val ignoreUpTo = it.getAnnotation(IgnoreUpTo::class.java)
+
+ if (isDevSdkInRange(ignoreUpTo?.value, ignoreAfter?.value)) AndroidJUnit4(klass) else null
+ }
+
+ override fun run(notifier: RunNotifier) {
+ if (baseRunner != null) {
+ baseRunner.run(notifier)
+ return
+ }
+
+ // Report a single, skipped placeholder test for this class, so that the class is still
+ // visible as skipped in test results.
+ notifier.fireTestIgnored(
+ Description.createTestDescription(klass, "skippedClassForDevSdkMismatch"))
+ }
+
+ override fun getDescription(): Description {
+ return baseRunner?.description ?: Description.createSuiteDescription(klass)
+ }
+
+ override fun testCount(): Int {
+ // When ignoring the tests, a skipped placeholder test is reported, so test count is 1.
+ return baseRunner?.testCount() ?: 1
+ }
+}
\ No newline at end of file
diff --git a/staticlibs/devicetests/com/android/testutils/FakeDns.kt b/staticlibs/devicetests/com/android/testutils/FakeDns.kt
new file mode 100644
index 0000000..825d748
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/FakeDns.kt
@@ -0,0 +1,94 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.DnsResolver
+import android.net.InetAddresses
+import android.os.Looper
+import android.os.Handler
+import com.android.internal.annotations.GuardedBy
+import java.net.InetAddress
+import java.util.concurrent.Executor
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.Mockito.any
+import org.mockito.Mockito.anyInt
+import org.mockito.Mockito.doAnswer
+
+const val TYPE_UNSPECIFIED = -1
+// TODO: Integrate with NetworkMonitorTest.
+class FakeDns(val mockResolver: DnsResolver) {
+ class DnsEntry(val hostname: String, val type: Int, val addresses: List<InetAddress>) {
+ fun match(host: String, type: Int) = hostname.equals(host) && type == type
+ }
+
+ @GuardedBy("answers")
+ val answers = ArrayList<DnsEntry>()
+
+ fun getAnswer(hostname: String, type: Int): DnsEntry? = synchronized(answers) {
+ return answers.firstOrNull { it.match(hostname, type) }
+ }
+
+ fun setAnswer(hostname: String, answer: Array<String>, type: Int) = synchronized(answers) {
+ val ans = DnsEntry(hostname, type, generateAnswer(answer))
+ // Replace or remove the existing one.
+ when (val index = answers.indexOfFirst { it.match(hostname, type) }) {
+ -1 -> answers.add(ans)
+ else -> answers[index] = ans
+ }
+ }
+
+ private fun generateAnswer(answer: Array<String>) =
+ answer.filterNotNull().map { InetAddresses.parseNumericAddress(it) }
+
+ fun startMocking() {
+ // Mock DnsResolver.query() w/o type
+ doAnswer {
+ mockAnswer(it, 1, -1, 3, 5)
+ }.`when`(mockResolver).query(any() /* network */, any() /* domain */, anyInt() /* flags */,
+ any() /* executor */, any() /* cancellationSignal */, any() /*callback*/)
+ // Mock DnsResolver.query() w/ type
+ doAnswer {
+ mockAnswer(it, 1, 2, 4, 6)
+ }.`when`(mockResolver).query(any() /* network */, any() /* domain */, anyInt() /* nsType */,
+ anyInt() /* flags */, any() /* executor */, any() /* cancellationSignal */,
+ any() /*callback*/)
+ }
+
+ private fun mockAnswer(
+ it: InvocationOnMock,
+ posHos: Int,
+ posType: Int,
+ posExecutor: Int,
+ posCallback: Int
+ ) {
+ val hostname = it.arguments[posHos] as String
+ val executor = it.arguments[posExecutor] as Executor
+ val callback = it.arguments[posCallback] as DnsResolver.Callback<List<InetAddress>>
+ var type = if (posType != -1) it.arguments[posType] as Int else TYPE_UNSPECIFIED
+ val answer = getAnswer(hostname, type)
+
+ if (!answer?.addresses.isNullOrEmpty()) {
+ Handler(Looper.getMainLooper()).post({ executor.execute({
+ callback.onAnswer(answer?.addresses, 0); }) })
+ }
+ }
+
+ /** Clears all entries. */
+ fun clearAll() = synchronized(answers) {
+ answers.clear()
+ }
+}
diff --git a/staticlibs/devicetests/com/android/testutils/HandlerUtils.kt b/staticlibs/devicetests/com/android/testutils/HandlerUtils.kt
new file mode 100644
index 0000000..fa36209
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/HandlerUtils.kt
@@ -0,0 +1,49 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.os.ConditionVariable
+import android.os.Handler
+import android.os.HandlerThread
+import java.util.concurrent.Executor
+import kotlin.system.measureTimeMillis
+import kotlin.test.fail
+
+/**
+ * Block until the specified Handler or HandlerThread becomes idle, or until timeoutMs has passed.
+ */
+fun HandlerThread.waitForIdle(timeoutMs: Int) = threadHandler.waitForIdle(timeoutMs.toLong())
+fun HandlerThread.waitForIdle(timeoutMs: Long) = threadHandler.waitForIdle(timeoutMs)
+fun Handler.waitForIdle(timeoutMs: Int) = waitForIdle(timeoutMs.toLong())
+fun Handler.waitForIdle(timeoutMs: Long) {
+ val cv = ConditionVariable(false)
+ post(cv::open)
+ if (!cv.block(timeoutMs)) {
+ fail("Handler did not become idle after ${timeoutMs}ms")
+ }
+}
+
+/**
+ * Block until the given Serial Executor becomes idle, or until timeoutMs has passed.
+ */
+fun waitForIdleSerialExecutor(executor: Executor, timeoutMs: Long) {
+ val cv = ConditionVariable()
+ executor.execute(cv::open)
+ if (!cv.block(timeoutMs)) {
+ fail("Executor did not become idle after ${timeoutMs}ms")
+ }
+}
diff --git a/staticlibs/devicetests/com/android/testutils/NetworkStatsUtils.kt b/staticlibs/devicetests/com/android/testutils/NetworkStatsUtils.kt
new file mode 100644
index 0000000..8324b25
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/NetworkStatsUtils.kt
@@ -0,0 +1,78 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.NetworkStats
+import kotlin.test.assertTrue
+
+@JvmOverloads
+fun orderInsensitiveEquals(
+ leftStats: NetworkStats,
+ rightStats: NetworkStats,
+ compareTime: Boolean = false
+): Boolean {
+ if (leftStats == rightStats) return true
+ if (compareTime && leftStats.getElapsedRealtime() != rightStats.getElapsedRealtime()) {
+ return false
+ }
+
+ // While operations such as add/subtract will preserve empty entries. This will make
+ // the result be hard to verify during test. Remove them before comparing since they
+ // are not really affect correctness.
+ // TODO (b/152827872): Remove empty entries after addition/subtraction.
+ val leftTrimmedEmpty = leftStats.removeEmptyEntries()
+ val rightTrimmedEmpty = rightStats.removeEmptyEntries()
+
+ if (leftTrimmedEmpty.size() != rightTrimmedEmpty.size()) return false
+ val left = NetworkStats.Entry()
+ val right = NetworkStats.Entry()
+ // Order insensitive compare.
+ for (i in 0 until leftTrimmedEmpty.size()) {
+ leftTrimmedEmpty.getValues(i, left)
+ val j: Int = rightTrimmedEmpty.findIndexHinted(left.iface, left.uid, left.set, left.tag,
+ left.metered, left.roaming, left.defaultNetwork, i)
+ if (j == -1) return false
+ rightTrimmedEmpty.getValues(j, right)
+ if (left != right) return false
+ }
+ return true
+}
+
+/**
+ * Assert that two {@link NetworkStats} are equals, assuming the order of the records are not
+ * necessarily the same.
+ *
+ * @note {@code elapsedRealtime} is not compared by default, given that in test cases that is not
+ * usually used.
+ */
+@JvmOverloads
+fun assertNetworkStatsEquals(
+ expected: NetworkStats,
+ actual: NetworkStats,
+ compareTime: Boolean = false
+) {
+ assertTrue(orderInsensitiveEquals(expected, actual, compareTime),
+ "expected: " + expected + " but was: " + actual)
+}
+
+/**
+ * Assert that after being parceled then unparceled, {@link NetworkStats} is equal to the original
+ * object.
+ */
+fun assertParcelingIsLossless(stats: NetworkStats) {
+ assertParcelingIsLossless(stats, { a, b -> orderInsensitiveEquals(a, b) })
+}
diff --git a/staticlibs/devicetests/com/android/testutils/ParcelUtils.kt b/staticlibs/devicetests/com/android/testutils/ParcelUtils.kt
new file mode 100644
index 0000000..5784f7c
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/ParcelUtils.kt
@@ -0,0 +1,68 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.os.Parcel
+import android.os.Parcelable
+import kotlin.test.assertTrue
+import kotlin.test.fail
+
+/**
+ * Return a new instance of `T` after being parceled then unparceled.
+ */
+fun <T : Parcelable> parcelingRoundTrip(source: T): T {
+ val creator: Parcelable.Creator<T>
+ try {
+ creator = source.javaClass.getField("CREATOR").get(null) as Parcelable.Creator<T>
+ } catch (e: IllegalAccessException) {
+ fail("Missing CREATOR field: " + e.message)
+ } catch (e: NoSuchFieldException) {
+ fail("Missing CREATOR field: " + e.message)
+ }
+
+ var p = Parcel.obtain()
+ source.writeToParcel(p, /* flags */ 0)
+ p.setDataPosition(0)
+ val marshalled = p.marshall()
+ p = Parcel.obtain()
+ p.unmarshall(marshalled, 0, marshalled.size)
+ p.setDataPosition(0)
+ return creator.createFromParcel(p)
+}
+
+/**
+ * Assert that after being parceled then unparceled, `source` is equal to the original
+ * object. If a customized equals function is provided, uses the provided one.
+ */
+@JvmOverloads
+fun <T : Parcelable> assertParcelingIsLossless(
+ source: T,
+ equals: (T, T) -> Boolean = { a, b -> a == b }
+) {
+ val actual = parcelingRoundTrip(source)
+ assertTrue(equals(source, actual), "Expected $source, but was $actual")
+}
+
+@JvmOverloads
+fun <T : Parcelable> assertParcelSane(
+ obj: T,
+ fieldCount: Int,
+ equals: (T, T) -> Boolean = { a, b -> a == b }
+) {
+ assertFieldCountEquals(fieldCount, obj::class.java)
+ assertParcelingIsLossless(obj, equals)
+}
diff --git a/staticlibs/devicetests/com/android/testutils/TapPacketReader.java b/staticlibs/devicetests/com/android/testutils/TapPacketReader.java
new file mode 100644
index 0000000..fea1cd9
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/TapPacketReader.java
@@ -0,0 +1,93 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils;
+
+import android.net.util.PacketReader;
+import android.os.Handler;
+
+import androidx.annotation.NonNull;
+import androidx.annotation.Nullable;
+
+import com.android.net.module.util.ArrayTrackRecord;
+
+import java.io.FileDescriptor;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.function.Predicate;
+
+import kotlin.Lazy;
+import kotlin.LazyKt;
+
+public class TapPacketReader extends PacketReader {
+ private final FileDescriptor mTapFd;
+ private final ArrayTrackRecord<byte[]> mReceivedPackets = new ArrayTrackRecord<>();
+ private final Lazy<ArrayTrackRecord<byte[]>.ReadHead> mReadHead =
+ LazyKt.lazy(mReceivedPackets::newReadHead);
+
+ public TapPacketReader(Handler h, FileDescriptor tapFd, int maxPacketSize) {
+ super(h, maxPacketSize);
+ mTapFd = tapFd;
+ }
+
+ @Override
+ protected FileDescriptor createFd() {
+ return mTapFd;
+ }
+
+ @Override
+ protected void handlePacket(byte[] recvbuf, int length) {
+ final byte[] newPacket = Arrays.copyOf(recvbuf, length);
+ if (!mReceivedPackets.add(newPacket)) {
+ throw new AssertionError("More than " + Integer.MAX_VALUE + " packets outstanding!");
+ }
+ }
+
+ /**
+ * Get the next packet that was received on the interface.
+ */
+ @Nullable
+ public byte[] popPacket(long timeoutMs) {
+ return mReadHead.getValue().poll(timeoutMs, packet -> true);
+ }
+
+ /**
+ * Get the next packet that was received on the interface and matches the specified filter.
+ */
+ @Nullable
+ public byte[] popPacket(long timeoutMs, @NonNull Predicate<byte[]> filter) {
+ return mReadHead.getValue().poll(timeoutMs, filter::test);
+ }
+
+ /**
+ * Get the {@link ArrayTrackRecord} that records all packets received by the reader since its
+ * creation.
+ */
+ public ArrayTrackRecord<byte[]> getReceivedPackets() {
+ return mReceivedPackets;
+ }
+
+ public void sendResponse(final ByteBuffer packet) throws IOException {
+ try (FileOutputStream out = new FileOutputStream(mTapFd)) {
+ byte[] packetBytes = new byte[packet.limit()];
+ packet.get(packetBytes);
+ packet.flip(); // So we can reuse it in the future.
+ out.write(packetBytes);
+ }
+ }
+}
diff --git a/staticlibs/devicetests/com/android/testutils/TestNetworkTracker.kt b/staticlibs/devicetests/com/android/testutils/TestNetworkTracker.kt
new file mode 100644
index 0000000..4bd9ae8
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/TestNetworkTracker.kt
@@ -0,0 +1,92 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.content.Context
+import android.net.ConnectivityManager
+import android.net.ConnectivityManager.NetworkCallback
+import android.net.LinkAddress
+import android.net.Network
+import android.net.NetworkCapabilities
+import android.net.NetworkRequest
+import android.net.StringNetworkSpecifier
+import android.net.TestNetworkInterface
+import android.net.TestNetworkManager
+import android.os.Binder
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.TimeUnit
+
+/**
+ * Create a test network based on a TUN interface.
+ *
+ * This method will block until the test network is available. Requires
+ * [android.Manifest.permission.CHANGE_NETWORK_STATE] and
+ * [android.Manifest.permission.MANAGE_TEST_NETWORKS].
+ */
+fun initTestNetwork(context: Context, interfaceAddr: LinkAddress, setupTimeoutMs: Long = 10_000L):
+ TestNetworkTracker {
+ val tnm = context.getSystemService(TestNetworkManager::class.java)
+ val iface = tnm.createTunInterface(arrayOf(interfaceAddr))
+ return TestNetworkTracker(context, iface, tnm, setupTimeoutMs)
+}
+
+/**
+ * Utility class to create and track test networks.
+ *
+ * This class is not thread-safe.
+ */
+class TestNetworkTracker internal constructor(
+ val context: Context,
+ val iface: TestNetworkInterface,
+ tnm: TestNetworkManager,
+ setupTimeoutMs: Long
+) {
+ private val cm = context.getSystemService(ConnectivityManager::class.java)
+ private val binder = Binder()
+
+ private val networkCallback: NetworkCallback
+ val network: Network
+
+ init {
+ val networkFuture = CompletableFuture<Network>()
+ val networkRequest = NetworkRequest.Builder()
+ .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+ // Test networks do not have NOT_VPN or TRUSTED capabilities by default
+ .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)
+ .removeCapability(NetworkCapabilities.NET_CAPABILITY_TRUSTED)
+ .setNetworkSpecifier(StringNetworkSpecifier(iface.interfaceName))
+ .build()
+ networkCallback = object : NetworkCallback() {
+ override fun onAvailable(network: Network) {
+ networkFuture.complete(network)
+ }
+ }
+ cm.requestNetwork(networkRequest, networkCallback)
+
+ try {
+ tnm.setupTestNetwork(iface.interfaceName, binder)
+ network = networkFuture.get(setupTimeoutMs, TimeUnit.MILLISECONDS)
+ } catch (e: Throwable) {
+ teardown()
+ throw e
+ }
+ }
+
+ fun teardown() {
+ cm.unregisterNetworkCallback(networkCallback)
+ }
+}
\ No newline at end of file
diff --git a/staticlibs/devicetests/com/android/testutils/TestableNetworkCallback.kt b/staticlibs/devicetests/com/android/testutils/TestableNetworkCallback.kt
new file mode 100644
index 0000000..959a837
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/TestableNetworkCallback.kt
@@ -0,0 +1,393 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.ConnectivityManager.NetworkCallback
+import android.net.LinkProperties
+import android.net.Network
+import android.net.NetworkCapabilities
+import android.net.NetworkCapabilities.NET_CAPABILITY_VALIDATED
+import com.android.net.module.util.ArrayTrackRecord
+import com.android.testutils.RecorderCallback.CallbackEntry.Available
+import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
+import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
+import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
+import com.android.testutils.RecorderCallback.CallbackEntry.Losing
+import com.android.testutils.RecorderCallback.CallbackEntry.Lost
+import com.android.testutils.RecorderCallback.CallbackEntry.Resumed
+import com.android.testutils.RecorderCallback.CallbackEntry.Suspended
+import com.android.testutils.RecorderCallback.CallbackEntry.Unavailable
+import kotlin.reflect.KClass
+import kotlin.test.assertEquals
+import kotlin.test.assertNotNull
+import kotlin.test.assertTrue
+import kotlin.test.fail
+
+object NULL_NETWORK : Network(-1)
+object ANY_NETWORK : Network(-2)
+
+private val Int.capabilityName get() = NetworkCapabilities.capabilityNameOf(this)
+
+open class RecorderCallback private constructor(
+ private val backingRecord: ArrayTrackRecord<CallbackEntry>
+) : NetworkCallback() {
+ public constructor() : this(ArrayTrackRecord())
+ protected constructor(src: RecorderCallback?): this(src?.backingRecord ?: ArrayTrackRecord())
+
+ sealed class CallbackEntry {
+ // To get equals(), hashcode(), componentN() etc for free, the child classes of
+ // this class are data classes. But while data classes can inherit from other classes,
+ // they may only have visible members in the constructors, so they couldn't declare
+ // a constructor with a non-val arg to pass to CallbackEntry. Instead, force all
+ // subclasses to implement a `network' property, which can be done in a data class
+ // constructor by specifying override.
+ abstract val network: Network
+
+ data class Available(override val network: Network) : CallbackEntry()
+ data class CapabilitiesChanged(
+ override val network: Network,
+ val caps: NetworkCapabilities
+ ) : CallbackEntry()
+ data class LinkPropertiesChanged(
+ override val network: Network,
+ val lp: LinkProperties
+ ) : CallbackEntry()
+ data class Suspended(override val network: Network) : CallbackEntry()
+ data class Resumed(override val network: Network) : CallbackEntry()
+ data class Losing(override val network: Network, val maxMsToLive: Int) : CallbackEntry()
+ data class Lost(override val network: Network) : CallbackEntry()
+ data class Unavailable private constructor(
+ override val network: Network
+ ) : CallbackEntry() {
+ constructor() : this(NULL_NETWORK)
+ }
+ data class BlockedStatus(
+ override val network: Network,
+ val blocked: Boolean
+ ) : CallbackEntry()
+
+ // Convenience constants for expecting a type
+ companion object {
+ @JvmField
+ val AVAILABLE = Available::class
+ @JvmField
+ val NETWORK_CAPS_UPDATED = CapabilitiesChanged::class
+ @JvmField
+ val LINK_PROPERTIES_CHANGED = LinkPropertiesChanged::class
+ @JvmField
+ val SUSPENDED = Suspended::class
+ @JvmField
+ val RESUMED = Resumed::class
+ @JvmField
+ val LOSING = Losing::class
+ @JvmField
+ val LOST = Lost::class
+ @JvmField
+ val UNAVAILABLE = Unavailable::class
+ @JvmField
+ val BLOCKED_STATUS = BlockedStatus::class
+ }
+ }
+
+ val history = backingRecord.newReadHead()
+ val mark get() = history.mark
+
+ override fun onAvailable(network: Network) {
+ history.add(Available(network))
+ }
+
+ // PreCheck is not used in the tests today. For backward compatibility with existing tests that
+ // expect the callbacks not to record this, do not listen to PreCheck here.
+
+ override fun onCapabilitiesChanged(network: Network, caps: NetworkCapabilities) {
+ history.add(CapabilitiesChanged(network, caps))
+ }
+
+ override fun onLinkPropertiesChanged(network: Network, lp: LinkProperties) {
+ history.add(LinkPropertiesChanged(network, lp))
+ }
+
+ override fun onBlockedStatusChanged(network: Network, blocked: Boolean) {
+ history.add(BlockedStatus(network, blocked))
+ }
+
+ override fun onNetworkSuspended(network: Network) {
+ history.add(Suspended(network))
+ }
+
+ override fun onNetworkResumed(network: Network) {
+ history.add(Resumed(network))
+ }
+
+ override fun onLosing(network: Network, maxMsToLive: Int) {
+ history.add(Losing(network, maxMsToLive))
+ }
+
+ override fun onLost(network: Network) {
+ history.add(Lost(network))
+ }
+
+ override fun onUnavailable() {
+ history.add(Unavailable())
+ }
+}
+
+private const val DEFAULT_TIMEOUT = 200L // ms
+
+open class TestableNetworkCallback private constructor(
+ src: TestableNetworkCallback?,
+ val defaultTimeoutMs: Long = DEFAULT_TIMEOUT
+) : RecorderCallback(src) {
+ @JvmOverloads
+ constructor(timeoutMs: Long = DEFAULT_TIMEOUT): this(null, timeoutMs)
+
+ fun createLinkedCopy() = TestableNetworkCallback(this, defaultTimeoutMs)
+
+ // The last available network, or null if any network was lost since the last call to
+ // onAvailable. TODO : fix this by fixing the tests that rely on this behavior
+ val lastAvailableNetwork: Network?
+ get() = when (val it = history.lastOrNull { it is Available || it is Lost }) {
+ is Available -> it.network
+ else -> null
+ }
+
+ fun pollForNextCallback(timeoutMs: Long = defaultTimeoutMs): CallbackEntry {
+ return history.poll(timeoutMs) ?: fail("Did not receive callback after ${timeoutMs}ms")
+ }
+
+ // Make open for use in ConnectivityServiceTest which is the only one knowing its handlers.
+ @JvmOverloads
+ open fun assertNoCallback(timeoutMs: Long = defaultTimeoutMs) {
+ val cb = history.poll(timeoutMs)
+ if (null != cb) fail("Expected no callback but got $cb")
+ }
+
+ // Expects a callback of the specified type on the specified network within the timeout.
+ // If no callback arrives, or a different callback arrives, fail. Returns the callback.
+ inline fun <reified T : CallbackEntry> expectCallback(
+ network: Network = ANY_NETWORK,
+ timeoutMs: Long = defaultTimeoutMs
+ ): T = pollForNextCallback(timeoutMs).let {
+ if (it !is T || (ANY_NETWORK !== network && it.network != network)) {
+ fail("Unexpected callback : $it, expected ${T::class} with Network[$network]")
+ } else {
+ it
+ }
+ }
+
+ // Expects a callback of the specified type matching the predicate within the timeout.
+ // Any callback that doesn't match the predicate will be skipped. Fails only if
+ // no matching callback is received within the timeout.
+ inline fun <reified T : CallbackEntry> eventuallyExpect(
+ timeoutMs: Long = defaultTimeoutMs,
+ from: Int = mark,
+ crossinline predicate: (T) -> Boolean = { true }
+ ): T = eventuallyExpectOrNull(timeoutMs, from, predicate).also {
+ assertNotNull(it, "Callback ${T::class} not received within ${timeoutMs}ms")
+ } as T
+
+ // TODO (b/157405399) straighten and unify the method names
+ inline fun <reified T : CallbackEntry> eventuallyExpectOrNull(
+ timeoutMs: Long = defaultTimeoutMs,
+ from: Int = mark,
+ crossinline predicate: (T) -> Boolean = { true }
+ ) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T?
+
+ fun expectCallbackThat(
+ timeoutMs: Long = defaultTimeoutMs,
+ valid: (CallbackEntry) -> Boolean
+ ) = pollForNextCallback(timeoutMs).also { assertTrue(valid(it), "Unexpected callback : $it") }
+
+ fun expectCapabilitiesThat(
+ net: Network,
+ tmt: Long = defaultTimeoutMs,
+ valid: (NetworkCapabilities) -> Boolean
+ ): CapabilitiesChanged {
+ return expectCallback<CapabilitiesChanged>(net, tmt).also {
+ assertTrue(valid(it.caps), "Capabilities don't match expectations ${it.caps}")
+ }
+ }
+
+ fun expectLinkPropertiesThat(
+ net: Network,
+ tmt: Long = defaultTimeoutMs,
+ valid: (LinkProperties) -> Boolean
+ ): LinkPropertiesChanged {
+ return expectCallback<LinkPropertiesChanged>(net, tmt).also {
+ assertTrue(valid(it.lp), "LinkProperties don't match expectations ${it.lp}")
+ }
+ }
+
+ // Expects onAvailable and the callbacks that follow it. These are:
+ // - onSuspended, iff the network was suspended when the callbacks fire.
+ // - onCapabilitiesChanged.
+ // - onLinkPropertiesChanged.
+ // - onBlockedStatusChanged.
+ //
+ // @param network the network to expect the callbacks on.
+ // @param suspended whether to expect a SUSPENDED callback.
+ // @param validated the expected value of the VALIDATED capability in the
+ // onCapabilitiesChanged callback.
+ // @param tmt how long to wait for the callbacks.
+ fun expectAvailableCallbacks(
+ net: Network,
+ suspended: Boolean = false,
+ validated: Boolean = true,
+ blocked: Boolean = false,
+ tmt: Long = defaultTimeoutMs
+ ) {
+ expectCallback<Available>(net, tmt)
+ if (suspended) {
+ expectCallback<Suspended>(net, tmt)
+ }
+ expectCapabilitiesThat(net, tmt) { validated == it.hasCapability(NET_CAPABILITY_VALIDATED) }
+ expectCallback<LinkPropertiesChanged>(net, tmt)
+ expectBlockedStatusCallback(blocked, net)
+ }
+
+ // Backward compatibility for existing Java code. Use named arguments instead and remove all
+ // these when there is no user left.
+ fun expectAvailableAndSuspendedCallbacks(
+ net: Network,
+ validated: Boolean,
+ tmt: Long = defaultTimeoutMs
+ ) = expectAvailableCallbacks(net, suspended = true, validated = validated, tmt = tmt)
+
+ fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) {
+ expectCallback<BlockedStatus>(net, tmt).also {
+ assertEquals(it.blocked, blocked, "Unexpected blocked status ${it.blocked}")
+ }
+ }
+
+ // Expects the available callbacks (where the onCapabilitiesChanged must contain the
+ // VALIDATED capability), plus another onCapabilitiesChanged which is identical to the
+ // one we just sent.
+ // TODO: this is likely a bug. Fix it and remove this method.
+ fun expectAvailableDoubleValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
+ val mark = history.mark
+ expectAvailableCallbacks(net, tmt = tmt)
+ val firstCaps = history.poll(tmt, mark) { it is CapabilitiesChanged }
+ assertEquals(firstCaps, expectCallback<CapabilitiesChanged>(net, tmt))
+ }
+
+ // Expects the available callbacks where the onCapabilitiesChanged must not have validated,
+ // then expects another onCapabilitiesChanged that has the validated bit set. This is used
+ // when a network connects and satisfies a callback, and then immediately validates.
+ fun expectAvailableThenValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
+ expectAvailableCallbacks(net, validated = false, tmt = tmt)
+ expectCapabilitiesThat(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
+ }
+
+ // Temporary Java compat measure : have MockNetworkAgent implement this so that all existing
+ // calls with networkAgent can be routed through here without moving MockNetworkAgent.
+ // TODO: clean this up, remove this method.
+ interface HasNetwork {
+ val network: Network
+ }
+
+ @JvmOverloads
+ open fun <T : CallbackEntry> expectCallback(
+ type: KClass<T>,
+ n: Network?,
+ timeoutMs: Long = defaultTimeoutMs
+ ) = pollForNextCallback(timeoutMs).also {
+ val network = n ?: NULL_NETWORK
+ // TODO : remove this .java access if the tests ever use kotlin-reflect. At the time of
+ // this writing this would be the only use of this library in the tests.
+ assertTrue(type.java.isInstance(it) && it.network == network,
+ "Unexpected callback : $it, expected ${type.java} with Network[$network]")
+ } as T
+
+ @JvmOverloads
+ open fun <T : CallbackEntry> expectCallback(
+ type: KClass<T>,
+ n: HasNetwork?,
+ timeoutMs: Long = defaultTimeoutMs
+ ) = expectCallback(type, n?.network, timeoutMs)
+
+ fun expectAvailableCallbacks(
+ n: HasNetwork,
+ suspended: Boolean,
+ validated: Boolean,
+ blocked: Boolean,
+ timeoutMs: Long
+ ) = expectAvailableCallbacks(n.network, suspended, validated, blocked, timeoutMs)
+
+ fun expectAvailableAndSuspendedCallbacks(n: HasNetwork, expectValidated: Boolean) {
+ expectAvailableAndSuspendedCallbacks(n.network, expectValidated)
+ }
+
+ fun expectAvailableCallbacksValidated(n: HasNetwork) {
+ expectAvailableCallbacks(n.network)
+ }
+
+ fun expectAvailableCallbacksValidatedAndBlocked(n: HasNetwork) {
+ expectAvailableCallbacks(n.network, blocked = true)
+ }
+
+ fun expectAvailableCallbacksUnvalidated(n: HasNetwork) {
+ expectAvailableCallbacks(n.network, validated = false)
+ }
+
+ fun expectAvailableCallbacksUnvalidatedAndBlocked(n: HasNetwork) {
+ expectAvailableCallbacks(n.network, validated = false, blocked = true)
+ }
+
+ fun expectAvailableDoubleValidatedCallbacks(n: HasNetwork) {
+ expectAvailableDoubleValidatedCallbacks(n.network, defaultTimeoutMs)
+ }
+
+ fun expectAvailableThenValidatedCallbacks(n: HasNetwork) {
+ expectAvailableThenValidatedCallbacks(n.network, defaultTimeoutMs)
+ }
+
+ @JvmOverloads
+ fun expectLinkPropertiesThat(
+ n: HasNetwork,
+ tmt: Long = defaultTimeoutMs,
+ valid: (LinkProperties) -> Boolean
+ ) = expectLinkPropertiesThat(n.network, tmt, valid)
+
+ @JvmOverloads
+ fun expectCapabilitiesThat(
+ n: HasNetwork,
+ tmt: Long = defaultTimeoutMs,
+ valid: (NetworkCapabilities) -> Boolean
+ ) = expectCapabilitiesThat(n.network, tmt, valid)
+
+ @JvmOverloads
+ fun expectCapabilitiesWith(
+ capability: Int,
+ n: HasNetwork,
+ timeoutMs: Long = defaultTimeoutMs
+ ): NetworkCapabilities {
+ return expectCapabilitiesThat(n.network, timeoutMs) { it.hasCapability(capability) }.caps
+ }
+
+ @JvmOverloads
+ fun expectCapabilitiesWithout(
+ capability: Int,
+ n: HasNetwork,
+ timeoutMs: Long = defaultTimeoutMs
+ ): NetworkCapabilities {
+ return expectCapabilitiesThat(n.network, timeoutMs) { !it.hasCapability(capability) }.caps
+ }
+
+ fun expectBlockedStatusCallback(expectBlocked: Boolean, n: HasNetwork) {
+ expectBlockedStatusCallback(expectBlocked, n.network, defaultTimeoutMs)
+ }
+}
diff --git a/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProvider.kt b/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProvider.kt
new file mode 100644
index 0000000..d5c3a2a
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProvider.kt
@@ -0,0 +1,98 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.netstats.provider.NetworkStatsProvider
+import com.android.net.module.util.ArrayTrackRecord
+import kotlin.test.assertEquals
+import kotlin.test.assertTrue
+import kotlin.test.fail
+
+private const val DEFAULT_TIMEOUT_MS = 200L
+
+open class TestableNetworkStatsProvider(
+ val defaultTimeoutMs: Long = DEFAULT_TIMEOUT_MS
+) : NetworkStatsProvider() {
+ sealed class CallbackType {
+ data class OnRequestStatsUpdate(val token: Int) : CallbackType()
+ data class OnSetLimit(val iface: String?, val quotaBytes: Long) : CallbackType()
+ data class OnSetAlert(val quotaBytes: Long) : CallbackType()
+ }
+
+ val history = ArrayTrackRecord<CallbackType>().newReadHead()
+ // See ReadHead#mark
+ val mark get() = history.mark
+
+ override fun onRequestStatsUpdate(token: Int) {
+ history.add(CallbackType.OnRequestStatsUpdate(token))
+ }
+
+ override fun onSetLimit(iface: String, quotaBytes: Long) {
+ history.add(CallbackType.OnSetLimit(iface, quotaBytes))
+ }
+
+ override fun onSetAlert(quotaBytes: Long) {
+ history.add(CallbackType.OnSetAlert(quotaBytes))
+ }
+
+ fun expectOnRequestStatsUpdate(token: Int, timeout: Long = defaultTimeoutMs) {
+ assertEquals(CallbackType.OnRequestStatsUpdate(token), history.poll(timeout))
+ }
+
+ fun expectOnSetLimit(iface: String?, quotaBytes: Long, timeout: Long = defaultTimeoutMs) {
+ assertEquals(CallbackType.OnSetLimit(iface, quotaBytes), history.poll(timeout))
+ }
+
+ fun expectOnSetAlert(quotaBytes: Long, timeout: Long = defaultTimeoutMs) {
+ assertEquals(CallbackType.OnSetAlert(quotaBytes), history.poll(timeout))
+ }
+
+ fun pollForNextCallback(timeout: Long = defaultTimeoutMs) =
+ history.poll(timeout) ?: fail("Did not receive callback after ${timeout}ms")
+
+ inline fun <reified T : CallbackType> expectCallback(
+ timeout: Long = defaultTimeoutMs,
+ predicate: (T) -> Boolean = { true }
+ ): T {
+ return pollForNextCallback(timeout).also { assertTrue(it is T && predicate(it)) } as T
+ }
+
+ // Expects a callback of the specified type matching the predicate within the timeout.
+ // Any callback that doesn't match the predicate will be skipped. Fails only if
+ // no matching callback is received within the timeout.
+ // TODO : factorize the code for this with the identical call in TestableNetworkCallback.
+ // There should be a common superclass doing this generically.
+ // TODO : have a better error message to have this fail. Right now the failure when no
+ // matching callback arrives comes from the casting to a non-nullable T.
+ // TODO : in fact, completely removing this method and have clients use
+ // history.poll(timeout, index, predicate) directly might be simpler.
+ inline fun <reified T : CallbackType> eventuallyExpect(
+ timeoutMs: Long = defaultTimeoutMs,
+ from: Int = mark,
+ crossinline predicate: (T) -> Boolean = { true }
+ ) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T
+
+ fun drainCallbacks() {
+ history.mark = history.size
+ }
+
+ @JvmOverloads
+ fun assertNoCallback(timeout: Long = defaultTimeoutMs) {
+ val cb = history.poll(timeout)
+ cb?.let { fail("Expected no callback but got $cb") }
+ }
+}
diff --git a/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProviderBinder.kt b/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProviderBinder.kt
new file mode 100644
index 0000000..02922d8
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProviderBinder.kt
@@ -0,0 +1,64 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.netstats.provider.INetworkStatsProvider
+import com.android.net.module.util.ArrayTrackRecord
+import kotlin.test.assertEquals
+import kotlin.test.fail
+
+private const val DEFAULT_TIMEOUT_MS = 200L
+
+open class TestableNetworkStatsProviderBinder : INetworkStatsProvider.Stub() {
+ sealed class CallbackType {
+ data class OnRequestStatsUpdate(val token: Int) : CallbackType()
+ data class OnSetLimit(val iface: String?, val quotaBytes: Long) : CallbackType()
+ data class OnSetAlert(val quotaBytes: Long) : CallbackType()
+ }
+
+ private val history = ArrayTrackRecord<CallbackType>().ReadHead()
+
+ override fun onRequestStatsUpdate(token: Int) {
+ history.add(CallbackType.OnRequestStatsUpdate(token))
+ }
+
+ override fun onSetLimit(iface: String?, quotaBytes: Long) {
+ history.add(CallbackType.OnSetLimit(iface, quotaBytes))
+ }
+
+ override fun onSetAlert(quotaBytes: Long) {
+ history.add(CallbackType.OnSetAlert(quotaBytes))
+ }
+
+ fun expectOnRequestStatsUpdate(token: Int) {
+ assertEquals(CallbackType.OnRequestStatsUpdate(token), history.poll(DEFAULT_TIMEOUT_MS))
+ }
+
+ fun expectOnSetLimit(iface: String?, quotaBytes: Long) {
+ assertEquals(CallbackType.OnSetLimit(iface, quotaBytes), history.poll(DEFAULT_TIMEOUT_MS))
+ }
+
+ fun expectOnSetAlert(quotaBytes: Long) {
+ assertEquals(CallbackType.OnSetAlert(quotaBytes), history.poll(DEFAULT_TIMEOUT_MS))
+ }
+
+ @JvmOverloads
+ fun assertNoCallback(timeout: Long = DEFAULT_TIMEOUT_MS) {
+ val cb = history.poll(timeout)
+ cb?.let { fail("Expected no callback but got $cb") }
+ }
+}
diff --git a/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProviderCbBinder.kt b/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProviderCbBinder.kt
new file mode 100644
index 0000000..163473a
--- /dev/null
+++ b/staticlibs/devicetests/com/android/testutils/TestableNetworkStatsProviderCbBinder.kt
@@ -0,0 +1,84 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.net.NetworkStats
+import android.net.netstats.provider.INetworkStatsProviderCallback
+import com.android.net.module.util.ArrayTrackRecord
+import kotlin.test.assertEquals
+import kotlin.test.assertTrue
+import kotlin.test.fail
+
+private const val DEFAULT_TIMEOUT_MS = 3000L
+
+open class TestableNetworkStatsProviderCbBinder : INetworkStatsProviderCallback.Stub() {
+ sealed class CallbackType {
+ data class NotifyStatsUpdated(
+ val token: Int,
+ val ifaceStats: NetworkStats,
+ val uidStats: NetworkStats
+ ) : CallbackType()
+ object NotifyLimitReached : CallbackType()
+ object NotifyAlertReached : CallbackType()
+ object Unregister : CallbackType()
+ }
+
+ private val history = ArrayTrackRecord<CallbackType>().ReadHead()
+
+ override fun notifyStatsUpdated(token: Int, ifaceStats: NetworkStats, uidStats: NetworkStats) {
+ history.add(CallbackType.NotifyStatsUpdated(token, ifaceStats, uidStats))
+ }
+
+ override fun notifyLimitReached() {
+ history.add(CallbackType.NotifyLimitReached)
+ }
+
+ override fun notifyAlertReached() {
+ history.add(CallbackType.NotifyAlertReached)
+ }
+
+ override fun unregister() {
+ history.add(CallbackType.Unregister)
+ }
+
+ fun expectNotifyStatsUpdated() {
+ val event = history.poll(DEFAULT_TIMEOUT_MS)
+ assertTrue(event is CallbackType.NotifyStatsUpdated)
+ }
+
+ fun expectNotifyStatsUpdated(ifaceStats: NetworkStats, uidStats: NetworkStats) {
+ val event = history.poll(DEFAULT_TIMEOUT_MS)!!
+ if (event !is CallbackType.NotifyStatsUpdated) {
+ throw Exception("Expected NotifyStatsUpdated callback, but got ${event::class}")
+ }
+ // TODO: verify token.
+ assertNetworkStatsEquals(ifaceStats, event.ifaceStats)
+ assertNetworkStatsEquals(uidStats, event.uidStats)
+ }
+
+ fun expectNotifyLimitReached() =
+ assertEquals(CallbackType.NotifyLimitReached, history.poll(DEFAULT_TIMEOUT_MS))
+
+ fun expectNotifyAlertReached() =
+ assertEquals(CallbackType.NotifyAlertReached, history.poll(DEFAULT_TIMEOUT_MS))
+
+ // Assert there is no callback in current queue.
+ fun assertNoCallback() {
+ val cb = history.poll(0)
+ cb?.let { fail("Expected no callback but got $cb") }
+ }
+}