libsnapshot: Implement merge flow.

This implements InitiateMerge() and WaitForMerge(). InitiateMerge() is
meant to be called after an update has been marked successful.
WaitForMerge() is designed to be called either: immediately after
InitiateMerge, or during each subsequent boot where merging has not
completed.

InitiateMerge converts each snapshot device to a snapshot-merge device.

WaitForMerge polls each snapshot-merge device until no device reports a
"merging" state. One of the following states can result from this:
 - MergeFailed. This will happen if any device failed to merge, or we
   were unable to poll, or any other system-level failure occurred.
 - MergeNeedsReboot. This will happen if a snapshot-merge device has
   completed merging, but we were unable to clean it up due to something
   holding a resource open.
 - MergeCompleted. This indicates that all snapshots completed merging
   and were cleaned up.

If WaitForMerge() returns MergeCompleted, then all snapshots have been
removed and a new update can begin. GetUpdateState() will return None.

MergeFailed and MergeNeedsReboot, on the other hand, are "sticky". They
indicate a merge is still pending. When called again, WaitForMerge()
will poll again to attempt to make more progress in the merge. For
NeedsReboot, a single reboot will ensure all resources are released and
the next WaitForMerge() will successfully finish cleanup. In the failure
case, it is unlikely the next WaitForMerge will succeed, but we always
retry anyway (there is no harm in doing so, and if we get lucky, the
device can take more OTAs).

Bug: 136678799
Test: libsnapshot_test gtests
Change-Id: I5e93fcbffee1973da5ff76363df12d6317a7a7c7
diff --git a/fs_mgr/libdm/include/libdm/dm.h b/fs_mgr/libdm/include/libdm/dm.h
index f5783cb..cf306f3 100644
--- a/fs_mgr/libdm/include/libdm/dm.h
+++ b/fs_mgr/libdm/include/libdm/dm.h
@@ -197,6 +197,7 @@
     struct TargetInfo {
         struct dm_target_spec spec;
         std::string data;
+        TargetInfo() {}
         TargetInfo(const struct dm_target_spec& spec, const std::string& data)
             : spec(spec), data(data) {}
     };
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h b/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h
index 103c128..f7608dc 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h
@@ -22,7 +22,8 @@
 #include <vector>
 
 #include <android-base/unique_fd.h>
-#include <libdm/dm_target.h>
+#include <libdm/dm.h>
+#include <libfiemap/image_manager.h>
 
 #ifndef FRIEND_TEST
 #define FRIEND_TEST(test_set_name, individual_test) \
