blast: fix registering callbacks

This is a better fix for b/134194071 than ag/7998544. The previous
patch only protected the call to notify the binder thread to send a
callback.

This patch has both a start registration and end registration call.
During that time, the transaction callback cannot be sent. This is
closer to a long term fix for the bug.

Bug: 134194071
Test: Switch between front and back cameras to make sure the app
    doesn't crash.

Change-Id: I2d20c13cc1c8d13e5a1340dfaa8cbbaa4d3a30ab
diff --git a/libs/gui/include/gui/ITransactionCompletedListener.h b/libs/gui/include/gui/ITransactionCompletedListener.h
index 774ad46..cbfd365 100644
--- a/libs/gui/include/gui/ITransactionCompletedListener.h
+++ b/libs/gui/include/gui/ITransactionCompletedListener.h
@@ -106,6 +106,16 @@
                       const std::vector<CallbackId>& ids)
           : transactionCompletedListener(listener), callbackIds(ids) {}
 
+    bool operator==(const ListenerCallbacks& rhs) const {
+        if (transactionCompletedListener != rhs.transactionCompletedListener) {
+            return false;
+        }
+        if (callbackIds.empty()) {
+            return rhs.callbackIds.empty();
+        }
+        return callbackIds.front() == rhs.callbackIds.front();
+    }
+
     sp<ITransactionCompletedListener> transactionCompletedListener;
     std::vector<CallbackId> callbackIds;
 };
diff --git a/services/surfaceflinger/BufferStateLayer.cpp b/services/surfaceflinger/BufferStateLayer.cpp
index 2abc1a7..4b01301 100644
--- a/services/surfaceflinger/BufferStateLayer.cpp
+++ b/services/surfaceflinger/BufferStateLayer.cpp
@@ -85,7 +85,7 @@
 }
 
 void BufferStateLayer::releasePendingBuffer(nsecs_t /*dequeueReadyTime*/) {
-    mFlinger->getTransactionCompletedThread().addPresentedCallbackHandles(
+    mFlinger->getTransactionCompletedThread().finalizePendingCallbackHandles(
             mDrawingState.callbackHandles);
 
     mDrawingState.callbackHandles = {};
@@ -310,7 +310,7 @@
 
         } else { // If this layer will NOT need to be relatched and presented this frame
             // Notify the transaction completed thread this handle is done
-            mFlinger->getTransactionCompletedThread().addUnpresentedCallbackHandle(handle);
+            mFlinger->getTransactionCompletedThread().registerUnpresentedCallbackHandle(handle);
         }
     }
 
diff --git a/services/surfaceflinger/SurfaceFlinger.cpp b/services/surfaceflinger/SurfaceFlinger.cpp
index f53d3fa..7d8630f 100644
--- a/services/surfaceflinger/SurfaceFlinger.cpp
+++ b/services/surfaceflinger/SurfaceFlinger.cpp
@@ -2183,14 +2183,7 @@
     }
 
     mTransactionCompletedThread.addPresentFence(mPreviousPresentFences[0]);
-
-    // Lock the mStateLock in case SurfaceFlinger is in the middle of applying a transaction.
-    // If we do not lock here, a callback could be sent without all of its SurfaceControls and
-    // metrics.
-    {
-        Mutex::Autolock _l(mStateLock);
-        mTransactionCompletedThread.sendCallbacks();
-    }
+    mTransactionCompletedThread.sendCallbacks();
 
     if (mLumaSampling && mRegionSamplingThread) {
         mRegionSamplingThread->notifyNewContent();
@@ -3790,8 +3783,8 @@
     if (!listenerCallbacks.empty()) {
         mTransactionCompletedThread.run();
     }
-    for (const auto& [listener, callbackIds] : listenerCallbacks) {
-        mTransactionCompletedThread.addCallback(listener, callbackIds);
+    for (const auto& listenerCallback : listenerCallbacks) {
+        mTransactionCompletedThread.startRegistration(listenerCallback);
     }
 
     uint32_t clientStateFlags = 0;
@@ -3800,6 +3793,10 @@
                                                  postTime, privileged);
     }
 
