Merge "Remove PacketRepeater destinationsSupplier logic"
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java b/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java
index 91e08a8..c056e69 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsAnnouncer.java
@@ -22,10 +22,8 @@
 
 import com.android.internal.annotations.VisibleForTesting;
 
-import java.net.SocketAddress;
 import java.util.Collections;
 import java.util.List;
-import java.util.function.Supplier;
 
 /**
  * Sends mDns announcements when a service registration changes and at regular intervals.
@@ -43,11 +41,8 @@
     static class AnnouncementInfo implements MdnsPacketRepeater.Request {
         @NonNull
         private final MdnsPacket mPacket;
-        @NonNull
-        private final Supplier<Iterable<SocketAddress>> mDestinationsSupplier;
 
-        AnnouncementInfo(List<MdnsRecord> announcedRecords, List<MdnsRecord> additionalRecords,
-                Supplier<Iterable<SocketAddress>> destinationsSupplier) {
+        AnnouncementInfo(List<MdnsRecord> announcedRecords, List<MdnsRecord> additionalRecords) {
             // Records to announce (as answers)
             // Records to place in the "Additional records", with NSEC negative responses
             // to mark records that have been verified unique
@@ -57,7 +52,6 @@
                     announcedRecords,
                     Collections.emptyList() /* authorityRecords */,
                     additionalRecords);
-            mDestinationsSupplier = destinationsSupplier;
         }
 
         @Override
@@ -66,11 +60,6 @@
         }
 
         @Override
