Merge "service: Implement producer and data source unregistration"
diff --git a/src/tracing/core/service_impl.cc b/src/tracing/core/service_impl.cc
index 6f03839..491291c 100644
--- a/src/tracing/core/service_impl.cc
+++ b/src/tracing/core/service_impl.cc
@@ -106,6 +106,15 @@
   PERFETTO_DCHECK_THREAD(thread_checker_);
   PERFETTO_DLOG("Producer %" PRIu64 " disconnected", id);
   PERFETTO_DCHECK(producers_.count(id));
+
+  for (auto it = data_sources_.begin(); it != data_sources_.end();) {
+    auto next = it;
+    next++;
+    if (it->second.producer_id == id)
+      UnregisterDataSource(id, it->second.data_source_id);
+    it = next;
+  }
+
   producers_.erase(id);
 }
 
@@ -244,15 +253,8 @@
   for (const TraceConfig::DataSource& cfg_data_source : cfg.data_sources()) {
     // Scan all the registered data sources with a matching name.
     auto range = data_sources_.equal_range(cfg_data_source.config().name());
-    for (auto it = range.first; it != range.second; it++) {
-      const RegisteredDataSource& reg_data_source = it->second;
-      ProducerEndpointImpl* producer = GetProducer(reg_data_source.producer_id);
-      if (!producer) {
-        PERFETTO_DCHECK(false);  // Something in the unregistration is broken.
-        continue;
-      }
-      CreateDataSourceInstanceForProducer(cfg_data_source, producer, &ts);
-    }
+    for (auto it = range.first; it != range.second; it++)
+      CreateDataSourceInstance(cfg_data_source, it->second, &ts);
   }
 
   // Trigger delayed task if the trace is time limited.
@@ -288,10 +290,9 @@
 
   for (const auto& data_source_inst : tracing_session->data_source_instances) {
     const ProducerID producer_id = data_source_inst.first;
-    const DataSourceInstanceID ds_inst_id = data_source_inst.second;
+    const DataSourceInstanceID ds_inst_id = data_source_inst.second.instance_id;
     ProducerEndpointImpl* producer = GetProducer(producer_id);
-    if (!producer)
-      continue;  // This could legitimately happen if a Producer disconnects.
+    PERFETTO_DCHECK(producer);
     producer->producer_->TearDownDataSourceInstance(ds_inst_id);
   }
   tracing_session->data_source_instances.clear();
@@ -439,8 +440,8 @@
                 producer_id, desc.name().c_str(), ds_id);
 
   PERFETTO_DCHECK(!desc.name().empty());
-  data_sources_.emplace(desc.name(),
-                        RegisteredDataSource{producer_id, ds_id, desc});
+  auto reg_ds = data_sources_.emplace(
+      desc.name(), RegisteredDataSource{producer_id, ds_id, desc});
 
   // If there are existing tracing sessions, we need to check if the new
   // data source is enabled by any of them.
@@ -458,17 +459,51 @@
     for (const TraceConfig::DataSource& cfg_data_source :
          tracing_session.config.data_sources()) {
       if (cfg_data_source.config().name() == desc.name())
-        CreateDataSourceInstanceForProducer(cfg_data_source, producer,
-                                            &tracing_session);
+        CreateDataSourceInstance(cfg_data_source, reg_ds->second,
+                                 &tracing_session);
     }
   }
 }
 
