libsnapshot: fix re-flash after update

If device takes an update from slot A to
B, immediately flashes the B slot, and reboot
into B slot, libsnapshot incorrectly considers
the device booted into the new slot and refuses
to clear update states. Fix this by checking
the UPDATED flag in super partition metadata.

Test: libsnapshot_test
Bug: 143551390
Change-Id: I3cd7bb19b394da6399d4bf2f9d013bfaa7f186f1
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h b/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h
index 260a10c..cd37e2e 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h
@@ -254,6 +254,7 @@
     FRIEND_TEST(SnapshotUpdateTest, SnapshotStatusFileWithoutCow);
     friend class SnapshotTest;
     friend class SnapshotUpdateTest;
+    friend class FlashAfterUpdateTest;
     friend struct AutoDeleteCowImage;
     friend struct AutoDeleteSnapshot;
     friend struct PartitionCowCreator;
@@ -351,6 +352,9 @@
     // condition was detected and handled.
     bool HandleCancelledUpdate(LockedFile* lock);
 
+    // Helper for HandleCancelledUpdate. Assumes booting from new slot.
+    bool HandleCancelledUpdateOnNewSlot(LockedFile* lock);
+
     // Remove artifacts created by the update process, such as snapshots, and
     // set the update state to None.
     bool RemoveAllUpdateState(LockedFile* lock);
@@ -369,7 +373,19 @@
     bool MarkSnapshotMergeCompleted(LockedFile* snapshot_lock, const std::string& snapshot_name);
     void AcknowledgeMergeSuccess(LockedFile* lock);
     void AcknowledgeMergeFailure();
-    bool IsCancelledSnapshot(const std::string& snapshot_name);
+    std::unique_ptr<LpMetadata> ReadCurrentMetadata();
+
+    enum class MetadataPartitionState {
+        // Partition does not exist.
+        None,
+        // Partition is flashed.
+        Flashed,
+        // Partition is created by OTA client.
+        Updated,
+    };
+    // Helper function to check the state of a partition as described in metadata.
+    MetadataPartitionState GetMetadataPartitionState(const LpMetadata& metadata,
+                                                     const std::string& name);
 
     // Note that these require the name of the device containing the snapshot,
     // which may be the "inner" device. Use GetsnapshotDeviecName().
diff --git a/fs_mgr/libsnapshot/snapshot.cpp b/fs_mgr/libsnapshot/snapshot.cpp
index 72bd308..2c516a2 100644
--- a/fs_mgr/libsnapshot/snapshot.cpp
+++ b/fs_mgr/libsnapshot/snapshot.cpp
@@ -568,6 +568,27 @@
         }
     }
 
