Support QosCallback with UDP socket. UnitTest

Adding unit tests to QosSocketFilterTest

Bug: 203146631
Test: atest & verified on LTE test equipment
Change-Id: I0cd82dde0067d754dfab01ed0406370d7debb937
diff --git a/tests/unit/java/android/net/QosSocketFilterTest.java b/tests/unit/java/android/net/QosSocketFilterTest.java
index 91f2cdd..6820b40 100644
--- a/tests/unit/java/android/net/QosSocketFilterTest.java
+++ b/tests/unit/java/android/net/QosSocketFilterTest.java
@@ -16,8 +16,17 @@
 
 package android.net;
 
-import static junit.framework.Assert.assertFalse;
-import static junit.framework.Assert.assertTrue;
+import static android.net.QosCallbackException.EX_TYPE_FILTER_NONE;
+import static android.net.QosCallbackException.EX_TYPE_FILTER_SOCKET_LOCAL_ADDRESS_CHANGED;
+import static android.net.QosCallbackException.EX_TYPE_FILTER_SOCKET_NOT_BOUND;
+import static android.net.QosCallbackException.EX_TYPE_FILTER_SOCKET_NOT_CONNECTED;
+import static android.net.QosCallbackException.EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED;
+import static android.system.OsConstants.IPPROTO_TCP;
+import static android.system.OsConstants.IPPROTO_UDP;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 
 import android.os.Build;
 
@@ -29,6 +38,7 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
+import java.net.DatagramSocket;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
 
@@ -36,14 +46,14 @@
 @SmallTest
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.R)
 public class QosSocketFilterTest {
-
+    private static final int TEST_NET_ID = 1777;
+    private final Network mNetwork = new Network(TEST_NET_ID);
     @Test
     public void testPortExactMatch() {
         final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4");
         final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4");
         assertTrue(QosSocketFilter.matchesAddress(
                 new InetSocketAddress(addressA, 10), addressB, 10, 10));
-
     }
 
     @Test
@@ -77,5 +87,90 @@
         assertFalse(QosSocketFilter.matchesAddress(
                 new InetSocketAddress(addressA, 10), addressB, 10, 10));
     }
+
+    @Test
+    public void testAddressMatchWithAnyLocalAddresses() {
+        final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4");
+        final InetAddress addressB = InetAddresses.parseNumericAddress("0.0.0.0");
+        assertTrue(QosSocketFilter.matchesAddress(
+                new InetSocketAddress(addressA, 10), addressB, 10, 10));
+        assertFalse(QosSocketFilter.matchesAddress(
+                new InetSocketAddress(addressB, 10), addressA, 10, 10));
+    }
+
+    @Test
+    public void testProtocolMatch() throws Exception {
+        DatagramSocket socket = new DatagramSocket(new InetSocketAddress("127.0.0.1", 0));
+        socket.connect(new InetSocketAddress("127.0.0.1", socket.getLocalPort() + 10));
+        DatagramSocket socketV6 = new DatagramSocket(new InetSocketAddress("::1", 0));
+        socketV6.connect(new InetSocketAddress("::1", socketV6.getLocalPort() + 10));
+        QosSocketInfo socketInfo = new QosSocketInfo(mNetwork, socket);
+        QosSocketFilter socketFilter = new QosSocketFilter(socketInfo);
+        QosSocketInfo socketInfo6 = new QosSocketInfo(mNetwork, socketV6);
+        QosSocketFilter socketFilter6 = new QosSocketFilter(socketInfo6);
+        assertTrue(socketFilter.matchesProtocol(IPPROTO_UDP));
+        assertTrue(socketFilter6.matchesProtocol(IPPROTO_UDP));
+        assertFalse(socketFilter.matchesProtocol(IPPROTO_TCP));
+        assertFalse(socketFilter6.matchesProtocol(IPPROTO_TCP));
+        socket.close();
+        socketV6.close();
+    }
+
+    @Test
+    public void testValidate() throws Exception {
+        DatagramSocket socket = new DatagramSocket(new InetSocketAddress("127.0.0.1", 0));
+        socket.connect(new InetSocketAddress("127.0.0.1", socket.getLocalPort() + 7));
+        DatagramSocket socketV6 = new DatagramSocket(new InetSocketAddress("::1", 0));
+
+        QosSocketInfo socketInfo = new QosSocketInfo(mNetwork, socket);
+        QosSocketFilter socketFilter = new QosSocketFilter(socketInfo);
+        QosSocketInfo socketInfo6 = new QosSocketInfo(mNetwork, socketV6);
+        QosSocketFilter socketFilter6 = new QosSocketFilter(socketInfo6);
+        assertEquals(EX_TYPE_FILTER_NONE, socketFilter.validate());
+        assertEquals(EX_TYPE_FILTER_NONE, socketFilter6.validate());
+        socket.close();
+        socketV6.close();
+    }
+
+    @Test
+    public void testValidateUnbind() throws Exception {
+        DatagramSocket socket;
+        socket = new DatagramSocket(null);
+        QosSocketInfo socketInfo = new QosSocketInfo(mNetwork, socket);
+        QosSocketFilter socketFilter = new QosSocketFilter(socketInfo);
+        assertEquals(EX_TYPE_FILTER_SOCKET_NOT_BOUND, socketFilter.validate());
+        socket.close();
+    }
+
+    @Test
+    public void testValidateLocalAddressChanged() throws Exception {
+        DatagramSocket socket = new DatagramSocket(null);
+        DatagramSocket socket6 = new DatagramSocket(null);
+        QosSocketInfo socketInfo = new QosSocketInfo(mNetwork, socket);
+        QosSocketFilter socketFilter = new QosSocketFilter(socketInfo);
+        QosSocketInfo socketInfo6 = new QosSocketInfo(mNetwork, socket6);
+        QosSocketFilter socketFilter6 = new QosSocketFilter(socketInfo6);
+        socket.bind(new InetSocketAddress("127.0.0.1", 0));
+        socket6.bind(new InetSocketAddress("::1", 0));
+        assertEquals(EX_TYPE_FILTER_SOCKET_LOCAL_ADDRESS_CHANGED, socketFilter.validate());
+        assertEquals(EX_TYPE_FILTER_SOCKET_LOCAL_ADDRESS_CHANGED, socketFilter6.validate());
+        socket.close();
+        socket6.close();
+    }
+
+    @Test
+    public void testValidateRemoteAddressChanged() throws Exception {
+        DatagramSocket socket;
+        socket = new DatagramSocket(new InetSocketAddress("127.0.0.1", 53137));
+        socket.connect(new InetSocketAddress("127.0.0.1", socket.getLocalPort() + 11));
+        QosSocketInfo socketInfo = new QosSocketInfo(mNetwork, socket);
+        QosSocketFilter socketFilter = new QosSocketFilter(socketInfo);
+        assertEquals(EX_TYPE_FILTER_NONE, socketFilter.validate());
+        socket.disconnect();
+        assertEquals(EX_TYPE_FILTER_SOCKET_NOT_CONNECTED, socketFilter.validate());
+        socket.connect(new InetSocketAddress("127.0.0.1", socket.getLocalPort() + 13));
+        assertEquals(EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED, socketFilter.validate());
+        socket.close();
+    }
 }
 
