Merge "Attempt to fix NetworkTracePollerTest flake"
diff --git a/service-t/native/libs/libnetworkstats/NetworkTracePollerTest.cpp b/service-t/native/libs/libnetworkstats/NetworkTracePollerTest.cpp
index 725cec1..df07bbe 100644
--- a/service-t/native/libs/libnetworkstats/NetworkTracePollerTest.cpp
+++ b/service-t/native/libs/libnetworkstats/NetworkTracePollerTest.cpp
@@ -26,6 +26,8 @@
 #include <sys/types.h>
 #include <unistd.h>
 
+#include <chrono>
+#include <thread>
 #include <vector>
 
 #include "netdbpf/NetworkTracePoller.h"
@@ -130,16 +132,21 @@
 
 TEST_F(NetworkTracePollerTest, TraceTcpSession) {
   __be16 server_port = 0;
-  std::vector<PacketTrace> packets;
+  std::vector<PacketTrace> packets, unmatched;
 
   // Record all packets with the bound address and current uid. This callback is
   // involked only within ConsumeAll, at which point the port should have
   // already been filled in and all packets have been processed.
   NetworkTracePoller handler([&](const std::vector<PacketTrace>& pkts) {
     for (const PacketTrace& pkt : pkts) {
-      if (pkt.sport != server_port && pkt.dport != server_port) return;
-      if (pkt.uid != getuid()) return;
-      packets.push_back(pkt);
+      if ((pkt.sport == server_port || pkt.dport == server_port) &&
+          pkt.uid == getuid()) {
+        packets.push_back(pkt);
+      } else {
+        // There may be spurious packets not caused by the test. These are only
+        // captured so that we can report them to help debug certain errors.
+        unmatched.push_back(pkt);
+      }
     }
   });
 
@@ -179,7 +186,13 @@
     EXPECT_EQ(std::string(data), std::string(buff));
   }
 
-  ASSERT_TRUE(handler.ConsumeAll());
+  // Poll until we get all the packets (typically we get it first try).
+  for (int attempt = 0; attempt < 10; attempt++) {
+    ASSERT_TRUE(handler.ConsumeAll());
+    if (packets.size() >= 12) break;
+    std::this_thread::sleep_for(std::chrono::milliseconds(5));
+  }
+
   ASSERT_TRUE(handler.Stop());
 
   // There are 12 packets in total (6 messages: each seen by client & server):
@@ -189,7 +202,9 @@
   // 4. Client sends data with psh ack
   // 5. Server acks the data packet
   // 6. Client closes connection with fin ack
-  ASSERT_EQ(packets.size(), 12) << PacketPrinter{packets};
+  ASSERT_EQ(packets.size(), 12)
+      << PacketPrinter{packets}
+      << "\nUnmatched packets: " << PacketPrinter{unmatched};
 
   // All packets should be TCP packets.
   EXPECT_THAT(packets, Each(Field(&PacketTrace::ipProto, Eq(IPPROTO_TCP))));