+    auto metadata = ReadCurrentMetadata();
+    for (auto it = snapshots.begin(); it != snapshots.end();) {
+        switch (GetMetadataPartitionState(*metadata, *it)) {
+            case MetadataPartitionState::Flashed:
+                LOG(WARNING) << "Detected re-flashing for partition " << *it
+                             << ". Skip merging it.";
+                [[fallthrough]];
+            case MetadataPartitionState::None: {
+                LOG(WARNING) << "Deleting snapshot for partition " << *it;
+                if (!DeleteSnapshot(lock.get(), *it)) {
+                    LOG(WARNING) << "Cannot delete snapshot for partition " << *it
+                                 << ". Skip merging it anyways.";
+                }
+                it = snapshots.erase(it);
+            } break;
+            case MetadataPartitionState::Updated: {
+                ++it;
+            } break;
+        }
+    }
+
     // Point of no return - mark that we're starting a merge. From now on every
     // snapshot must be a merge target.
     if (!WriteUpdateState(lock.get(), UpdateState::Merging)) {
@@ -855,8 +876,15 @@
 
     std::string dm_name = GetSnapshotDeviceName(name, snapshot_status);
 
+    std::unique_ptr<LpMetadata> current_metadata;
+
     if (!IsSnapshotDevice(dm_name)) {
-        if (IsCancelledSnapshot(name)) {
+        if (!current_metadata) {
+            current_metadata = ReadCurrentMetadata();
+        }
+
+        if (!current_metadata ||
+            GetMetadataPartitionState(*current_metadata, name) != MetadataPartitionState::Updated) {
             DeleteSnapshot(lock, name);
             return UpdateState::Cancelled;
         }
@@ -877,7 +905,8 @@
     }
 
     // This check is expensive so it is only enabled for debugging.
-    DCHECK(!IsCancelledSnapshot(name));
+    DCHECK((current_metadata = ReadCurrentMetadata()) &&
+           GetMetadataPartitionState(*current_metadata, name) == MetadataPartitionState::Updated);
 
     std::string target_type;
     DmTargetSnapshot::Status status;
@@ -1106,13 +1135,17 @@
     if (device_->GetSlotSuffix() != old_slot) {
         // We're booted into the target slot, which means we just rebooted
         // after applying the update.
-        return false;
+        if (!HandleCancelledUpdateOnNewSlot(lock)) {
+            return false;
+        }
     }
 
     // The only way we can get here is if:
     //  (1) The device rolled back to the previous slot.
     //  (2) This function was called prematurely before rebooting the device.
     //  (3) fastboot set_active was used.
+    //  (4) The device updates to the new slot but re-flashed *all* partitions
+    //      in the new slot.
     //
     // In any case, delete the snapshots. It may be worth using the boot_control
     // HAL to differentiate case (2).
@@ -1120,18 +1153,66 @@
     return true;
 }
 
-bool SnapshotManager::IsCancelledSnapshot(const std::string& snapshot_name) {
+std::unique_ptr<LpMetadata> SnapshotManager::ReadCurrentMetadata() {
     const auto& opener = device_->GetPartitionOpener();
     uint32_t slot = SlotNumberForSlotSuffix(device_->GetSlotSuffix());
     auto super_device = device_->GetSuperDevice(slot);
     auto metadata = android::fs_mgr::ReadMetadata(opener, super_device, slot);
     if (!metadata) {
         LOG(ERROR) << "Could not read dynamic partition metadata for device: " << super_device;
-        return false;
+        return nullptr;
     }
-    auto partition = android::fs_mgr::FindPartition(*metadata.get(), snapshot_name);
-    if (!partition) return false;
-    return (partition->attributes & LP_PARTITION_ATTR_UPDATED) == 0;
+    return metadata;
+}
+
+SnapshotManager::MetadataPartitionState SnapshotManager::GetMetadataPartitionState(
+        const LpMetadata& metadata, const std::string& name) {
+    auto partition = android::fs_mgr::FindPartition(metadata, name);
+    if (!partition) return MetadataPartitionState::None;
+    if (partition->attributes & LP_PARTITION_ATTR_UPDATED) {
+        return MetadataPartitionState::Updated;
+    }
+    return MetadataPartitionState::Flashed;
+}
+
+bool SnapshotManager::HandleCancelledUpdateOnNewSlot(LockedFile* lock) {
+    std::vector<std::string> snapshots;
+    if (!ListSnapshots(lock, &snapshots)) {
+        LOG(WARNING) << "Failed to list snapshots to determine whether device has been flashed "
+                     << "after applying an update. Assuming no snapshots.";
+        // Let HandleCancelledUpdate resets UpdateState.
+        return true;
+    }
+
+    // Attempt to detect re-flashing on each partition.
+    // - If all partitions are re-flashed, we can proceed to cancel the whole update.
+    // - If only some of the partitions are re-flashed, snapshots for re-flashed partitions are
+    //   deleted. Caller is responsible for merging the rest of the snapshots.
+    // - If none of the partitions are re-flashed, caller is responsible for merging the snapshots.
+    auto metadata = ReadCurrentMetadata();
+    if (!metadata) return false;
+    bool all_snapshot_cancelled = true;
+    for (const auto& snapshot_name : snapshots) {
+        if (GetMetadataPartitionState(*metadata, snapshot_name) ==
+            MetadataPartitionState::Updated) {
+            LOG(WARNING) << "Cannot cancel update because snapshot" << snapshot_name
+                         << " is in use.";
+            all_snapshot_cancelled = false;
+            continue;
+        }
+        // Delete snapshots for partitions that are re-flashed after the update.
+        LOG(INFO) << "Detected re-flashing of partition " << snapshot_name << ".";
+        if (!DeleteSnapshot(lock, snapshot_name)) {
+            // This is an error, but it is okay to leave the snapshot in the short term.
+            // However, if all_snapshot_cancelled == false after exiting the loop, caller may
+            // initiate merge for this unused snapshot, which is likely to fail.
+            LOG(WARNING) << "Failed to delete snapshot for re-flashed partition " << snapshot_name;
+        }
+    }
+    if (!all_snapshot_cancelled) return false;
+
+    LOG(INFO) << "All partitions are re-flashed after update, removing all update states.";
+    return true;
 }
 
 bool SnapshotManager::RemoveAllSnapshots(LockedFile* lock) {
@@ -2090,7 +2171,9 @@
 }
 
 UpdateState SnapshotManager::InitiateMergeAndWait() {
-    auto state = GetUpdateState();
+    LOG(INFO) << "Waiting for any previous merge request to complete. "
+              << "This can take up to several minutes.";
+    auto state = ProcessUpdateState();
     if (state == UpdateState::None) {
         LOG(INFO) << "Can't find any snapshot to merge.";
         return state;
@@ -2100,11 +2183,13 @@
             LOG(ERROR) << "Failed to initiate merge.";
             return state;
         }
+        // All other states can be handled by ProcessUpdateState.
+        LOG(INFO) << "Waiting for merge to complete. This can take up to several minutes.";
+        state = ProcessUpdateState();
     }
 
-    // All other states can be handled by ProcessUpdateState.
-    LOG(INFO) << "Waiting for any merge to complete. This can take up to 1 minute.";
-    return ProcessUpdateState();
+    LOG(INFO) << "Merge finished with state \"" << state << "\".";
+    return state;
 }
 
 }  // namespace snapshot
