vulkan: refactor DebugReportCallbackList

Simplify DebugReportCallbackList to be a thread-safe list with three
methods:

 - AddCallback adds a node to the list
 - RemoveCallback removes a node from the list
 - Message invokes each of the nodes on the list

Add some static methods for Node* and VkDebugReportCallbackEXT
conversions.

Bug: 28120066
Change-Id: I109c6eff368cacb37508e2549dbd0b5dfa23bcb3
diff --git a/vulkan/libvulkan/debug_report.cpp b/vulkan/libvulkan/debug_report.cpp
index c4a1174..0388fdc 100644
--- a/vulkan/libvulkan/debug_report.cpp
+++ b/vulkan/libvulkan/debug_report.cpp
@@ -19,62 +19,37 @@
 namespace vulkan {
 namespace driver {
 
-VkResult DebugReportCallbackList::CreateCallback(
-    VkInstance instance,
-    const VkDebugReportCallbackCreateInfoEXT* create_info,
-    const VkAllocationCallbacks* allocator,
-    VkDebugReportCallbackEXT* callback) {
-    VkDebugReportCallbackEXT driver_callback = VK_NULL_HANDLE;
+DebugReportCallbackList::Node* DebugReportCallbackList::AddCallback(
+    const VkDebugReportCallbackCreateInfoEXT& info,
+    VkDebugReportCallbackEXT driver_handle,
+    const VkAllocationCallbacks& allocator) {
+    void* mem = allocator.pfnAllocation(allocator.pUserData, sizeof(Node),
+                                        alignof(Node),
+                                        VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
+    if (!mem)
+        return nullptr;
 
-    if (GetData(instance).driver.CreateDebugReportCallbackEXT) {
-        VkResult result = GetData(instance).driver.CreateDebugReportCallbackEXT(
-            instance, create_info, allocator, &driver_callback);
-        if (result != VK_SUCCESS)
-            return result;
-    }
-
-    const VkAllocationCallbacks* alloc =
-        allocator ? allocator : &GetData(instance).allocator;
-    void* mem =
-        alloc->pfnAllocation(alloc->pUserData, sizeof(Node), alignof(Node),
-                             VK_SYSTEM_ALLOCATION_SCOPE_OBJECT);
-    if (!mem) {
-        if (GetData(instance).driver.DestroyDebugReportCallbackEXT) {
-            GetData(instance).driver.DestroyDebugReportCallbackEXT(
-                instance, driver_callback, allocator);
-        }
-        return VK_ERROR_OUT_OF_HOST_MEMORY;
-    }
-
+    // initialize and prepend node to the list
     std::lock_guard<decltype(rwmutex_)> lock(rwmutex_);
-    head_.next =
-        new (mem) Node{head_.next, create_info->flags, create_info->pfnCallback,
-                       create_info->pUserData, driver_callback};
-    *callback =
-        VkDebugReportCallbackEXT(reinterpret_cast<uintptr_t>(head_.next));
-    return VK_SUCCESS;
+    head_.next = new (mem) Node{head_.next, info.flags, info.pfnCallback,
+                                info.pUserData, driver_handle};
+
+    return head_.next;
 }
 
-void DebugReportCallbackList::DestroyCallback(
-    VkInstance instance,
-    VkDebugReportCallbackEXT callback,
-    const VkAllocationCallbacks* allocator) {
-    Node* node = reinterpret_cast<Node*>(uintptr_t(callback));
-    std::unique_lock<decltype(rwmutex_)> lock(rwmutex_);
-    Node* prev = &head_;
-    while (prev && prev->next != node)
-        prev = prev->next;
-    prev->next = node->next;
-    lock.unlock();
-
-    if (GetData(instance).driver.DestroyDebugReportCallbackEXT) {
-        GetData(instance).driver.DestroyDebugReportCallbackEXT(
-            instance, node->driver_callback, allocator);
+void DebugReportCallbackList::RemoveCallback(
+    Node* node,
+    const VkAllocationCallbacks& allocator) {
+    // remove node from the list
+    {
+        std::lock_guard<decltype(rwmutex_)> lock(rwmutex_);
+        Node* prev = &head_;
+        while (prev && prev->next != node)
+            prev = prev->next;
+        prev->next = node->next;
     }
 
-    const VkAllocationCallbacks* alloc =
-        allocator ? allocator : &GetData(instance).allocator;
-    alloc->pfnFree(alloc->pUserData, node);
+    allocator.pfnFree(allocator.pUserData, node);
 }
 
 void DebugReportCallbackList::Message(VkDebugReportFlagsEXT flags,
@@ -89,7 +64,7 @@
     while ((node = node->next)) {
         if ((node->flags & flags) != 0) {
             node->callback(flags, object_type, object, location, message_code,
-                           layer_prefix, message, node->data);
+                           layer_prefix, message, node->user_data);
         }
     }
 }
@@ -99,16 +74,50 @@
     const VkDebugReportCallbackCreateInfoEXT* create_info,
     const VkAllocationCallbacks* allocator,
     VkDebugReportCallbackEXT* callback) {
-    return GetData(instance).debug_report_callbacks.CreateCallback(
-        instance, create_info, allocator, callback);
+    const auto& driver = GetData(instance).driver;
+    VkDebugReportCallbackEXT driver_handle = VK_NULL_HANDLE;
+    if (driver.CreateDebugReportCallbackEXT) {
+        VkResult result = driver.CreateDebugReportCallbackEXT(
+            instance, create_info, allocator, &driver_handle);
+        if (result != VK_SUCCESS)
+            return result;
+    }
+
+    auto& callbacks = GetData(instance).debug_report_callbacks;
+    auto node = callbacks.AddCallback(
+        *create_info, driver_handle,
+        (allocator) ? *allocator : GetData(instance).allocator);
+    if (!node) {
+        if (driver_handle != VK_NULL_HANDLE) {
+            driver.DestroyDebugReportCallbackEXT(instance, driver_handle,
+                                                 allocator);
+        }
+
+        return VK_ERROR_OUT_OF_HOST_MEMORY;
+    }
+
+    *callback = callbacks.GetHandle(node);
+
+    return VK_SUCCESS;
 }
 
 void DestroyDebugReportCallbackEXT(VkInstance instance,
                                    VkDebugReportCallbackEXT callback,
                                    const VkAllocationCallbacks* allocator) {
-    if (callback)
-        GetData(instance).debug_report_callbacks.DestroyCallback(
-            instance, callback, allocator);
+    if (callback == VK_NULL_HANDLE)
+        return;
+
+    auto& callbacks = GetData(instance).debug_report_callbacks;
+    auto node = callbacks.FromHandle(callback);
+    auto driver_handle = callbacks.GetDriverHandle(node);
+
+    callbacks.RemoveCallback(
+        node, (allocator) ? *allocator : GetData(instance).allocator);
+
+    if (driver_handle != VK_NULL_HANDLE) {
+        GetData(instance).driver.DestroyDebugReportCallbackEXT(
+            instance, driver_handle, allocator);
+    }
 }
 
 void DebugReportMessageEXT(VkInstance instance,
diff --git a/vulkan/libvulkan/debug_report.h b/vulkan/libvulkan/debug_report.h
index 72b1887..7a03d4a 100644
--- a/vulkan/libvulkan/debug_report.h
+++ b/vulkan/libvulkan/debug_report.h
@@ -30,6 +30,10 @@
 // clang-format on
 
 class DebugReportCallbackList {
+   private:
+    // forward declaration
+    struct Node;
+
    public:
     DebugReportCallbackList()
         : head_{nullptr, 0, nullptr, nullptr, VK_NULL_HANDLE} {}
@@ -37,14 +41,11 @@
     DebugReportCallbackList& operator=(const DebugReportCallbackList&) = delete;
     ~DebugReportCallbackList() = default;
 
-    VkResult CreateCallback(
-        VkInstance instance,
-        const VkDebugReportCallbackCreateInfoEXT* create_info,
-        const VkAllocationCallbacks* allocator,
-        VkDebugReportCallbackEXT* callback);
-    void DestroyCallback(VkInstance instance,
-                         VkDebugReportCallbackEXT callback,
-                         const VkAllocationCallbacks* allocator);
+    Node* AddCallback(const VkDebugReportCallbackCreateInfoEXT& info,
+                      VkDebugReportCallbackEXT driver_handle,
+                      const VkAllocationCallbacks& allocator);
+    void RemoveCallback(Node* node, const VkAllocationCallbacks& allocator);
+
     void Message(VkDebugReportFlagsEXT flags,
                  VkDebugReportObjectTypeEXT object_type,
                  uint64_t object,
@@ -53,13 +54,27 @@
                  const char* layer_prefix,
                  const char* message);
 
+    static Node* FromHandle(VkDebugReportCallbackEXT handle) {
+        return reinterpret_cast<Node*>(uintptr_t(handle));
+    }
+
+    static VkDebugReportCallbackEXT GetHandle(const Node* node) {
+        return VkDebugReportCallbackEXT(reinterpret_cast<uintptr_t>(node));
+    }
+
+    static VkDebugReportCallbackEXT GetDriverHandle(const Node* node) {
+        return node->driver_handle;
+    }
+
    private:
     struct Node {
         Node* next;
+
         VkDebugReportFlagsEXT flags;
         PFN_vkDebugReportCallbackEXT callback;
-        void* data;
-        VkDebugReportCallbackEXT driver_callback;
+        void* user_data;
+
+        VkDebugReportCallbackEXT driver_handle;
     };
 
     // TODO(jessehall): replace with std::shared_mutex when available in libc++