diff --git a/tests/unit/java/android/net/QosSocketInfoTest.java b/tests/unit/java/android/net/QosSocketInfoTest.java
new file mode 100644
index 0000000..749c182
--- /dev/null
+++ b/tests/unit/java/android/net/QosSocketInfoTest.java
@@ -0,0 +1,111 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net;
+
+import static android.system.OsConstants.SOCK_DGRAM;
+import static android.system.OsConstants.SOCK_STREAM;
+
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+
+import android.os.Build;
+
+import androidx.test.filters.SmallTest;
+
+import com.android.testutils.DevSdkIgnoreRule;
+import com.android.testutils.DevSdkIgnoreRunner;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+
+import java.net.DatagramSocket;
+import java.net.InetSocketAddress;
+import java.net.ServerSocket;
+import java.net.Socket;
+
+@RunWith(DevSdkIgnoreRunner.class)
+@SmallTest
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.R)
+public class QosSocketInfoTest {
+    @Mock
+    private Network mMockNetwork = mock(Network.class);
+
+    @Test
+    public void testConstructWithSock() throws Exception {
+        ServerSocket server = new ServerSocket();
+        ServerSocket server6 = new ServerSocket();
+
+        InetSocketAddress clientAddr = new InetSocketAddress("127.0.0.1", 0);
+        InetSocketAddress serverAddr = new InetSocketAddress("127.0.0.1", 0);
+        InetSocketAddress clientAddr6 = new InetSocketAddress("::1", 0);
+        InetSocketAddress serverAddr6 = new InetSocketAddress("::1", 0);
+        server.bind(serverAddr);
+        server6.bind(serverAddr6);
+        Socket socket = new Socket(serverAddr.getAddress(), server.getLocalPort(),
+                clientAddr.getAddress(), clientAddr.getPort());
+        Socket socket6 = new Socket(serverAddr6.getAddress(), server6.getLocalPort(),
+                clientAddr6.getAddress(), clientAddr6.getPort());
+        QosSocketInfo sockInfo = new QosSocketInfo(mMockNetwork, socket);
+        QosSocketInfo sockInfo6 = new QosSocketInfo(mMockNetwork, socket6);
+        assertTrue(sockInfo.getLocalSocketAddress()
+                .equals(new InetSocketAddress(socket.getLocalAddress(), socket.getLocalPort())));
+        assertTrue(sockInfo.getRemoteSocketAddress()
+                .equals((InetSocketAddress) socket.getRemoteSocketAddress()));
+        assertEquals(SOCK_STREAM, sockInfo.getSocketType());
+        assertTrue(sockInfo6.getLocalSocketAddress()
+                .equals(new InetSocketAddress(socket6.getLocalAddress(), socket6.getLocalPort())));
+        assertTrue(sockInfo6.getRemoteSocketAddress()
+                .equals((InetSocketAddress) socket6.getRemoteSocketAddress()));
+        assertEquals(SOCK_STREAM, sockInfo6.getSocketType());
+        socket.close();
+        socket6.close();
+        server.close();
+        server6.close();
+    }
+
+    @Test
+    public void testConstructWithDatagramSock() throws Exception {
+        InetSocketAddress clientAddr = new InetSocketAddress("127.0.0.1", 0);
+        InetSocketAddress serverAddr = new InetSocketAddress("127.0.0.1", 0);
+        InetSocketAddress clientAddr6 = new InetSocketAddress("::1", 0);
+        InetSocketAddress serverAddr6 = new InetSocketAddress("::1", 0);
+        DatagramSocket socket = new DatagramSocket(null);
+        socket.setReuseAddress(true);
+        socket.bind(clientAddr);
+        socket.connect(serverAddr);
+        DatagramSocket socket6 = new DatagramSocket(null);
+        socket6.setReuseAddress(true);
+        socket6.bind(clientAddr);
+        socket6.connect(serverAddr);
+        QosSocketInfo sockInfo = new QosSocketInfo(mMockNetwork, socket);
+        QosSocketInfo sockInfo6 = new QosSocketInfo(mMockNetwork, socket6);
+        assertTrue(sockInfo.getLocalSocketAddress()
+                .equals((InetSocketAddress) socket.getLocalSocketAddress()));
+        assertTrue(sockInfo.getRemoteSocketAddress()
+                .equals((InetSocketAddress) socket.getRemoteSocketAddress()));
+        assertEquals(SOCK_DGRAM, sockInfo.getSocketType());
+        assertTrue(sockInfo6.getLocalSocketAddress()
+                .equals((InetSocketAddress) socket6.getLocalSocketAddress()));
+        assertTrue(sockInfo6.getRemoteSocketAddress()
+                .equals((InetSocketAddress) socket6.getRemoteSocketAddress()));
+        assertEquals(SOCK_DGRAM, sockInfo6.getSocketType());
+        socket.close();
+    }
+}
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 025b28c..5e0415a 100644
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -136,6 +136,8 @@
 import static com.android.server.ConnectivityService.PREFERENCE_ORDER_PROFILE;
 import static com.android.server.ConnectivityService.PREFERENCE_ORDER_VPN;
 import static com.android.server.ConnectivityServiceTestUtils.transportToLegacyType;