diff --git a/fs_mgr/libsnapshot/snapshot_test.cpp b/fs_mgr/libsnapshot/snapshot_test.cpp
index c9bf5b8..36dcd01 100644
--- a/fs_mgr/libsnapshot/snapshot_test.cpp
+++ b/fs_mgr/libsnapshot/snapshot_test.cpp
@@ -56,6 +56,7 @@
 using android::fs_mgr::GetPartitionName;
 using android::fs_mgr::Interval;
 using android::fs_mgr::MetadataBuilder;
+using android::fs_mgr::SlotSuffixForSlotNumber;
 using chromeos_update_engine::DeltaArchiveManifest;
 using chromeos_update_engine::DynamicPartitionGroup;
 using chromeos_update_engine::PartitionUpdate;
@@ -680,7 +681,6 @@
         // Initialize source partition metadata using |manifest_|.
         src_ = MetadataBuilder::New(*opener_, "super", 0);
         ASSERT_TRUE(FillFakeMetadata(src_.get(), manifest_, "_a"));
-        ASSERT_NE(nullptr, src_);
         // Add sys_b which is like system_other.
         auto partition = src_->AddPartition("sys_b", 0);
         ASSERT_NE(nullptr, partition);
@@ -731,8 +731,12 @@
         if (!hash.has_value()) {
             return AssertionFailure() << "Cannot read partition " << name << ": " << path;
         }
-        if (hashes_[name] != *hash) {
-            return AssertionFailure() << "Content of " << name << " has changed after the merge";
+        auto it = hashes_.find(name);
+        if (it == hashes_.end()) {
+            return AssertionFailure() << "No existing hash for " << name << ". Bad test code?";
+        }
+        if (it->second != *hash) {
+            return AssertionFailure() << "Content of " << name << " has changed";
         }
         return AssertionSuccess();
     }
@@ -1218,6 +1222,121 @@
     EXPECT_FALSE(IsMetadataMounted());
 }
 
