Fix replying to queries via unicast

Instead of (wrongly) looking at the DNS message flags for the Query
Unicast bit, look at the flag in the question record.

The previous bug means that all queries would be replied multicast
instead of sometimes unicast. The fix is flagged off as this is a
behavior change that may affect performance and latency.

Bug: 289482497
Test: atest
Change-Id: I08e30c4ffa747c9073d631e8addca1278ea40648
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 6c25d76..9b2f80b 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -176,7 +176,7 @@
             "mdns_advertiser_allowlist_";
     private static final String MDNS_ALLOWLIST_FLAG_SUFFIX = "_version";
 
-
+    private static final String FORCE_ENABLE_FLAG_FOR_TEST_PREFIX = "test_";
 
     @VisibleForTesting
     static final String MDNS_CONFIG_RUNNING_APP_ACTIVE_IMPORTANCE_CUTOFF =
@@ -1739,6 +1739,10 @@
                         mContext, MdnsFeatureFlags.NSD_LIMIT_LABEL_COUNT))
                 .setIsKnownAnswerSuppressionEnabled(mDeps.isFeatureEnabled(
                         mContext, MdnsFeatureFlags.NSD_KNOWN_ANSWER_SUPPRESSION))
+                .setIsUnicastReplyEnabled(mDeps.isFeatureEnabled(
+                        mContext, MdnsFeatureFlags.NSD_UNICAST_REPLY_ENABLED))
+                .setOverrideProvider(flag -> mDeps.isFeatureEnabled(
+                        mContext, FORCE_ENABLE_FLAG_FOR_TEST_PREFIX + flag))
                 .build();
         mMdnsSocketClient =
                 new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider,
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
index 1ad47a3..9466162 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsFeatureFlags.java
@@ -15,6 +15,9 @@
  */
 package com.android.server.connectivity.mdns;
 
+import android.annotation.NonNull;
+import android.annotation.Nullable;
+
 /**
  * The class that contains mDNS feature flags;
  */
@@ -46,6 +49,14 @@
      */
     public static final String NSD_KNOWN_ANSWER_SUPPRESSION = "nsd_known_answer_suppression";
 
+    /**
+     * A feature flag to control whether unicast replies should be enabled.
+     *
+     * <p>Enabling this feature causes replies to queries with the Query Unicast (QU) flag set to be
+     * sent unicast instead of multicast, as per RFC6762 5.4.
+     */
+    public static final String NSD_UNICAST_REPLY_ENABLED = "nsd_unicast_reply_enabled";
+
     // Flag for offload feature
     public final boolean mIsMdnsOffloadFeatureEnabled;
 
@@ -61,6 +72,36 @@
     // Flag for known-answer suppression
     public final boolean mIsKnownAnswerSuppressionEnabled;
 
+    // Flag to enable replying unicast to queries requesting unicast replies
+    public final boolean mIsUnicastReplyEnabled;
+
+    @Nullable
+    private final FlagOverrideProvider mOverrideProvider;
+
+    /**
+     * A provider that can indicate whether a flag should be force-enabled for testing purposes.
+     */
+    public interface FlagOverrideProvider {
+        /**
+         * Indicates whether the flag should be force-enabled for testing purposes.
+         */
+        boolean isForceEnabledForTest(@NonNull String flag);
+    }
+
+    /**
+     * Indicates whether the flag should be force-enabled for testing purposes.
+     */
+    private boolean isForceEnabledForTest(@NonNull String flag) {
+        return mOverrideProvider != null && mOverrideProvider.isForceEnabledForTest(flag);
+    }
+
+    /**
+     * Indicates whether {@link #NSD_UNICAST_REPLY_ENABLED} is enabled, including for testing.
+     */
+    public boolean isUnicastReplyEnabled() {
+        return mIsUnicastReplyEnabled || isForceEnabledForTest(NSD_UNICAST_REPLY_ENABLED);
+    }
+
     /**
      * The constructor for {@link MdnsFeatureFlags}.
      */