-void ServiceImpl::CreateDataSourceInstanceForProducer(
+void ServiceImpl::UnregisterDataSource(ProducerID producer_id,
+                                       DataSourceID ds_id) {
+  PERFETTO_DCHECK_THREAD(thread_checker_);
+  PERFETTO_CHECK(producer_id);
+  PERFETTO_CHECK(ds_id);
+  ProducerEndpointImpl* producer = GetProducer(producer_id);
+  PERFETTO_DCHECK(producer);
+  for (auto& session : tracing_sessions_) {
+    auto it = session.second.data_source_instances.begin();
+    while (it != session.second.data_source_instances.end()) {
+      if (it->first == producer_id && it->second.data_source_id == ds_id) {
+        producer->producer_->TearDownDataSourceInstance(it->second.instance_id);
+        it = session.second.data_source_instances.erase(it);
+      } else {
+        ++it;
+      }
+    }
+  }
+
+  for (auto it = data_sources_.begin(); it != data_sources_.end(); ++it) {
+    if (it->second.producer_id == producer_id &&
+        it->second.data_source_id == ds_id) {
+      data_sources_.erase(it);
+      return;
+    }
+  }
+  PERFETTO_DLOG("Tried to unregister a non-existent data source %" PRIu64
+                " for producer %" PRIu64,
+                ds_id, producer_id);
+  PERFETTO_DCHECK(false);
+}
+
+void ServiceImpl::CreateDataSourceInstance(
     const TraceConfig::DataSource& cfg_data_source,
-    ProducerEndpointImpl* producer,
+    const RegisteredDataSource& data_source,
     TracingSession* tracing_session) {
   PERFETTO_DCHECK_THREAD(thread_checker_);
+  ProducerEndpointImpl* producer = GetProducer(data_source.producer_id);
+  PERFETTO_DCHECK(producer);
   // TODO(primiano): match against |producer_name_filter| and add tests
   // for registration ordering (data sources vs consumers).
 
@@ -494,7 +529,8 @@
   ds_config.set_target_buffer(global_id);
 
   DataSourceInstanceID inst_id = ++last_data_source_instance_id_;
-  tracing_session->data_source_instances.emplace(producer->id_, inst_id);
+  tracing_session->data_source_instances.emplace(
+      producer->id_, DataSourceInstance{inst_id, data_source.data_source_id});
   PERFETTO_DLOG("Starting data source %s with target buffer %" PRIu16,
                 ds_config.name().c_str(), global_id);
   producer->producer_->CreateDataSourceInstance(inst_id, ds_config);
@@ -553,8 +589,8 @@
     : service_(service), consumer_(consumer), weak_ptr_factory_(this) {}
 
 ServiceImpl::ConsumerEndpointImpl::~ConsumerEndpointImpl() {
-  consumer_->OnDisconnect();
   service_->DisconnectConsumer(this);
+  consumer_->OnDisconnect();
 }
 
 void ServiceImpl::ConsumerEndpointImpl::EnableTracing(const TraceConfig& cfg) {
@@ -621,8 +657,8 @@
 }
 
 ServiceImpl::ProducerEndpointImpl::~ProducerEndpointImpl() {
-  producer_->OnDisconnect();
   service_->DisconnectProducer(id_);
+  producer_->OnDisconnect();
 }
 
 void ServiceImpl::ProducerEndpointImpl::RegisterDataSource(
@@ -640,10 +676,10 @@
 }
 
 void ServiceImpl::ProducerEndpointImpl::UnregisterDataSource(
-    DataSourceID dsid) {
+    DataSourceID ds_id) {
   PERFETTO_DCHECK_THREAD(thread_checker_);
-  PERFETTO_CHECK(dsid);
-  // TODO(primiano): implement the bookkeeping logic.
+  PERFETTO_CHECK(ds_id);
+  service_->UnregisterDataSource(id_, ds_id);
 }
 
 void ServiceImpl::ProducerEndpointImpl::NotifySharedMemoryUpdate(
diff --git a/src/tracing/core/service_impl.h b/src/tracing/core/service_impl.h
index cb1ab16..d4b4947 100644
--- a/src/tracing/core/service_impl.h
+++ b/src/tracing/core/service_impl.h
@@ -123,6 +123,7 @@
   void RegisterDataSource(ProducerID,
                           DataSourceID,
                           const DataSourceDescriptor&);
+  void UnregisterDataSource(ProducerID, DataSourceID);
   void CopyProducerPageIntoLogBuffer(ProducerID,
                                      BufferID,
                                      const uint8_t*,
@@ -155,6 +156,12 @@
     DataSourceDescriptor descriptor;
   };
 
+  // Represents an active data source for a tracing session.
+  struct DataSourceInstance {
+    DataSourceInstanceID instance_id;
+    DataSourceID data_source_id;
+  };
+
   struct TraceBuffer {
     TraceBuffer();
     ~TraceBuffer();
@@ -208,7 +215,7 @@
 
     // List of data source instances that have been enabled on the various
     // producers for this tracing session.
-    std::multimap<ProducerID, DataSourceInstanceID> data_source_instances;
+    std::multimap<ProducerID, DataSourceInstance> data_source_instances;
 
     // Maps a per-trace-session buffer index into the corresponding global
     // BufferID (shared namespace amongst all consumers). This vector has as
@@ -219,10 +226,9 @@
   ServiceImpl(const ServiceImpl&) = delete;
   ServiceImpl& operator=(const ServiceImpl&) = delete;
 
-  void CreateDataSourceInstanceForProducer(
-      const TraceConfig::DataSource& cfg_data_source,
-      ProducerEndpointImpl* producer,
-      TracingSession* tracing_session);
+  void CreateDataSourceInstance(const TraceConfig::DataSource&,
+                                const RegisteredDataSource&,
+                                TracingSession*);
 
   // Returns a pointer to the |tracing_sessions_| entry or nullptr if the
   // session doesn't exists.
diff --git a/src/tracing/core/service_impl_unittest.cc b/src/tracing/core/service_impl_unittest.cc
index 8f13777..3ab0b11 100644
--- a/src/tracing/core/service_impl_unittest.cc
+++ b/src/tracing/core/service_impl_unittest.cc
@@ -20,10 +20,12 @@
 
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
+#include "perfetto/tracing/core/consumer.h"
 #include "perfetto/tracing/core/data_source_config.h"
 #include "perfetto/tracing/core/data_source_descriptor.h"
 #include "perfetto/tracing/core/producer.h"
 #include "perfetto/tracing/core/shared_memory.h"
+#include "perfetto/tracing/core/trace_packet.h"
 #include "src/base/test/test_task_runner.h"
 #include "src/tracing/test/test_shared_memory.h"
 
@@ -46,14 +48,34 @@
   MOCK_METHOD1(TearDownDataSourceInstance, void(DataSourceInstanceID));
 };
 
+class MockConsumer : public Consumer {
+ public:
+  ~MockConsumer() override {}
+
+  // Consumer implementation.
+  MOCK_METHOD0(OnConnect, void());
+  MOCK_METHOD0(OnDisconnect, void());
+
+  void OnTraceData(std::vector<TracePacket> packets, bool has_more) override {}
+};
+
 }  // namespace
 
-TEST(ServiceImplTest, RegisterAndUnregister) {
+class ServiceImplTest : public testing::Test {
+ public:
+  ServiceImplTest() {
+    auto shm_factory =
+        std::unique_ptr<SharedMemory::Factory>(new TestSharedMemory::Factory());
+    svc.reset(static_cast<ServiceImpl*>(
+        Service::CreateInstance(std::move(shm_factory), &task_runner)
+            .release()));
+  }
+
   base::TestTaskRunner task_runner;
-  auto shm_factory =
-      std::unique_ptr<SharedMemory::Factory>(new TestSharedMemory::Factory());
-  std::unique_ptr<ServiceImpl> svc(static_cast<ServiceImpl*>(
-      Service::CreateInstance(std::move(shm_factory), &task_runner).release()));
+  std::unique_ptr<ServiceImpl> svc;
+};
+
+TEST_F(ServiceImplTest, RegisterAndUnregister) {
   MockProducer mock_producer_1;
   MockProducer mock_producer_2;
   std::unique_ptr<Service::ProducerEndpoint> producer_endpoint_1 =
@@ -78,7 +100,7 @@
   DataSourceDescriptor ds_desc1;
   ds_desc1.set_name("foo");
   producer_endpoint_1->RegisterDataSource(
-      ds_desc1, [&task_runner, &producer_endpoint_1](DataSourceID id) {
+      ds_desc1, [this, &producer_endpoint_1](DataSourceID id) {
         EXPECT_EQ(1u, id);
         task_runner.PostTask(
             std::bind(&Service::ProducerEndpoint::UnregisterDataSource,
@@ -88,7 +110,7 @@
   DataSourceDescriptor ds_desc2;
   ds_desc2.set_name("bar");
   producer_endpoint_2->RegisterDataSource(
-      ds_desc2, [&task_runner, &producer_endpoint_2](DataSourceID id) {
+      ds_desc2, [this, &producer_endpoint_2](DataSourceID id) {
         EXPECT_EQ(1u, id);
         task_runner.PostTask(
             std::bind(&Service::ProducerEndpoint::UnregisterDataSource,
@@ -113,4 +135,136 @@
   ASSERT_EQ(0u, svc->num_producers());
 }
 
+TEST_F(ServiceImplTest, EnableAndDisableTracing) {
+  MockProducer mock_producer;
+  std::unique_ptr<Service::ProducerEndpoint> producer_endpoint =
+      svc->ConnectProducer(&mock_producer, 123u /* uid */);
+  MockConsumer mock_consumer;
+  std::unique_ptr<Service::ConsumerEndpoint> consumer_endpoint =
+      svc->ConnectConsumer(&mock_consumer);
+
+  InSequence seq;
+  EXPECT_CALL(mock_producer, OnConnect());
+  EXPECT_CALL(mock_consumer, OnConnect());
+  task_runner.RunUntilIdle();
+
+  DataSourceDescriptor ds_desc;
+  ds_desc.set_name("foo");
+  producer_endpoint->RegisterDataSource(ds_desc, [](DataSourceID) {});
+
+  task_runner.RunUntilIdle();
+
+  EXPECT_CALL(mock_producer, CreateDataSourceInstance(_, _));
+  EXPECT_CALL(mock_producer, TearDownDataSourceInstance(_));
+  TraceConfig trace_config;
+  trace_config.add_buffers()->set_size_kb(4096 * 10);
+  auto* ds_config = trace_config.add_data_sources()->mutable_config();
+  ds_config->set_name("foo");
+  ds_config->set_target_buffer(0);
+  consumer_endpoint->EnableTracing(trace_config);
+  task_runner.RunUntilIdle();
+
+  EXPECT_CALL(mock_producer, OnDisconnect());
+  EXPECT_CALL(mock_consumer, OnDisconnect());
+  consumer_endpoint->DisableTracing();
+  producer_endpoint.reset();
+  consumer_endpoint.reset();
+  task_runner.RunUntilIdle();
+  Mock::VerifyAndClearExpectations(&mock_producer);
+  Mock::VerifyAndClearExpectations(&mock_consumer);
+}
+
+TEST_F(ServiceImplTest, DisconnectConsumerWhileTracing) {
+  MockProducer mock_producer;
+  std::unique_ptr<Service::ProducerEndpoint> producer_endpoint =
+      svc->ConnectProducer(&mock_producer, 123u /* uid */);
+  MockConsumer mock_consumer;
+  std::unique_ptr<Service::ConsumerEndpoint> consumer_endpoint =
+      svc->ConnectConsumer(&mock_consumer);
+
+  InSequence seq;
+  EXPECT_CALL(mock_producer, OnConnect());
+  EXPECT_CALL(mock_consumer, OnConnect());
+  task_runner.RunUntilIdle();
+
+  DataSourceDescriptor ds_desc;
+  ds_desc.set_name("foo");
+  producer_endpoint->RegisterDataSource(ds_desc, [](DataSourceID) {});
+  task_runner.RunUntilIdle();
+
+  // Disconnecting the consumer while tracing should trigger data source
+  // teardown.
+  EXPECT_CALL(mock_producer, CreateDataSourceInstance(_, _));
+  EXPECT_CALL(mock_producer, TearDownDataSourceInstance(_));
+  TraceConfig trace_config;
+  trace_config.add_buffers()->set_size_kb(4096 * 10);
+  auto* ds_config = trace_config.add_data_sources()->mutable_config();
+  ds_config->set_name("foo");
+  ds_config->set_target_buffer(0);
+  consumer_endpoint->EnableTracing(trace_config);
+  task_runner.RunUntilIdle();
+
+  EXPECT_CALL(mock_consumer, OnDisconnect());
+  consumer_endpoint.reset();
+  task_runner.RunUntilIdle();
+
+  EXPECT_CALL(mock_producer, OnDisconnect());
+  producer_endpoint.reset();
+  Mock::VerifyAndClearExpectations(&mock_producer);
+  Mock::VerifyAndClearExpectations(&mock_consumer);
+}
+
+TEST_F(ServiceImplTest, ReconnectProducerWhileTracing) {
+  MockProducer mock_producer;
+  std::unique_ptr<Service::ProducerEndpoint> producer_endpoint =
+      svc->ConnectProducer(&mock_producer, 123u /* uid */);
+  MockConsumer mock_consumer;
+  std::unique_ptr<Service::ConsumerEndpoint> consumer_endpoint =
+      svc->ConnectConsumer(&mock_consumer);
+
+  InSequence seq;
+  EXPECT_CALL(mock_producer, OnConnect());
+  EXPECT_CALL(mock_consumer, OnConnect());
+  task_runner.RunUntilIdle();
+
+  DataSourceDescriptor ds_desc;
+  ds_desc.set_name("foo");
+  producer_endpoint->RegisterDataSource(ds_desc, [](DataSourceID) {});
+  task_runner.RunUntilIdle();
+
+  // Disconnecting the producer while tracing should trigger data source
+  // teardown.
+  EXPECT_CALL(mock_producer, CreateDataSourceInstance(_, _));
+  EXPECT_CALL(mock_producer, TearDownDataSourceInstance(_));
+  EXPECT_CALL(mock_producer, OnDisconnect());
+  TraceConfig trace_config;
+  trace_config.add_buffers()->set_size_kb(4096 * 10);
+  auto* ds_config = trace_config.add_data_sources()->mutable_config();
+  ds_config->set_name("foo");
+  ds_config->set_target_buffer(0);
+  consumer_endpoint->EnableTracing(trace_config);
+  producer_endpoint.reset();
+  task_runner.RunUntilIdle();
+
+  // Reconnecting a producer with a matching data source should see that data
+  // source getting enabled.
+  EXPECT_CALL(mock_producer, OnConnect());
+  producer_endpoint = svc->ConnectProducer(&mock_producer, 123u /* uid */);
+  task_runner.RunUntilIdle();
+  EXPECT_CALL(mock_producer, CreateDataSourceInstance(_, _));
+  EXPECT_CALL(mock_producer, TearDownDataSourceInstance(_));
+  producer_endpoint->RegisterDataSource(ds_desc, [](DataSourceID) {});
+  task_runner.RunUntilIdle();
+
+  EXPECT_CALL(mock_consumer, OnDisconnect());
+  consumer_endpoint->DisableTracing();
+  consumer_endpoint.reset();
+  task_runner.RunUntilIdle();
+
+  EXPECT_CALL(mock_producer, OnDisconnect());
+  producer_endpoint.reset();
+  Mock::VerifyAndClearExpectations(&mock_producer);
+  Mock::VerifyAndClearExpectations(&mock_consumer);
+}
+
 }  // namespace perfetto