+import static com.android.server.NetworkAgentWrapper.CallbackType.OnQosCallbackRegister;
+import static com.android.server.NetworkAgentWrapper.CallbackType.OnQosCallbackUnregister;
 import static com.android.testutils.ConcurrentUtils.await;
 import static com.android.testutils.ConcurrentUtils.durationOf;
 import static com.android.testutils.DevSdkIgnoreRule.IgnoreAfter;
@@ -159,8 +161,8 @@
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
-import static org.junit.Assume.assumeTrue;
 import static org.junit.Assume.assumeFalse;
+import static org.junit.Assume.assumeTrue;
 import static org.mockito.AdditionalMatchers.aryEq;
 import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.ArgumentMatchers.anyLong;
@@ -11844,16 +11846,14 @@
         mQosCallbackMockHelper.registerQosCallback(
                 mQosCallbackMockHelper.mFilter, mQosCallbackMockHelper.mCallback);
 
-        final NetworkAgentWrapper.CallbackType.OnQosCallbackRegister cbRegister1 =
-                (NetworkAgentWrapper.CallbackType.OnQosCallbackRegister)
-                        wrapper.getCallbackHistory().poll(1000, x -> true);
+        final OnQosCallbackRegister cbRegister1 =
+                (OnQosCallbackRegister) wrapper.getCallbackHistory().poll(1000, x -> true);
         assertNotNull(cbRegister1);
 
         final int registerCallbackId = cbRegister1.mQosCallbackId;
         mService.unregisterQosCallback(mQosCallbackMockHelper.mCallback);
-        final NetworkAgentWrapper.CallbackType.OnQosCallbackUnregister cbUnregister;
-        cbUnregister = (NetworkAgentWrapper.CallbackType.OnQosCallbackUnregister)
-                wrapper.getCallbackHistory().poll(1000, x -> true);
+        final OnQosCallbackUnregister cbUnregister =
+                (OnQosCallbackUnregister) wrapper.getCallbackHistory().poll(1000, x -> true);
         assertNotNull(cbUnregister);
         assertEquals(registerCallbackId, cbUnregister.mQosCallbackId);
         assertNull(wrapper.getCallbackHistory().poll(200, x -> true));