@@ -68,12 +109,16 @@
             boolean includeInetAddressRecordsInProbing,
             boolean isExpiredServicesRemovalEnabled,
             boolean isLabelCountLimitEnabled,
-            boolean isKnownAnswerSuppressionEnabled) {
+            boolean isKnownAnswerSuppressionEnabled,
+            boolean isUnicastReplyEnabled,
+            @Nullable FlagOverrideProvider overrideProvider) {
         mIsMdnsOffloadFeatureEnabled = isOffloadFeatureEnabled;
         mIncludeInetAddressRecordsInProbing = includeInetAddressRecordsInProbing;
         mIsExpiredServicesRemovalEnabled = isExpiredServicesRemovalEnabled;
         mIsLabelCountLimitEnabled = isLabelCountLimitEnabled;
         mIsKnownAnswerSuppressionEnabled = isKnownAnswerSuppressionEnabled;
+        mIsUnicastReplyEnabled = isUnicastReplyEnabled;
+        mOverrideProvider = overrideProvider;
     }
 
 
@@ -90,6 +135,8 @@
         private boolean mIsExpiredServicesRemovalEnabled;
         private boolean mIsLabelCountLimitEnabled;
         private boolean mIsKnownAnswerSuppressionEnabled;
+        private boolean mIsUnicastReplyEnabled;
+        private FlagOverrideProvider mOverrideProvider;
 
         /**
          * The constructor for {@link Builder}.
@@ -100,6 +147,8 @@
             mIsExpiredServicesRemovalEnabled = false;
             mIsLabelCountLimitEnabled = true; // Default enabled.
             mIsKnownAnswerSuppressionEnabled = false;
+            mIsUnicastReplyEnabled = true;
+            mOverrideProvider = null;
         }
 
         /**
@@ -154,6 +203,27 @@
         }
 
         /**
+         * Set whether the unicast reply feature is enabled.
+         *
+         * @see #NSD_UNICAST_REPLY_ENABLED
+         */
+        public Builder setIsUnicastReplyEnabled(boolean isUnicastReplyEnabled) {
+            mIsUnicastReplyEnabled = isUnicastReplyEnabled;
+            return this;
+        }
+
+        /**
+         * Set a {@link FlagOverrideProvider} to be used by {@link #isForceEnabledForTest(String)}.
+         *
+         * If non-null, features that use {@link #isForceEnabledForTest(String)} will use that
+         * provider to query whether the flag should be force-enabled.
+         */
+        public Builder setOverrideProvider(@Nullable FlagOverrideProvider overrideProvider) {
+            mOverrideProvider = overrideProvider;
+            return this;
+        }
+
+        /**
          * Builds a {@link MdnsFeatureFlags} with the arguments supplied to this builder.
          */
         public MdnsFeatureFlags build() {
@@ -161,7 +231,9 @@
                     mIncludeInetAddressRecordsInProbing,
                     mIsExpiredServicesRemovalEnabled,
                     mIsLabelCountLimitEnabled,
-                    mIsKnownAnswerSuppressionEnabled);
+                    mIsKnownAnswerSuppressionEnabled,
+                    mIsUnicastReplyEnabled,
+                    mOverrideProvider);
         }
     }
 }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
