Merge "QOS filter matching support based on remote address and port number for connected sockets" am: 6c5553aaaa

Original change: https://android-review.googlesource.com/c/platform/frameworks/base/+/1687813

Change-Id: Ibd70a86d82d4810425288694f2e3904d471a2d24
diff --git a/framework/api/system-current.txt b/framework/api/system-current.txt
index 5750845..730555b 100644
--- a/framework/api/system-current.txt
+++ b/framework/api/system-current.txt
@@ -381,6 +381,7 @@
   public abstract class QosFilter {
     method @NonNull public abstract android.net.Network getNetwork();
     method public abstract boolean matchesLocalAddress(@NonNull java.net.InetAddress, int, int);
+    method public abstract boolean matchesRemoteAddress(@NonNull java.net.InetAddress, int, int);
   }
 
   public final class QosSession implements android.os.Parcelable {
@@ -403,6 +404,7 @@
     method public int describeContents();
     method @NonNull public java.net.InetSocketAddress getLocalSocketAddress();
     method @NonNull public android.net.Network getNetwork();
+    method @Nullable public java.net.InetSocketAddress getRemoteSocketAddress();
     method public void writeToParcel(@NonNull android.os.Parcel, int);
     field @NonNull public static final android.os.Parcelable.Creator<android.net.QosSocketInfo> CREATOR;
   }
diff --git a/framework/src/android/net/QosFilter.java b/framework/src/android/net/QosFilter.java
index ab55002..957c867 100644
--- a/framework/src/android/net/QosFilter.java
+++ b/framework/src/android/net/QosFilter.java
@@ -71,5 +71,16 @@
      */
     public abstract boolean matchesLocalAddress(@NonNull InetAddress address,
             int startPort, int endPort);
+
+    /**
+     * Determines whether or not the parameters is a match for the filter.
+     *
+     * @param address the remote address
+     * @param startPort the start of the port range
+     * @param endPort the end of the port range
+     * @return whether the parameters match the remote address of the filter
+     */
+    public abstract boolean matchesRemoteAddress(@NonNull InetAddress address,
+            int startPort, int endPort);
 }
 
diff --git a/framework/src/android/net/QosSocketFilter.java b/framework/src/android/net/QosSocketFilter.java
index 2080e68..69da7f4 100644
--- a/framework/src/android/net/QosSocketFilter.java
+++ b/framework/src/android/net/QosSocketFilter.java
@@ -138,13 +138,26 @@
         if (mQosSocketInfo.getLocalSocketAddress() == null) {
             return false;
         }
-
-        return matchesLocalAddress(mQosSocketInfo.getLocalSocketAddress(), address, startPort,
+        return matchesAddress(mQosSocketInfo.getLocalSocketAddress(), address, startPort,
                 endPort);
     }
 
     /**
-     * Called from {@link QosSocketFilter#matchesLocalAddress(InetAddress, int, int)} with the
+     * @inheritDoc
+     */
+    @Override
+    public boolean matchesRemoteAddress(@NonNull final InetAddress address, final int startPort,
+            final int endPort) {
+        if (mQosSocketInfo.getRemoteSocketAddress() == null) {
+            return false;
+        }
+        return matchesAddress(mQosSocketInfo.getRemoteSocketAddress(), address, startPort,
+                endPort);
+    }
+
+    /**
+     * Called from {@link QosSocketFilter#matchesLocalAddress(InetAddress, int, int)}
+     * and {@link QosSocketFilter#matchesRemoteAddress(InetAddress, int, int)} with the
      * filterSocketAddress coming from {@link QosSocketInfo#getLocalSocketAddress()}.
      * <p>
      * This method exists for testing purposes since {@link QosSocketInfo} couldn't be mocked
@@ -156,7 +169,7 @@
      * @param endPort the end of the port range to check
      */
     @VisibleForTesting