+    for (const auto& listenerCallback : listenerCallbacks) {
+        mTransactionCompletedThread.endRegistration(listenerCallback);
+    }
+
     // If the state doesn't require a traversal and there are callbacks, send them now
     if (!(clientStateFlags & eTraversalNeeded) && !listenerCallbacks.empty()) {
         mTransactionCompletedThread.sendCallbacks();
@@ -3937,7 +3934,7 @@
     sp<Layer> layer(client->getLayerUser(s.surface));
     if (layer == nullptr) {
         for (auto& listenerCallback : listenerCallbacks) {
-            mTransactionCompletedThread.addUnpresentedCallbackHandle(
+            mTransactionCompletedThread.registerUnpresentedCallbackHandle(
                     new CallbackHandle(listenerCallback.transactionCompletedListener,
                                        listenerCallback.callbackIds, s.surface));
         }
diff --git a/services/surfaceflinger/TransactionCompletedThread.cpp b/services/surfaceflinger/TransactionCompletedThread.cpp
index 5cf8eb1..0806f2a 100644
--- a/services/surfaceflinger/TransactionCompletedThread.cpp
+++ b/services/surfaceflinger/TransactionCompletedThread.cpp
@@ -75,14 +75,15 @@
     mThread = std::thread(&TransactionCompletedThread::threadMain, this);
 }
 
-status_t TransactionCompletedThread::addCallback(const sp<ITransactionCompletedListener>& listener,
-                                                 const std::vector<CallbackId>& callbackIds) {
+status_t TransactionCompletedThread::startRegistration(const ListenerCallbacks& listenerCallbacks) {
     std::lock_guard lock(mMutex);
     if (!mRunning) {
         ALOGE("cannot add callback because the callback thread isn't running");
         return BAD_VALUE;
     }
 
+    auto& [listener, callbackIds] = listenerCallbacks;
+
     if (mCompletedTransactions.count(listener) == 0) {
         status_t err = IInterface::asBinder(listener)->linkToDeath(mDeathRecipient);
         if (err != NO_ERROR) {
@@ -91,11 +92,41 @@
         }
     }
 
+    mRegisteringTransactions.insert(listenerCallbacks);
+
     auto& transactionStatsDeque = mCompletedTransactions[listener];
     transactionStatsDeque.emplace_back(callbackIds);
+
     return NO_ERROR;
 }
 
+status_t TransactionCompletedThread::endRegistration(const ListenerCallbacks& listenerCallbacks) {
+    std::lock_guard lock(mMutex);
+    if (!mRunning) {
+        ALOGE("cannot add callback because the callback thread isn't running");
+        return BAD_VALUE;
+    }
+
+    auto itr = mRegisteringTransactions.find(listenerCallbacks);
+    if (itr == mRegisteringTransactions.end()) {
+        ALOGE("cannot end a registration that does not exist");
+        return BAD_VALUE;
+    }
+
+    mRegisteringTransactions.erase(itr);
+
+    return NO_ERROR;
+}
+
+bool TransactionCompletedThread::isRegisteringTransaction(
+        const sp<ITransactionCompletedListener>& transactionListener,
+        const std::vector<CallbackId>& callbackIds) {
+    ListenerCallbacks listenerCallbacks(transactionListener, callbackIds);
+
+    auto itr = mRegisteringTransactions.find(listenerCallbacks);
+    return itr != mRegisteringTransactions.end();
+}
+
 status_t TransactionCompletedThread::registerPendingCallbackHandle(
         const sp<CallbackHandle>& handle) {
     std::lock_guard lock(mMutex);
@@ -105,7 +136,7 @@
     }
 
     // If we can't find the transaction stats something has gone wrong. The client should call
-    // addCallback before trying to register a pending callback handle.
+    // startRegistration before trying to register a pending callback handle.
     TransactionStats* transactionStats;
     status_t err = findTransactionStats(handle->listener, handle->callbackIds, &transactionStats);
     if (err != NO_ERROR) {
@@ -117,7 +148,7 @@
     return NO_ERROR;
 }
 
-status_t TransactionCompletedThread::addPresentedCallbackHandles(
+status_t TransactionCompletedThread::finalizePendingCallbackHandles(
         const std::deque<sp<CallbackHandle>>& handles) {
     std::lock_guard lock(mMutex);
     if (!mRunning) {
@@ -158,7 +189,7 @@
     return NO_ERROR;
 }
 
-status_t TransactionCompletedThread::addUnpresentedCallbackHandle(
+status_t TransactionCompletedThread::registerUnpresentedCallbackHandle(
         const sp<CallbackHandle>& handle) {
     std::lock_guard lock(mMutex);
     if (!mRunning) {
@@ -189,7 +220,7 @@
 
 status_t TransactionCompletedThread::addCallbackHandle(const sp<CallbackHandle>& handle) {
     // If we can't find the transaction stats something has gone wrong. The client should call
-    // addCallback before trying to add a presnted callback handle.
+    // startRegistration before trying to add a callback handle.
     TransactionStats* transactionStats;
     status_t err = findTransactionStats(handle->listener, handle->callbackIds, &transactionStats);
     if (err != NO_ERROR) {
@@ -233,6 +264,13 @@
             while (transactionStatsItr != transactionStatsDeque.end()) {
                 auto& transactionStats = *transactionStatsItr;
 
+                // If this transaction is still registering, it is not safe to send a callback
+                // because there could be surface controls that haven't been added to
+                // transaction stats or mPendingTransactions.
+                if (isRegisteringTransaction(listener, transactionStats.callbackIds)) {
+                    break;
+                }
+
                 // If we are still waiting on the callback handles for this transaction, stop
                 // here because all transaction callbacks for the same listener must come in order
                 auto pendingTransactions = mPendingTransactions.find(listener);
diff --git a/services/surfaceflinger/TransactionCompletedThread.h b/services/surfaceflinger/TransactionCompletedThread.h
index 21e2678..b821350 100644
--- a/services/surfaceflinger/TransactionCompletedThread.h
+++ b/services/surfaceflinger/TransactionCompletedThread.h
@@ -21,6 +21,7 @@
 #include <mutex>
 #include <thread>
 #include <unordered_map>
+#include <unordered_set>
 
 #include <android-base/thread_annotations.h>
 
@@ -30,6 +31,12 @@
 
 namespace android {
 
+struct ITransactionCompletedListenerHash {
+    std::size_t operator()(const sp<ITransactionCompletedListener>& listener) const {
+        return std::hash<IBinder*>{}((listener) ? IInterface::asBinder(listener).get() : nullptr);
+    }
+};
+
 struct CallbackIdsHash {
     // CallbackId vectors have several properties that let us get away with this simple hash.
     // 1) CallbackIds are never 0 so if something has gone wrong and our CallbackId vector is
@@ -42,6 +49,22 @@
     }
 };
 
+struct ListenerCallbacksHash {
+    std::size_t HashCombine(size_t value1, size_t value2) const {
+        return value1 ^ (value2 + 0x9e3779b9 + (value1 << 6) + (value1 >> 2));
+    }
+
+    std::size_t operator()(const ListenerCallbacks& listenerCallbacks) const {
+        struct ITransactionCompletedListenerHash listenerHasher;
+        struct CallbackIdsHash callbackIdsHasher;
+
+        std::size_t listenerHash = listenerHasher(listenerCallbacks.transactionCompletedListener);
+        std::size_t callbackIdsHash = callbackIdsHasher(listenerCallbacks.callbackIds);
+
+        return HashCombine(listenerHash, callbackIdsHash);
+    }
+};
+
 class CallbackHandle : public RefBase {
 public:
     CallbackHandle(const sp<ITransactionCompletedListener>& transactionListener,
@@ -64,10 +87,12 @@
     void run();
 
     // Adds listener and callbackIds in case there are no SurfaceControls that are supposed
-    // to be included in the callback. This functions should be call before attempting to add any
-    // callback handles.
-    status_t addCallback(const sp<ITransactionCompletedListener>& transactionListener,
-                         const std::vector<CallbackId>& callbackIds);
+    // to be included in the callback. This functions should be call before attempting to register
+    // any callback handles.
+    status_t startRegistration(const ListenerCallbacks& listenerCallbacks);
+    // Ends the registration. After this is called, no more CallbackHandles will be registered.
+    // It is safe to send a callback if the Transaction doesn't have any Pending callback handles.
+    status_t endRegistration(const ListenerCallbacks& listenerCallbacks);
 
     // Informs the TransactionCompletedThread that there is a Transaction with a CallbackHandle
     // that needs to be latched and presented this frame. This function should be called once the
@@ -76,11 +101,11 @@
     // presented.
     status_t registerPendingCallbackHandle(const sp<CallbackHandle>& handle);
     // Notifies the TransactionCompletedThread that a pending CallbackHandle has been presented.
-    status_t addPresentedCallbackHandles(const std::deque<sp<CallbackHandle>>& handles);
+    status_t finalizePendingCallbackHandles(const std::deque<sp<CallbackHandle>>& handles);
 
     // Adds the Transaction CallbackHandle from a layer that does not need to be relatched and
     // presented this frame.
-    status_t addUnpresentedCallbackHandle(const sp<CallbackHandle>& handle);
+    status_t registerUnpresentedCallbackHandle(const sp<CallbackHandle>& handle);
 
     void addPresentFence(const sp<Fence>& presentFence);
 
@@ -89,6 +114,9 @@
 private:
     void threadMain();
 
+    bool isRegisteringTransaction(const sp<ITransactionCompletedListener>& transactionListener,
+                                  const std::vector<CallbackId>& callbackIds) REQUIRES(mMutex);
+
     status_t findTransactionStats(const sp<ITransactionCompletedListener>& listener,
                                   const std::vector<CallbackId>& callbackIds,
                                   TransactionStats** outTransactionStats) REQUIRES(mMutex);
@@ -106,13 +134,6 @@
     };
     sp<ThreadDeathRecipient> mDeathRecipient;
 
-    struct ITransactionCompletedListenerHash {
-        std::size_t operator()(const sp<ITransactionCompletedListener>& listener) const {
-            return std::hash<IBinder*>{}((listener) ? IInterface::asBinder(listener).get()
-                                                    : nullptr);
-        }
-    };
-
     // Protects the creation and destruction of mThread
     std::mutex mThreadMutex;
 
@@ -121,11 +142,15 @@
     std::mutex mMutex;
     std::condition_variable_any mConditionVariable;
 
+    std::unordered_set<ListenerCallbacks, ListenerCallbacksHash> mRegisteringTransactions
+            GUARDED_BY(mMutex);
+
     std::unordered_map<
             sp<ITransactionCompletedListener>,
             std::unordered_map<std::vector<CallbackId>, uint32_t /*count*/, CallbackIdsHash>,
             ITransactionCompletedListenerHash>
             mPendingTransactions GUARDED_BY(mMutex);
+
     std::unordered_map<sp<ITransactionCompletedListener>, std::deque<TransactionStats>,
                        ITransactionCompletedListenerHash>
             mCompletedTransactions GUARDED_BY(mMutex);