index 28bd1b4..4b43989 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecord.java
@@ -176,6 +176,16 @@
     }
 
     /**
+     * For questions, returns whether a unicast reply was requested.
+     *
+     * In practice this is identical to {@link #getCacheFlush()}, as the "cache flush" flag in
+     * replies is the same as "unicast reply requested" in questions.
+     */
+    public final boolean isUnicastReplyRequested() {
+        return (cls & MdnsConstants.QCLASS_UNICAST) != 0;
+    }
+
+    /**
      * Returns the record's remaining TTL.
      *
      * If the record was not sent yet (receipt time {@link #RECEIPT_TIME_NOT_SENT}), this is the
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
index 6b6632c..585b097 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsRecordRepository.java
@@ -499,26 +499,30 @@
     @Nullable
     public MdnsReplyInfo getReply(MdnsPacket packet, InetSocketAddress src) {
         final long now = SystemClock.elapsedRealtime();
-        final boolean replyUnicast = (packet.flags & MdnsConstants.QCLASS_UNICAST) != 0;
 
         // Use LinkedHashSet for preserving the insert order of the RRs, so that RRs of the same
         // service or host are grouped together (which is more developer-friendly).
         final Set<RecordInfo<?>> answerInfo = new LinkedHashSet<>();
         final Set<RecordInfo<?>> additionalAnswerInfo = new LinkedHashSet<>();
-
+        // Reply unicast if the feature is enabled AND all replied questions request unicast
+        final boolean replyUnicastEnabled = mMdnsFeatureFlags.isUnicastReplyEnabled();
+        boolean replyUnicast = replyUnicastEnabled;
         for (MdnsRecord question : packet.questions) {
             // Add answers from general records
-            addReplyFromService(question, mGeneralRecords, null /* servicePtrRecord */,
-                    null /* serviceSrvRecord */, null /* serviceTxtRecord */, replyUnicast, now,
-                    answerInfo, additionalAnswerInfo, Collections.emptyList());
+            if (addReplyFromService(question, mGeneralRecords, null /* servicePtrRecord */,
+                    null /* serviceSrvRecord */, null /* serviceTxtRecord */, replyUnicastEnabled,
+                    now, answerInfo, additionalAnswerInfo, Collections.emptyList())) {
+                replyUnicast &= question.isUnicastReplyRequested();
+            }
 
             // Add answers from each service
             for (int i = 0; i < mServices.size(); i++) {
                 final ServiceRegistration registration = mServices.valueAt(i);
                 if (registration.exiting || registration.isProbing) continue;
                 if (addReplyFromService(question, registration.allRecords, registration.ptrRecords,
-                        registration.srvRecord, registration.txtRecord, replyUnicast, now,
+                        registration.srvRecord, registration.txtRecord, replyUnicastEnabled, now,
                         answerInfo, additionalAnswerInfo, packet.answers)) {
+                    replyUnicast &= question.isUnicastReplyRequested();
                     registration.repliedServiceCount++;
                     registration.sentPacketCount++;
                 }
@@ -570,6 +574,12 @@
         // Determine the send destination
         final InetSocketAddress dest;
         if (replyUnicast) {
+            // As per RFC6762 5.4, "if the responder has not multicast that record recently (within
+            // one quarter of its TTL), then the responder SHOULD instead multicast the response so
+            // as to keep all the peer caches up to date": this SHOULD is not implemented to
+            // minimize latency for queriers who have just started, so they did not receive previous
+            // multicast responses. Unicast replies are faster as they do not need to wait for the
+            // beacon interval on Wi-Fi.
             dest = src;
         } else if (src.getAddress() instanceof Inet4Address) {
             dest = IPV4_SOCKET_ADDR;
@@ -608,7 +618,7 @@
             @Nullable List<RecordInfo<MdnsPointerRecord>> servicePtrRecords,
             @Nullable RecordInfo<MdnsServiceRecord> serviceSrvRecord,
             @Nullable RecordInfo<MdnsTextRecord> serviceTxtRecord,
-            boolean replyUnicast, long now, @NonNull Set<RecordInfo<?>> answerInfo,
+            boolean replyUnicastEnabled, long now, @NonNull Set<RecordInfo<?>> answerInfo,
             @NonNull Set<RecordInfo<?>> additionalAnswerInfo,
             @NonNull List<MdnsRecord> knownAnswerRecords) {
         boolean hasDnsSdPtrRecordAnswer = false;
@@ -659,7 +669,8 @@
 
             // TODO: responses to probe queries should bypass this check and only ensure the
             // reply is sent 250ms after the last sent time (RFC 6762 p.15)
-            if (!replyUnicast && info.lastAdvertisedTimeMs > 0L
+            if (!(replyUnicastEnabled && question.isUnicastReplyRequested())
+                    && info.lastAdvertisedTimeMs > 0L
                     && now - info.lastAdvertisedTimeMs < MIN_MULTICAST_REPLY_INTERVAL_MS) {
                 continue;
             }