-    public static boolean matchesLocalAddress(@NonNull final InetSocketAddress filterSocketAddress,
+    public static boolean matchesAddress(@NonNull final InetSocketAddress filterSocketAddress,
             @NonNull final InetAddress address,
             final int startPort, final int endPort) {
         return startPort <= filterSocketAddress.getPort()
diff --git a/framework/src/android/net/QosSocketInfo.java b/framework/src/android/net/QosSocketInfo.java
index 53d9669..a45d507 100644
--- a/framework/src/android/net/QosSocketInfo.java
+++ b/framework/src/android/net/QosSocketInfo.java
@@ -17,6 +17,7 @@
 package android.net;
 
 import android.annotation.NonNull;
+import android.annotation.Nullable;
 import android.annotation.SystemApi;
 import android.os.Parcel;
 import android.os.ParcelFileDescriptor;
@@ -32,7 +33,8 @@
 /**
  * Used in conjunction with
  * {@link ConnectivityManager#registerQosCallback}
- * in order to receive Qos Sessions related to the local address and port of a bound {@link Socket}.
+ * in order to receive Qos Sessions related to the local address and port of a bound {@link Socket}
+ * and/or remote address and port of a connected {@link Socket}.
  *
  * @hide
  */
@@ -48,6 +50,9 @@
     @NonNull
     private final InetSocketAddress mLocalSocketAddress;
 
+    @Nullable
+    private final InetSocketAddress mRemoteSocketAddress;
+
     /**
      * The {@link Network} the socket is on.
      *
@@ -81,6 +86,18 @@
     }
 
     /**
+     * The remote address of the socket passed into {@link QosSocketInfo(Network, Socket)}.
+     * The value does not reflect any changes that occur to the socket after it is first set
+     * in the constructor.
+     *
+     * @return the remote address of the socket if socket is connected, null otherwise
+     */
+    @Nullable
+    public InetSocketAddress getRemoteSocketAddress() {
+        return mRemoteSocketAddress;
+    }
+
+    /**
      * Creates a {@link QosSocketInfo} given a {@link Network} and bound {@link Socket}.  The
      * {@link Socket} must remain bound in order to receive {@link QosSession}s.
      *
@@ -95,6 +112,12 @@
         mParcelFileDescriptor = ParcelFileDescriptor.fromSocket(socket);
         mLocalSocketAddress =
                 new InetSocketAddress(socket.getLocalAddress(), socket.getLocalPort());
+
+        if (socket.isConnected()) {
+            mRemoteSocketAddress = (InetSocketAddress) socket.getRemoteSocketAddress();
+        } else {
+            mRemoteSocketAddress = null;
+        }
     }
 
     /* Parcelable methods */
@@ -102,11 +125,15 @@
         mNetwork = Objects.requireNonNull(Network.CREATOR.createFromParcel(in));
         mParcelFileDescriptor = ParcelFileDescriptor.CREATOR.createFromParcel(in);
 
-        final int addressLength = in.readInt();
-        mLocalSocketAddress = readSocketAddress(in, addressLength);
+        final int localAddressLength = in.readInt();
+        mLocalSocketAddress = readSocketAddress(in, localAddressLength);
+
+        final int remoteAddressLength = in.readInt();
+        mRemoteSocketAddress = remoteAddressLength == 0 ? null
+                : readSocketAddress(in, remoteAddressLength);
     }
 
-    private InetSocketAddress readSocketAddress(final Parcel in, final int addressLength) {
+    private @NonNull InetSocketAddress readSocketAddress(final Parcel in, final int addressLength) {
         final byte[] address = new byte[addressLength];
         in.readByteArray(address);
         final int port = in.readInt();
@@ -130,10 +157,19 @@
         mNetwork.writeToParcel(dest, 0);
         mParcelFileDescriptor.writeToParcel(dest, 0);
 
-        final byte[] address = mLocalSocketAddress.getAddress().getAddress();
-        dest.writeInt(address.length);
-        dest.writeByteArray(address);
+        final byte[] localAddress = mLocalSocketAddress.getAddress().getAddress();
+        dest.writeInt(localAddress.length);
+        dest.writeByteArray(localAddress);
         dest.writeInt(mLocalSocketAddress.getPort());
+
+        if (mRemoteSocketAddress == null) {
+            dest.writeInt(0);
+        } else {
+            final byte[] remoteAddress = mRemoteSocketAddress.getAddress().getAddress();
+            dest.writeInt(remoteAddress.length);
+            dest.writeByteArray(remoteAddress);
+            dest.writeInt(mRemoteSocketAddress.getPort());
+        }
     }
 
     @NonNull
diff --git a/tests/unit/java/android/net/QosSocketFilterTest.java b/tests/unit/java/android/net/QosSocketFilterTest.java
index ad58960..40f8f1b 100644
--- a/tests/unit/java/android/net/QosSocketFilterTest.java
+++ b/tests/unit/java/android/net/QosSocketFilterTest.java
@@ -35,7 +35,7 @@
     public void testPortExactMatch() {
         final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4");
         final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4");
-        assertTrue(QosSocketFilter.matchesLocalAddress(
+        assertTrue(QosSocketFilter.matchesAddress(
                 new InetSocketAddress(addressA, 10), addressB, 10, 10));
 
     }
@@ -44,7 +44,7 @@
     public void testPortLessThanStart() {
         final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4");
         final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4");
-        assertFalse(QosSocketFilter.matchesLocalAddress(
+        assertFalse(QosSocketFilter.matchesAddress(
                 new InetSocketAddress(addressA, 8), addressB, 10, 10));
     }
 
@@ -52,7 +52,7 @@
     public void testPortGreaterThanEnd() {
         final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4");
         final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4");
-        assertFalse(QosSocketFilter.matchesLocalAddress(
+        assertFalse(QosSocketFilter.matchesAddress(
                 new InetSocketAddress(addressA, 18), addressB, 10, 10));
     }
 
@@ -60,7 +60,7 @@
     public void testPortBetweenStartAndEnd() {
         final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4");
         final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4");
-        assertTrue(QosSocketFilter.matchesLocalAddress(
+        assertTrue(QosSocketFilter.matchesAddress(
                 new InetSocketAddress(addressA, 10), addressB, 8, 18));
     }
 
@@ -68,7 +68,7 @@
     public void testAddressesDontMatch() {
         final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4");
         final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.5");
-        assertFalse(QosSocketFilter.matchesLocalAddress(
+        assertFalse(QosSocketFilter.matchesAddress(
                 new InetSocketAddress(addressA, 10), addressB, 10, 10));
     }
 }