+class FlashAfterUpdateTest : public SnapshotUpdateTest,
+                             public WithParamInterface<std::tuple<uint32_t, bool>> {
+  public:
+    AssertionResult InitiateMerge(const std::string& slot_suffix) {
+        auto sm = SnapshotManager::New(new TestDeviceInfo(fake_super, slot_suffix));
+        if (!sm->CreateLogicalAndSnapshotPartitions("super")) {
+            return AssertionFailure() << "Cannot CreateLogicalAndSnapshotPartitions";
+        }
+        if (!sm->InitiateMerge()) {
+            return AssertionFailure() << "Cannot initiate merge";
+        }
+        return AssertionSuccess();
+    }
+};
+
+TEST_P(FlashAfterUpdateTest, FlashSlotAfterUpdate) {
+    // OTA client blindly unmaps all partitions that are possibly mapped.
+    for (const auto& name : {"sys_b", "vnd_b", "prd_b"}) {
+        ASSERT_TRUE(sm->UnmapUpdateSnapshot(name));
+    }
+
+    // Execute the update.
+    ASSERT_TRUE(sm->BeginUpdate());
+    ASSERT_TRUE(sm->CreateUpdateSnapshots(manifest_));
+
+    ASSERT_TRUE(sm->FinishedSnapshotWrites());
+
+    // Simulate shutting down the device.
+    ASSERT_TRUE(UnmapAll());
+
+    if (std::get<1>(GetParam()) /* merge */) {
+        ASSERT_TRUE(InitiateMerge("_b"));
+        // Simulate shutting down the device after merge has initiated.
+        ASSERT_TRUE(UnmapAll());
+    }
+
+    auto flashed_slot = std::get<0>(GetParam());
+    auto flashed_slot_suffix = SlotSuffixForSlotNumber(flashed_slot);
+
+    // Simulate flashing |flashed_slot|. This clears the UPDATED flag.
+    auto flashed_builder = MetadataBuilder::New(*opener_, "super", flashed_slot);
+    flashed_builder->RemoveGroupAndPartitions(group_->name() + flashed_slot_suffix);
+    flashed_builder->RemoveGroupAndPartitions(kCowGroupName);
+    ASSERT_TRUE(FillFakeMetadata(flashed_builder.get(), manifest_, flashed_slot_suffix));
+
+    // Deliberately remove a partition from this build so that
+    // InitiateMerge do not switch state to "merging". This is possible in
+    // practice because the list of dynamic partitions may change.
+    ASSERT_NE(nullptr, flashed_builder->FindPartition("prd" + flashed_slot_suffix));
+    flashed_builder->RemovePartition("prd" + flashed_slot_suffix);
+
+    auto flashed_metadata = flashed_builder->Export();
+    ASSERT_NE(nullptr, flashed_metadata);
+    ASSERT_TRUE(UpdatePartitionTable(*opener_, "super", *flashed_metadata, flashed_slot));
+
+    std::string path;
+    for (const auto& name : {"sys", "vnd"}) {
+        ASSERT_TRUE(CreateLogicalPartition(
+                CreateLogicalPartitionParams{
+                        .block_device = fake_super,
+                        .metadata_slot = flashed_slot,
+                        .partition_name = name + flashed_slot_suffix,
+                        .timeout_ms = 1s,
+                        .partition_opener = opener_.get(),
+                },
+                &path));
+        ASSERT_TRUE(WriteRandomData(path));
+        auto hash = GetHash(path);
+        ASSERT_TRUE(hash.has_value());
+        hashes_[name + flashed_slot_suffix] = *hash;
+    }
+
+    // Simulate shutting down the device after flash.
+    ASSERT_TRUE(UnmapAll());
+
+    // Simulate reboot. After reboot, init does first stage mount.
+    auto init = SnapshotManager::NewForFirstStageMount(
+            new TestDeviceInfo(fake_super, flashed_slot_suffix));
+    ASSERT_NE(init, nullptr);
+    if (init->NeedSnapshotsInFirstStageMount()) {
+        ASSERT_TRUE(init->CreateLogicalAndSnapshotPartitions("super"));
+    } else {
+        for (const auto& name : {"sys", "vnd"}) {
+            ASSERT_TRUE(CreateLogicalPartition(
+                    CreateLogicalPartitionParams{
+                            .block_device = fake_super,
+                            .metadata_slot = flashed_slot,
+                            .partition_name = name + flashed_slot_suffix,
+                            .timeout_ms = 1s,
+                            .partition_opener = opener_.get(),
+                    },
+                    &path));
+        }
+    }
+
+    // Check that the target partitions have the same content.
+    for (const auto& name : {"sys", "vnd"}) {
+        ASSERT_TRUE(IsPartitionUnchanged(name + flashed_slot_suffix));
+    }
+
+    // There should be no snapshot to merge.
+    auto new_sm = SnapshotManager::New(new TestDeviceInfo(fake_super, flashed_slot_suffix));
+    ASSERT_EQ(UpdateState::Cancelled, new_sm->InitiateMergeAndWait());
+
+    // Next OTA calls CancelUpdate no matter what.
+    ASSERT_TRUE(new_sm->CancelUpdate());
+}
+
+INSTANTIATE_TEST_SUITE_P(, FlashAfterUpdateTest, Combine(Values(0, 1), Bool()),
+                         [](const TestParamInfo<FlashAfterUpdateTest::ParamType>& info) {
+                             return "Flash"s + (std::get<0>(info.param) ? "New"s : "Old"s) +
+                                    "Slot"s + (std::get<1>(info.param) ? "After"s : "Before"s) +
+                                    "Merge"s;
+                         });
+
 }  // namespace snapshot
 }  // namespace android