@@ -11932,6 +11932,86 @@
                         && session.getSessionType() == QosSession.TYPE_NR_BEARER));
     }
 
+    @Test @IgnoreUpTo(SC_V2)
+    public void testQosCallbackAvailableOnValidationError() throws Exception {
+        mQosCallbackMockHelper = new QosCallbackMockHelper();
+        final NetworkAgentWrapper wrapper = mQosCallbackMockHelper.mAgentWrapper;
+        final int sessionId = 10;
+        final int qosCallbackId = 1;
+
+        doReturn(QosCallbackException.EX_TYPE_FILTER_NONE)
+                .when(mQosCallbackMockHelper.mFilter).validate();
+        mQosCallbackMockHelper.registerQosCallback(
+                mQosCallbackMockHelper.mFilter, mQosCallbackMockHelper.mCallback);
+        OnQosCallbackRegister cbRegister1 =
+                (OnQosCallbackRegister) wrapper.getCallbackHistory().poll(1000, x -> true);
+        assertNotNull(cbRegister1);
+        final int registerCallbackId = cbRegister1.mQosCallbackId;
+
+        waitForIdle();
+
+        doReturn(QosCallbackException.EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED)
+                .when(mQosCallbackMockHelper.mFilter).validate();
+        final EpsBearerQosSessionAttributes attributes = new EpsBearerQosSessionAttributes(
+                1, 2, 3, 4, 5, new ArrayList<>());
+        mQosCallbackMockHelper.mAgentWrapper.getNetworkAgent()
+                .sendQosSessionAvailable(qosCallbackId, sessionId, attributes);
+        waitForIdle();
+
+        final NetworkAgentWrapper.CallbackType.OnQosCallbackUnregister cbUnregister;
+        cbUnregister = (NetworkAgentWrapper.CallbackType.OnQosCallbackUnregister)
+                wrapper.getCallbackHistory().poll(1000, x -> true);
+        assertNotNull(cbUnregister);
+        assertEquals(registerCallbackId, cbUnregister.mQosCallbackId);
+        waitForIdle();
+        verify(mQosCallbackMockHelper.mCallback)
+                .onError(eq(QosCallbackException.EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED));
+    }
+
+    @Test @IgnoreUpTo(SC_V2)
+    public void testQosCallbackLostOnValidationError() throws Exception {
+        mQosCallbackMockHelper = new QosCallbackMockHelper();
+        final int sessionId = 10;
+        final int qosCallbackId = 1;
+
+        doReturn(QosCallbackException.EX_TYPE_FILTER_NONE)
+                .when(mQosCallbackMockHelper.mFilter).validate();
+        mQosCallbackMockHelper.registerQosCallback(
+                mQosCallbackMockHelper.mFilter, mQosCallbackMockHelper.mCallback);
+        waitForIdle();
+        EpsBearerQosSessionAttributes attributes =
+                sendQosSessionEvent(qosCallbackId, sessionId, true);
+        waitForIdle();
+
+        verify(mQosCallbackMockHelper.mCallback).onQosEpsBearerSessionAvailable(argThat(session ->
+                session.getSessionId() == sessionId
+                        && session.getSessionType() == QosSession.TYPE_EPS_BEARER), eq(attributes));
+
+        doReturn(QosCallbackException.EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED)
+                .when(mQosCallbackMockHelper.mFilter).validate();
+
+        sendQosSessionEvent(qosCallbackId, sessionId, false);
+        waitForIdle();
+        verify(mQosCallbackMockHelper.mCallback)
+                .onError(eq(QosCallbackException.EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED));
+    }
+
+    private EpsBearerQosSessionAttributes sendQosSessionEvent(
+            int qosCallbackId, int sessionId, boolean available) {
+        if (available) {
+            final EpsBearerQosSessionAttributes attributes = new EpsBearerQosSessionAttributes(
+                    1, 2, 3, 4, 5, new ArrayList<>());
+            mQosCallbackMockHelper.mAgentWrapper.getNetworkAgent()
+                    .sendQosSessionAvailable(qosCallbackId, sessionId, attributes);
+            return attributes;
+        } else {
+            mQosCallbackMockHelper.mAgentWrapper.getNetworkAgent()
+                    .sendQosSessionLost(qosCallbackId, sessionId, QosSession.TYPE_EPS_BEARER);
+            return null;
+        }
+
+    }
+
     @Test
     public void testQosCallbackTooManyRequests() throws Exception {
         mQosCallbackMockHelper = new QosCallbackMockHelper();