-        public Iterable<SocketAddress> getDestinations(int index) {
-            return mDestinationsSupplier.get();
-        }
-
-        @Override
         public long getDelayMs(int nextIndex) {
             // Delay is doubled for each announcement
             return ANNOUNCEMENT_INITIAL_DELAY_MS << (nextIndex - 1);
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsPacketRepeater.java b/service/mdns/com/android/server/connectivity/mdns/MdnsPacketRepeater.java
index 015dbd8..ae54e70 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsPacketRepeater.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsPacketRepeater.java
@@ -24,7 +24,7 @@
 import android.util.Log;
 
 import java.io.IOException;
-import java.net.SocketAddress;
+import java.net.InetSocketAddress;
 
 /**
  * A class used to send several packets at given time intervals.
@@ -32,6 +32,14 @@
  */
 public abstract class MdnsPacketRepeater<T extends MdnsPacketRepeater.Request> {
     private static final boolean DBG = MdnsAdvertiser.DBG;
+    private static final InetSocketAddress IPV4_ADDR = new InetSocketAddress(
+            MdnsConstants.getMdnsIPv4Address(), MdnsConstants.MDNS_PORT);
+    private static final InetSocketAddress IPV6_ADDR = new InetSocketAddress(
+            MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT);
+    private static final InetSocketAddress[] ALL_ADDRS = new InetSocketAddress[] {
+            IPV4_ADDR, IPV6_ADDR
+    };
+
     @NonNull
     private final MdnsReplySender mReplySender;
     @NonNull
@@ -70,12 +78,6 @@
         MdnsPacket getPacket(int index);
 
         /**
-         * Get a set of destinations for the packet for one iteration.
-         */
-        @NonNull
-        Iterable<SocketAddress> getDestinations(int index);
-
-        /**
          * Get the delay in milliseconds until the next packet transmission.
          */
         long getDelayMs(int nextIndex);
@@ -110,12 +112,13 @@
             }
 
             final MdnsPacket packet = request.getPacket(index);
-            final Iterable<SocketAddress> destinations = request.getDestinations(index);
             if (DBG) {
-                Log.v(getTag(), "Sending packets to " + destinations + " for iteration "
-                        + index + " out of " + request.getNumSends());
+                Log.v(getTag(), "Sending packets for iteration " + index + " out of "
+                        + request.getNumSends());
             }
-            for (SocketAddress destination : destinations) {
+            // Send to both v4 and v6 addresses; the reply sender will take care of ignoring the
+            // send when the socket has not joined the relevant group.
+            for (InetSocketAddress destination : ALL_ADDRS) {
                 try {
                     mReplySender.sendNow(packet, destination);
                 } catch (IOException e) {
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java b/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
index db7049e..9a1e62b 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsProber.java
@@ -22,12 +22,10 @@
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.net.module.util.CollectionUtils;
 
-import java.net.SocketAddress;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
-import java.util.function.Supplier;
 
 /**
  * Sends mDns probe requests to verify service records are unique on the network.
@@ -51,21 +49,15 @@
         private final int mServiceId;
         @NonNull
         private final MdnsPacket mPacket;
-        @NonNull
-        private final Supplier<Iterable<SocketAddress>> mDestinationsSupplier;
 
         /**
          * Create a new ProbingInfo
          * @param serviceId Service to probe for.
          * @param probeRecords Records to be probed for uniqueness.
-         * @param destinationsSupplier Supplier for the probe destinations. Will be called on the
-         *                             probe handler thread for each probe.
          */
-        ProbingInfo(int serviceId, @NonNull List<MdnsRecord> probeRecords,
-                @NonNull Supplier<Iterable<SocketAddress>> destinationsSupplier) {
+        ProbingInfo(int serviceId, @NonNull List<MdnsRecord> probeRecords) {
             mServiceId = serviceId;
             mPacket = makePacket(probeRecords);
-            mDestinationsSupplier = destinationsSupplier;
         }
 
         public int getServiceId() {
@@ -78,12 +70,6 @@
             return mPacket;
         }
 
-        @NonNull
-        @Override
-        public Iterable<SocketAddress> getDestinations(int index) {
-            return mDestinationsSupplier.get();
-        }
-
         @Override
         public long getDelayMs(int nextIndex) {
             // As per https://datatracker.ietf.org/doc/html/rfc6762#section-8.1
diff --git a/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java b/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java
index adf6f4d..c6b8f47 100644
--- a/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java
+++ b/service/mdns/com/android/server/connectivity/mdns/MdnsReplySender.java
@@ -21,8 +21,10 @@
 
 import java.io.IOException;
 import java.net.DatagramPacket;
+import java.net.Inet4Address;
+import java.net.Inet6Address;
+import java.net.InetSocketAddress;
 import java.net.MulticastSocket;
-import java.net.SocketAddress;
 
 /**
  * A class that handles sending mDNS replies to a {@link MulticastSocket}, possibly queueing them
@@ -50,11 +52,16 @@
      *
      * Must be called on the looper thread used by the {@link MdnsReplySender}.
      */
-    public void sendNow(@NonNull MdnsPacket packet, @NonNull SocketAddress destination)
+    public void sendNow(@NonNull MdnsPacket packet, @NonNull InetSocketAddress destination)
             throws IOException {
         if (Thread.currentThread() != mLooper.getThread()) {
             throw new IllegalStateException("sendNow must be called in the handler thread");
         }
+        if (!((destination.getAddress() instanceof Inet6Address && mSocket.hasJoinedIpv6())
+                || (destination.getAddress() instanceof Inet4Address && mSocket.hasJoinedIpv4()))) {
+            // Skip sending if the socket has not joined the v4/v6 group (there was no address)
+            return;
+        }
 
         // TODO: support packets over size (send in multiple packets with TC bit set)
         final MdnsPacketWriter writer = new MdnsPacketWriter(mPacketCreationBuffer);
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
index 2051e0c..961f0f0 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAnnouncerTest.kt
@@ -27,7 +27,6 @@
 import java.net.DatagramPacket
 import java.net.Inet6Address
 import java.net.InetAddress
-import java.net.InetSocketAddress
 import kotlin.test.assertEquals
 import kotlin.test.assertTrue
 import org.junit.After
@@ -37,6 +36,7 @@
 import org.mockito.ArgumentCaptor
 import org.mockito.Mockito.any
 import org.mockito.Mockito.atLeast
+import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.mock
 import org.mockito.Mockito.timeout
 import org.mockito.Mockito.verify
@@ -46,9 +46,6 @@
 private const val NEXT_ANNOUNCES_DELAY = 1L
 private const val TEST_TIMEOUT_MS = 1000L
 
-private val destinationsSupplier = {
-    listOf(InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT)) }
-
 @RunWith(DevSdkIgnoreRunner::class)
 @IgnoreUpTo(Build.VERSION_CODES.S_V2)
 class MdnsAnnouncerTest {
@@ -59,6 +56,7 @@
 
     @Before
     fun setUp() {
+        doReturn(true).`when`(socket).hasJoinedIpv6()
         thread.start()
     }
 
@@ -70,7 +68,7 @@
     private class TestAnnouncementInfo(
         announcedRecords: List<MdnsRecord>,
         additionalRecords: List<MdnsRecord>
-    ) : AnnouncementInfo(announcedRecords, additionalRecords, destinationsSupplier) {
+    ) : AnnouncementInfo(announcedRecords, additionalRecords) {
         override fun getDelayMs(nextIndex: Int) =
                 if (nextIndex < FIRST_ANNOUNCES_COUNT) {
                     FIRST_ANNOUNCES_DELAY
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
index a98a4b2..3caa97d 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsProberTest.kt
@@ -25,7 +25,6 @@
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
 import com.android.testutils.DevSdkIgnoreRunner
 import java.net.DatagramPacket
-import java.net.InetSocketAddress
 import java.util.concurrent.CompletableFuture
 import java.util.concurrent.TimeUnit
 import kotlin.test.assertEquals
@@ -37,15 +36,13 @@
 import org.mockito.ArgumentCaptor
 import org.mockito.Mockito.any
 import org.mockito.Mockito.atLeast
+import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.mock
 import org.mockito.Mockito.never
 import org.mockito.Mockito.timeout
 import org.mockito.Mockito.times
 import org.mockito.Mockito.verify
 
-private val destinationsSupplier = {
-    listOf(InetSocketAddress(MdnsConstants.getMdnsIPv6Address(), MdnsConstants.MDNS_PORT)) }
-
 private const val TEST_TIMEOUT_MS = 10_000L
 private const val SHORT_TIMEOUT_MS = 200L
 
@@ -64,6 +61,7 @@
 
     @Before
     fun setUp() {
+        doReturn(true).`when`(socket).hasJoinedIpv6()
         thread.start()
     }
 
@@ -73,7 +71,7 @@
     }
 
     private class TestProbeInfo(probeRecords: List<MdnsRecord>, private val delayMs: Long = 1L) :
-            ProbingInfo(1 /* serviceId */, probeRecords, destinationsSupplier) {
+            ProbingInfo(1 /* serviceId */, probeRecords) {
         // Just send the packets quickly. Timing-related tests for MdnsPacketRepeater are already
         // done in MdnsAnnouncerTest.
         override fun getDelayMs(nextIndex: Int) = delayMs