@@ -38,7 +39,7 @@
 
 namespace snapshot {
 
-enum class UpdateState {
+enum class UpdateState : unsigned int {
     // No update or merge is in progress.
     None,
 
@@ -51,6 +52,10 @@
     // The kernel is merging in the background.
     Merging,
 
+    // Post-merge cleanup steps could not be completed due to a transient
+    // error, but the next reboot will finish any pending operations.
+    MergeNeedsReboot,
+
     // Merging is complete, and needs to be acknowledged.
     MergeCompleted,
 
@@ -94,8 +99,23 @@
 
     // Wait for the current merge to finish, then perform cleanup when it
     // completes. It is necessary to call this after InitiateMerge(), or when
-    // a merge is detected for the first time after boot.
-    bool WaitForMerge();
+    // a merge state is detected during boot.
+    //
+    // Note that after calling WaitForMerge(), GetUpdateState() may still return
+    // that a merge is in progress:
+    //   MergeFailed indicates that a fatal error occurred. WaitForMerge() may
+    //   called any number of times again to attempt to make more progress, but
+    //   we do not expect it to succeed if a catastrophic error occurred.
+    //
+    //   MergeNeedsReboot indicates that the merge has completed, but cleanup
+    //   failed. This can happen if for some reason resources were not closed
+    //   properly. In this case another reboot is needed before we can take
+    //   another OTA. However, WaitForMerge() can be called again without
+    //   rebooting, to attempt to finish cleanup anyway.
+    //
+    //   MergeCompleted indicates that the update has fully completed.
+    //   GetUpdateState will return None, and a new update can begin.
+    UpdateState WaitForMerge();
 
     // Find the status of the current update, if any.
     //
@@ -109,9 +129,14 @@
     FRIEND_TEST(SnapshotTest, CreateSnapshot);
     FRIEND_TEST(SnapshotTest, MapSnapshot);
     FRIEND_TEST(SnapshotTest, MapPartialSnapshot);
+    FRIEND_TEST(SnapshotTest, NoMergeBeforeReboot);
+    FRIEND_TEST(SnapshotTest, Merge);
+    FRIEND_TEST(SnapshotTest, MergeCannotRemoveCow);
     friend class SnapshotTest;
 
+    using DmTargetSnapshot = android::dm::DmTargetSnapshot;
     using IImageManager = android::fiemap::IImageManager;
+    using TargetInfo = android::dm::DeviceMapper::TargetInfo;
 
     explicit SnapshotManager(IDeviceInfo* info);
 
@@ -203,6 +228,33 @@
         uint64_t metadata_sectors = 0;
     };
 
+    // Helpers for merging.
+    bool SwitchSnapshotToMerge(LockedFile* lock, const std::string& name);
+    bool RewriteSnapshotDeviceTable(const std::string& dm_name);
+    bool MarkSnapshotMergeCompleted(LockedFile* snapshot_lock, const std::string& snapshot_name);
+    void AcknowledgeMergeSuccess(LockedFile* lock);
+    void AcknowledgeMergeFailure();
+
+    // Note that these require the name of the device containing the snapshot,
+    // which may be the "inner" device. Use GetsnapshotDeviecName().
+    bool QuerySnapshotStatus(const std::string& dm_name, std::string* target_type,
+                             DmTargetSnapshot::Status* status);
+    bool IsSnapshotDevice(const std::string& dm_name, TargetInfo* target = nullptr);
+
+    // Internal callback for when merging is complete.
+    bool OnSnapshotMergeComplete(LockedFile* lock, const std::string& name,
+                                 const SnapshotStatus& status);
+    bool CollapseSnapshotDevice(const std::string& name, const SnapshotStatus& status);
+
+    // Only the following UpdateStates are used here:
+    //   UpdateState::Merging
+    //   UpdateState::MergeCompleted
+    //   UpdateState::MergeFailed
+    //   UpdateState::MergeNeedsReboot
+    UpdateState CheckMergeState();
+    UpdateState CheckMergeState(LockedFile* lock);
+    UpdateState CheckTargetMergeState(LockedFile* lock, const std::string& name);
+
     // Interact with status files under /metadata/ota/snapshots.
     bool WriteSnapshotStatus(LockedFile* lock, const std::string& name,
                              const SnapshotStatus& status);
diff --git a/fs_mgr/libsnapshot/snapshot.cpp b/fs_mgr/libsnapshot/snapshot.cpp
index 2d1eab0..63a01f3 100644
--- a/fs_mgr/libsnapshot/snapshot.cpp
+++ b/fs_mgr/libsnapshot/snapshot.cpp
@@ -19,6 +19,8 @@
 #include <sys/types.h>
 #include <sys/unistd.h>
 
+#include <thread>
+
 #include <android-base/file.h>
 #include <android-base/logging.h>
 #include <android-base/parseint.h>
@@ -169,6 +171,11 @@
     if (!ReadSnapshotStatus(lock, name, &status)) {
         return false;
     }
+    if (status.state == "merge-completed") {
+        LOG(ERROR) << "Should not create a snapshot device for " << name
+                   << " after merging has completed.";
+        return false;
+    }
 
     // Validate the block device size, as well as the requested snapshot size.
     // During this we also compute the linear sector region if any.
@@ -204,19 +211,31 @@
 
     std::string cow_dev;
     if (!images_->MapImageDevice(cow_name, timeout_ms, &cow_dev)) {
+        LOG(ERROR) << "Could not map image device: " << cow_name;
         return false;
     }
 
     auto& dm = DeviceMapper::Instance();
 
-    // Merging is a global state, not per-snapshot. We do however track the
-    // progress of individual snapshots' merges.
+    // Note that merging is a global state. We do track whether individual devices
+    // have completed merging, but the start of the merge process is considered
+    // atomic.
     SnapshotStorageMode mode;
-    UpdateState update_state = ReadUpdateState(lock);
-    if (update_state == UpdateState::Merging || update_state == UpdateState::MergeCompleted) {
-        mode = SnapshotStorageMode::Merge;
-    } else {
-        mode = SnapshotStorageMode::Persistent;
+    switch (ReadUpdateState(lock)) {
+        case UpdateState::MergeCompleted:
+        case UpdateState::MergeNeedsReboot:
+            LOG(ERROR) << "Should not create a snapshot device for " << name
+                       << " after global merging has completed.";
+            return false;
+        case UpdateState::Merging:
+        case UpdateState::MergeFailed:
+            // Note: MergeFailed indicates that a merge is in progress, but
+            // is possibly stalled. We still have to honor the merge.
+            mode = SnapshotStorageMode::Merge;
+            break;
+        default:
+            mode = SnapshotStorageMode::Persistent;
+            break;
     }
 
     // The kernel (tested on 4.19) crashes horribly if a device has both a snapshot
@@ -312,11 +331,504 @@
 }
 
 bool SnapshotManager::InitiateMerge() {
-    return false;
+    auto lock = LockExclusive();
+    if (!lock) return false;
+
+    UpdateState state = ReadUpdateState(lock.get());
+    if (state != UpdateState::Unverified) {
+        LOG(ERROR) << "Cannot begin a merge if an update has not been verified";
+        return false;
+    }
+    if (!device_->IsRunningSnapshot()) {
+        LOG(ERROR) << "Cannot begin a merge if the device is not booted off a snapshot";
+        return false;
+    }
+
+    std::vector<std::string> snapshots;
+    if (!ListSnapshots(lock.get(), &snapshots)) {
+        LOG(ERROR) << "Could not list snapshots";
+        return false;
+    }
+
+    auto& dm = DeviceMapper::Instance();
+    for (const auto& snapshot : snapshots) {
+        // The device has to be mapped, since everything should be merged at
+        // the same time. This is a fairly serious error. We could forcefully
+        // map everything here, but it should have been mapped during first-
+        // stage init.
+        if (dm.GetState(snapshot) == DmDeviceState::INVALID) {
+            LOG(ERROR) << "Cannot begin merge; device " << snapshot << " is not mapped.";
+            return false;
+        }
+    }
+
+    // Point of no return - mark that we're starting a merge. From now on every
+    // snapshot must be a merge target.
+    if (!WriteUpdateState(lock.get(), UpdateState::Merging)) {
+        return false;
+    }
+
+    bool rewrote_all = true;
+    for (const auto& snapshot : snapshots) {
+        // If this fails, we have no choice but to continue. Everything must
+        // be merged. This is not an ideal state to be in, but it is safe,
+        // because we the next boot will try again.
+        if (!SwitchSnapshotToMerge(lock.get(), snapshot)) {
+            LOG(ERROR) << "Failed to switch snapshot to a merge target: " << snapshot;
+            rewrote_all = false;
+        }
+    }
+
+    // If we couldn't switch everything to a merge target, pre-emptively mark
+    // this merge as failed. It will get acknowledged when WaitForMerge() is
+    // called.
+    if (!rewrote_all) {
+        WriteUpdateState(lock.get(), UpdateState::MergeFailed);
+    }
+
+    // Return true no matter what, because a merge was initiated.
+    return true;
 }
 
-bool SnapshotManager::WaitForMerge() {
-    return false;
+bool SnapshotManager::SwitchSnapshotToMerge(LockedFile* lock, const std::string& name) {
+    SnapshotStatus status;
+    if (!ReadSnapshotStatus(lock, name, &status)) {
+        return false;
+    }
+    if (status.state != "created") {
+        LOG(WARNING) << "Snapshot " << name << " has unexpected state: " << status.state;
+    }
+
+    // After this, we return true because we technically did switch to a merge
+    // target. Everything else we do here is just informational.
+    auto dm_name = GetSnapshotDeviceName(name, status);
+    if (!RewriteSnapshotDeviceTable(dm_name)) {
+        return false;
+    }
+
+    status.state = "merging";
+
+    DmTargetSnapshot::Status dm_status;
+    if (!QuerySnapshotStatus(dm_name, nullptr, &dm_status)) {
+        LOG(ERROR) << "Could not query merge status for snapshot: " << dm_name;
+    }
+    status.sectors_allocated = dm_status.sectors_allocated;
+    status.metadata_sectors = dm_status.metadata_sectors;
+    if (!WriteSnapshotStatus(lock, name, status)) {
+        LOG(ERROR) << "Could not update status file for snapshot: " << name;
+    }
+    return true;
+}
+
+bool SnapshotManager::RewriteSnapshotDeviceTable(const std::string& dm_name) {
+    auto& dm = DeviceMapper::Instance();
+
+    std::vector<DeviceMapper::TargetInfo> old_targets;
+    if (!dm.GetTableInfo(dm_name, &old_targets)) {
+        LOG(ERROR) << "Could not read snapshot device table: " << dm_name;
+        return false;
+    }
+    if (old_targets.size() != 1 || DeviceMapper::GetTargetType(old_targets[0].spec) != "snapshot") {
+        LOG(ERROR) << "Unexpected device-mapper table for snapshot: " << dm_name;
+        return false;
+    }
+
+    std::string base_device, cow_device;
+    if (!DmTargetSnapshot::GetDevicesFromParams(old_targets[0].data, &base_device, &cow_device)) {
+        LOG(ERROR) << "Could not derive underlying devices for snapshot: " << dm_name;
+        return false;
+    }
+
+    DmTable table;
+    table.Emplace<DmTargetSnapshot>(0, old_targets[0].spec.length, base_device, cow_device,
+                                    SnapshotStorageMode::Merge, kSnapshotChunkSize);
+    if (!dm.LoadTableAndActivate(dm_name, table)) {
+        LOG(ERROR) << "Could not swap device-mapper tables on snapshot device " << dm_name;
+        return false;
+    }
+    LOG(INFO) << "Successfully switched snapshot device to a merge target: " << dm_name;
+    return true;
+}
+
+enum class TableQuery {
+    Table,
+    Status,
+};
+
+static bool GetSingleTarget(const std::string& dm_name, TableQuery query,
+                            DeviceMapper::TargetInfo* target) {
+    auto& dm = DeviceMapper::Instance();
+    if (dm.GetState(dm_name) == DmDeviceState::INVALID) {
+        return false;
+    }
+
+    std::vector<DeviceMapper::TargetInfo> targets;
+    bool result;
+    if (query == TableQuery::Status) {
+        result = dm.GetTableStatus(dm_name, &targets);
+    } else {
+        result = dm.GetTableInfo(dm_name, &targets);
+    }
+    if (!result) {
+        LOG(ERROR) << "Could not query device: " << dm_name;
+        return false;
+    }
+    if (targets.size() != 1) {
+        return false;
+    }
+
+    *target = std::move(targets[0]);
+    return true;
+}
+
+bool SnapshotManager::IsSnapshotDevice(const std::string& dm_name, TargetInfo* target) {
+    DeviceMapper::TargetInfo snap_target;
+    if (!GetSingleTarget(dm_name, TableQuery::Status, &snap_target)) {
+        return false;
+    }
+    auto type = DeviceMapper::GetTargetType(snap_target.spec);
+    if (type != "snapshot" && type != "snapshot-merge") {
+        return false;
+    }
+    if (target) {
+        *target = std::move(snap_target);
+    }
+    return true;
+}
+
+bool SnapshotManager::QuerySnapshotStatus(const std::string& dm_name, std::string* target_type,
+                                          DmTargetSnapshot::Status* status) {
+    DeviceMapper::TargetInfo target;
+    if (!IsSnapshotDevice(dm_name, &target)) {
+        LOG(ERROR) << "Device " << dm_name << " is not a snapshot or snapshot-merge device";
+        return false;
+    }
+    if (!DmTargetSnapshot::ParseStatusText(target.data, status)) {
+        LOG(ERROR) << "Could not parse snapshot status text: " << dm_name;
+        return false;
+    }
+    if (target_type) {
+        *target_type = DeviceMapper::GetTargetType(target.spec);
+    }
+    return true;
+}
+
+// Note that when a merge fails, we will *always* try again to complete the
+// merge each time the device boots. There is no harm in doing so, and if
+// the problem was transient, we might manage to get a new outcome.
+UpdateState SnapshotManager::WaitForMerge() {
+    while (true) {
+        UpdateState state = CheckMergeState();
+        if (state != UpdateState::Merging) {
+            // Either there is no merge, or the merge was finished, so no need
+            // to keep waiting.
+            return state;
+        }
+
+        // This wait is not super time sensitive, so we have a relatively
+        // low polling frequency.
+        std::this_thread::sleep_for(2s);
+    }
+}
+
+UpdateState SnapshotManager::CheckMergeState() {
+    auto lock = LockExclusive();
+    if (!lock) {
+        AcknowledgeMergeFailure();
+        return UpdateState::MergeFailed;
+    }
+
+    auto state = CheckMergeState(lock.get());
+    if (state == UpdateState::MergeCompleted) {
+        AcknowledgeMergeSuccess(lock.get());
+    } else if (state == UpdateState::MergeFailed) {
+        AcknowledgeMergeFailure();
+    }
+    return state;
+}
+
+UpdateState SnapshotManager::CheckMergeState(LockedFile* lock) {
+    UpdateState state = ReadUpdateState(lock);
+    switch (state) {
+        case UpdateState::None:
+        case UpdateState::MergeCompleted:
+            // Harmless races are allowed between two callers of WaitForMerge,
+            // so in both of these cases we just propagate the state.
+            return state;
+
+        case UpdateState::Merging:
+        case UpdateState::MergeNeedsReboot:
+        case UpdateState::MergeFailed:
+            // We'll poll each snapshot below. Note that for the NeedsReboot
+            // case, we always poll once to give cleanup another opportunity to
+            // run.
+            break;
+
+        default:
+            LOG(ERROR) << "No merge exists, cannot wait. Update state: "
+                       << static_cast<uint32_t>(state);
+            return UpdateState::None;
+    }
+
+    std::vector<std::string> snapshots;
+    if (!ListSnapshots(lock, &snapshots)) {
+        return UpdateState::MergeFailed;
+    }
+
+    bool failed = false;
+    bool merging = false;
+    bool needs_reboot = false;
+    for (const auto& snapshot : snapshots) {
+        UpdateState snapshot_state = CheckTargetMergeState(lock, snapshot);
+        switch (snapshot_state) {
+            case UpdateState::MergeFailed:
+                failed = true;
+                break;
+            case UpdateState::Merging:
+                merging = true;
+                break;
+            case UpdateState::MergeNeedsReboot:
+                needs_reboot = true;
+                break;
+            case UpdateState::MergeCompleted:
+                break;
+            default:
+                LOG(ERROR) << "Unknown merge status: " << static_cast<uint32_t>(snapshot_state);
+                failed = true;
+                break;
+        }
+    }
+
+    if (merging) {
+        // Note that we handle "Merging" before we handle anything else. We
+        // want to poll until *nothing* is merging if we can, so everything has
+        // a chance to get marked as completed or failed.
+        return UpdateState::Merging;
+    }
+    if (failed) {
+        // Note: since there are many drop-out cases for failure, we acknowledge
+        // it in WaitForMerge rather than here and elsewhere.
+        return UpdateState::MergeFailed;
+    }
+    if (needs_reboot) {
+        WriteUpdateState(lock, UpdateState::MergeNeedsReboot);
+        return UpdateState::MergeNeedsReboot;
+    }
+    return UpdateState::MergeCompleted;
+}
+
+UpdateState SnapshotManager::CheckTargetMergeState(LockedFile* lock, const std::string& name) {
+    SnapshotStatus snapshot_status;
+    if (!ReadSnapshotStatus(lock, name, &snapshot_status)) {
+        return UpdateState::MergeFailed;
+    }
+
+    std::string dm_name = GetSnapshotDeviceName(name, snapshot_status);
+
+    // During a check, we decided the merge was complete, but we were unable to
+    // collapse the device-mapper stack and perform COW cleanup. If we haven't
+    // rebooted after this check, the device will still be a snapshot-merge
+    // target. If the have rebooted, the device will now be a linear target,
+    // and we can try cleanup again.
+    if (snapshot_status.state == "merge-complete" && !IsSnapshotDevice(dm_name)) {
+        // NB: It's okay if this fails now, we gave cleanup our best effort.
+        OnSnapshotMergeComplete(lock, name, snapshot_status);
+        return UpdateState::MergeCompleted;
+    }
+
+    std::string target_type;
+    DmTargetSnapshot::Status status;
+    if (!QuerySnapshotStatus(dm_name, &target_type, &status)) {
+        return UpdateState::MergeFailed;
+    }
+    if (target_type != "snapshot-merge") {
+        // We can get here if we failed to rewrite the target type in
+        // InitiateMerge(). If we failed to create the target in first-stage
+        // init, boot would not succeed.
+        LOG(ERROR) << "Snapshot " << name << " has incorrect target type: " << target_type;
+        return UpdateState::MergeFailed;
+    }
+
+    // These two values are equal when merging is complete.
+    if (status.sectors_allocated != status.metadata_sectors) {
+        if (snapshot_status.state == "merge-complete") {
+            LOG(ERROR) << "Snapshot " << name << " is merging after being marked merge-complete.";
+            return UpdateState::MergeFailed;
+        }
+        return UpdateState::Merging;
+    }
+
+    // Merging is done. First, update the status file to indicate the merge
+    // is complete. We do this before calling OnSnapshotMergeComplete, even
+    // though this means the write is potentially wasted work (since in the
+    // ideal case we'll immediately delete the file).
+    //
+    // This makes it simpler to reason about the next reboot: no matter what
+    // part of cleanup failed, first-stage init won't try to create another
+    // snapshot device for this partition.
+    snapshot_status.state = "merge-complete";
+    if (!WriteSnapshotStatus(lock, name, snapshot_status)) {
+        return UpdateState::MergeFailed;
+    }
+    if (!OnSnapshotMergeComplete(lock, name, snapshot_status)) {
+        return UpdateState::MergeNeedsReboot;
+    }
+    return UpdateState::MergeCompleted;
+}
+
+void SnapshotManager::AcknowledgeMergeSuccess(LockedFile* lock) {
+    if (!WriteUpdateState(lock, UpdateState::None)) {
+        // We'll try again next reboot, ad infinitum.
+        return;
+    }
+}
+
+void SnapshotManager::AcknowledgeMergeFailure() {
+    // Log first, so worst case, we always have a record of why the calls below
+    // were being made.
+    LOG(ERROR) << "Merge could not be completed and will be marked as failed.";
+
+    auto lock = LockExclusive();
+    if (!lock) return;
+
+    // Since we released the lock in between WaitForMerge and here, it's
+    // possible (1) the merge successfully completed or (2) was already
+    // marked as a failure. So make sure to check the state again, and
+    // only mark as a failure if appropriate.
+    UpdateState state = ReadUpdateState(lock.get());
+    if (state != UpdateState::Merging && state != UpdateState::MergeNeedsReboot) {
+        return;
+    }
+
+    WriteUpdateState(lock.get(), UpdateState::MergeFailed);
+}
+
+bool SnapshotManager::OnSnapshotMergeComplete(LockedFile* lock, const std::string& name,
+                                              const SnapshotStatus& status) {
+    auto dm_name = GetSnapshotDeviceName(name, status);
+    if (IsSnapshotDevice(dm_name)) {
+        // We are extra-cautious here, to avoid deleting the wrong table.
+        std::string target_type;
+        DmTargetSnapshot::Status dm_status;
+        if (!QuerySnapshotStatus(dm_name, &target_type, &dm_status)) {
+            return false;
+        }
+        if (target_type != "snapshot-merge") {
+            LOG(ERROR) << "Unexpected target type " << target_type
+                       << " for snapshot device: " << dm_name;
+            return false;
+        }
+        if (dm_status.sectors_allocated != dm_status.metadata_sectors) {
+            LOG(ERROR) << "Merge is unexpectedly incomplete for device " << dm_name;
+            return false;
+        }
+        if (!CollapseSnapshotDevice(name, status)) {
+            LOG(ERROR) << "Unable to collapse snapshot: " << name;
+            return false;
+        }
+        // Note that collapsing is implicitly an Unmap, so we don't need to
+        // unmap the snapshot.
+    }
+
+    if (!DeleteSnapshot(lock, name)) {
+        LOG(ERROR) << "Could not delete snapshot: " << name;
+        return false;
+    }
+    return true;
+}
+
+bool SnapshotManager::CollapseSnapshotDevice(const std::string& name,
+                                             const SnapshotStatus& status) {
+    // Ideally, we would complete the following steps to collapse the device:
+    //  (1) Rewrite the snapshot table to be identical to the base device table.
+    //  (2) Rewrite the verity table to use the "snapshot" (now linear) device.
+    //  (3) Delete the base device.
+    //
+    // This should be possible once libsnapshot understands LpMetadata. In the
+    // meantime, we implement a simpler solution: rewriting the snapshot table
+    // to be a single dm-linear segment against the base device. While not as
+    // ideal, it still lets us remove the COW device. We can remove this
+    // implementation once the new method has been tested.
+    auto& dm = DeviceMapper::Instance();
+    auto dm_name = GetSnapshotDeviceName(name, status);
+
+    DeviceMapper::TargetInfo target;
+    if (!GetSingleTarget(dm_name, TableQuery::Table, &target)) {
+        return false;
+    }
+    if (DeviceMapper::GetTargetType(target.spec) != "snapshot-merge") {
+        // This should be impossible, it was checked above.
+        LOG(ERROR) << "Snapshot device has invalid target type: " << dm_name;
+        return false;
+    }
+
+    std::string base_device, cow_device;
+    if (!DmTargetSnapshot::GetDevicesFromParams(target.data, &base_device, &cow_device)) {
+        LOG(ERROR) << "Could not parse snapshot device " << dm_name
+                   << " parameters: " << target.data;
+        return false;
+    }
+
+    uint64_t num_sectors = status.snapshot_size / kSectorSize;
+    if (num_sectors * kSectorSize != status.snapshot_size) {
+        LOG(ERROR) << "Snapshot " << name
+                   << " size is not sector aligned: " << status.snapshot_size;
+        return false;
+    }
+
+    if (dm_name != name) {
+        // We've derived the base device, but we actually need to replace the
+        // table of the outermost device. Do a quick verification that this
+        // device looks like we expect it to.
+        std::vector<DeviceMapper::TargetInfo> outer_table;
+        if (!dm.GetTableInfo(name, &outer_table)) {
+            LOG(ERROR) << "Could not validate outer snapshot table: " << name;
+            return false;
+        }
+        if (outer_table.size() != 2) {
+            LOG(ERROR) << "Expected 2 dm-linear targets for tabble " << name
+                       << ", got: " << outer_table.size();
+            return false;
+        }
+        for (const auto& target : outer_table) {
+            auto target_type = DeviceMapper::GetTargetType(target.spec);
+            if (target_type != "linear") {
+                LOG(ERROR) << "Outer snapshot table may only contain linear targets, but " << name
+                           << " has target: " << target_type;
+                return false;
+            }
+        }
+        uint64_t sectors = outer_table[0].spec.length + outer_table[1].spec.length;
+        if (sectors != num_sectors) {
+            LOG(ERROR) << "Outer snapshot " << name << " should have " << num_sectors
+                       << ", got: " << sectors;
+            return false;
+        }
+    }
+
+    // Note: we are replacing the OUTER table here, so we do not use dm_name.
+    DmTargetLinear new_target(0, num_sectors, base_device, 0);
+    LOG(INFO) << "Replacing snapshot device " << name
+              << " table with: " << new_target.GetParameterString();
+
+    DmTable table;
+    table.Emplace<DmTargetLinear>(new_target);
+    if (!dm.LoadTableAndActivate(name, table)) {
+        return false;
+    }
+
+    if (dm_name != name) {
+        // Attempt to delete the snapshot device. Nothing should be depending on
+        // the device, and device-mapper should have flushed remaining I/O. We
+        // could in theory replace with dm-zero (or re-use the table above), but
+        // for now it's better to know why this would fail.
+        if (!dm.DeleteDevice(dm_name)) {
+            LOG(ERROR) << "Unable to delete snapshot device " << dm_name << ", COW cannot be "
+                       << "reclaimed until after reboot.";
+            return false;
+        }
+    }
+    return true;
 }
 
 bool SnapshotManager::RemoveAllSnapshots(LockedFile* lock) {
@@ -439,6 +951,10 @@
         return UpdateState::Merging;
     } else if (contents == "merge-completed") {
         return UpdateState::MergeCompleted;
+    } else if (contents == "merge-needs-reboot") {
+        return UpdateState::MergeNeedsReboot;
+    } else if (contents == "merge-failed") {
+        return UpdateState::MergeFailed;
     } else {
         LOG(ERROR) << "Unknown merge state in update state file";
         return UpdateState::None;
@@ -463,6 +979,12 @@
         case UpdateState::MergeCompleted:
             contents = "merge-completed";
             break;
+        case UpdateState::MergeNeedsReboot:
+            contents = "merge-needs-reboot";
+            break;
+        case UpdateState::MergeFailed:
+            contents = "merge-failed";
+            break;
         default:
             LOG(ERROR) << "Unknown update state";
             return false;
diff --git a/fs_mgr/libsnapshot/snapshot_test.cpp b/fs_mgr/libsnapshot/snapshot_test.cpp
index aaddec2..4903224 100644
--- a/fs_mgr/libsnapshot/snapshot_test.cpp
+++ b/fs_mgr/libsnapshot/snapshot_test.cpp
@@ -23,6 +23,7 @@
 #include <iostream>
 
 #include <android-base/file.h>
+#include <android-base/properties.h>
 #include <android-base/strings.h>
 #include <android-base/unique_fd.h>
 #include <gtest/gtest.h>
@@ -82,12 +83,7 @@
         // are tests, we don't care, destroy everything that might exist.
         std::vector<std::string> snapshots = {"test-snapshot"};
         for (const auto& snapshot : snapshots) {
-            if (dm_.GetState(snapshot) != DmDeviceState::INVALID) {
-                dm_.DeleteDevice(snapshot);
-            }
-            if (dm_.GetState(snapshot + "-inner") != DmDeviceState::INVALID) {
-                dm_.DeleteDevice(snapshot + "-inner");
-            }
+            DeleteSnapshotDevice(snapshot);
             temp_images_.emplace_back(snapshot + "-cow");
 
             auto status_file = sm->GetSnapshotStatusFilePath(snapshot);
@@ -120,6 +116,16 @@
         return image_manager_->MapImageDevice(name, 10s, path);
     }
 
+    bool DeleteSnapshotDevice(const std::string& snapshot) {
+        if (dm_.GetState(snapshot) != DmDeviceState::INVALID) {
+            if (!dm_.DeleteDevice(snapshot)) return false;
+        }
+        if (dm_.GetState(snapshot + "-inner") != DmDeviceState::INVALID) {
+            if (!dm_.DeleteDevice(snapshot + "-inner")) return false;
+        }
+        return true;
+    }
+
     DeviceMapper& dm_;
     std::unique_ptr<SnapshotManager::LockedFile> lock_;
     std::vector<std::string> temp_images_;
@@ -182,6 +188,117 @@
     ASSERT_TRUE(android::base::StartsWith(snap_device, "/dev/block/dm-"));
 }
 
+TEST_F(SnapshotTest, NoMergeBeforeReboot) {
+    ASSERT_TRUE(AcquireLock());
+
+    // Set the state to Unverified, as if we finished an update.
+    ASSERT_TRUE(sm->WriteUpdateState(lock_.get(), UpdateState::Unverified));
+
+    // Release the lock.
+    lock_ = nullptr;
+
+    // Merge should fail, since we didn't mark the device as rebooted.
+    ASSERT_FALSE(sm->InitiateMerge());
+}
+
+TEST_F(SnapshotTest, Merge) {
+    ASSERT_TRUE(AcquireLock());
+
+    static const uint64_t kDeviceSize = 1024 * 1024;
+    ASSERT_TRUE(sm->CreateSnapshot(lock_.get(), "test-snapshot", kDeviceSize, kDeviceSize,
+                                   kDeviceSize));
+
+    std::string base_device, snap_device;
+    ASSERT_TRUE(CreateTempDevice("base-device", kDeviceSize, &base_device));
+    ASSERT_TRUE(sm->MapSnapshot(lock_.get(), "test-snapshot", base_device, 10s, &snap_device));
+
+    std::string test_string = "This is a test string.";
+    {
+        unique_fd fd(open(snap_device.c_str(), O_RDWR | O_CLOEXEC | O_SYNC));
+        ASSERT_GE(fd, 0);
+        ASSERT_TRUE(android::base::WriteFully(fd, test_string.data(), test_string.size()));
+    }
+
+    // Note: we know the name of the device is test-snapshot because we didn't
+    // request a linear segment.
+    DeviceMapper::TargetInfo target;
+    ASSERT_TRUE(sm->IsSnapshotDevice("test-snapshot", &target));
+    ASSERT_EQ(DeviceMapper::GetTargetType(target.spec), "snapshot");
+
+    // Set the state to Unverified, as if we finished an update.
+    ASSERT_TRUE(sm->WriteUpdateState(lock_.get(), UpdateState::Unverified));
+
+    // Release the lock.
+    lock_ = nullptr;
+
+    test_device->set_is_running_snapshot(true);
+    ASSERT_TRUE(sm->InitiateMerge());
+
+    // The device should have been switched to a snapshot-merge target.
+    ASSERT_TRUE(sm->IsSnapshotDevice("test-snapshot", &target));
+    ASSERT_EQ(DeviceMapper::GetTargetType(target.spec), "snapshot-merge");
+
+    // We should not be able to cancel an update now.
+    ASSERT_FALSE(sm->CancelUpdate());
+
+    ASSERT_EQ(sm->WaitForMerge(), UpdateState::MergeCompleted);
+    ASSERT_EQ(sm->GetUpdateState(), UpdateState::None);
+
+    // The device should no longer be a snapshot or snapshot-merge.
+    ASSERT_FALSE(sm->IsSnapshotDevice("test-snapshot"));
+
+    // Test that we can read back the string we wrote to the snapshot.
+    unique_fd fd(open(base_device.c_str(), O_RDONLY | O_CLOEXEC));
+    ASSERT_GE(fd, 0);
+
+    std::string buffer(test_string.size(), '\0');
+    ASSERT_TRUE(android::base::ReadFully(fd, buffer.data(), buffer.size()));
+    ASSERT_EQ(test_string, buffer);
+}
+
+TEST_F(SnapshotTest, MergeCannotRemoveCow) {
+    ASSERT_TRUE(AcquireLock());
+
+    static const uint64_t kDeviceSize = 1024 * 1024;
+    ASSERT_TRUE(sm->CreateSnapshot(lock_.get(), "test-snapshot", kDeviceSize, kDeviceSize,
+                                   kDeviceSize));
+
+    std::string base_device, snap_device;
+    ASSERT_TRUE(CreateTempDevice("base-device", kDeviceSize, &base_device));
+    ASSERT_TRUE(sm->MapSnapshot(lock_.get(), "test-snapshot", base_device, 10s, &snap_device));
+
+    // Keep an open handle to the cow device. This should cause the merge to
+    // be incomplete.
+    auto cow_path = android::base::GetProperty("gsid.mapped_image.test-snapshot-cow", "");
+    unique_fd fd(open(cow_path.c_str(), O_RDONLY | O_CLOEXEC));
+    ASSERT_GE(fd, 0);
+
+    // Set the state to Unverified, as if we finished an update.
+    ASSERT_TRUE(sm->WriteUpdateState(lock_.get(), UpdateState::Unverified));
+
+    // Release the lock.
+    lock_ = nullptr;
+
+    test_device->set_is_running_snapshot(true);
+    ASSERT_TRUE(sm->InitiateMerge());
+
+    // COW cannot be removed due to open fd, so expect a soft failure.
+    ASSERT_EQ(sm->WaitForMerge(), UpdateState::MergeNeedsReboot);
+
+    // Forcefully delete the snapshot device, so it looks like we just rebooted.
+    ASSERT_TRUE(DeleteSnapshotDevice("test-snapshot"));
+
+    // Map snapshot should fail now, because we're in a merge-complete state.
+    ASSERT_TRUE(AcquireLock());
+    ASSERT_FALSE(sm->MapSnapshot(lock_.get(), "test-snapshot", base_device, 10s, &snap_device));
+
+    // Release everything and now the merge should complete.
+    fd = {};
+    lock_ = nullptr;
+
+    ASSERT_EQ(sm->WaitForMerge(), UpdateState::MergeCompleted);
+}
+
 }  // namespace snapshot
 }  // namespace android