libsnapshot: lock on /metadata/ota

We used to flock() on /metadata/ota/state to ensure
atomic access. However, writing the file itself is
not necessary atomic and may lead to inconsistent
states.

This change redirects flock() to the outer directory, /metadata/ota,
which is very likely to exist (see exception below).
flock() is called on this directory instead of /state. This allows
a follow-up change to turn all writes to the /metadata partition
atomic.

Note: /metadata/ota may not exist during first boot after a flash
with wipe. However, first_stage_init always checks existence of
boot indicator before even trying to flock() (via
IsSnapshotManagerNeeded() and NeedSnapshotsInFirstStageMount()). If
the boot indicator exists, /metadata/ota must exist as well.

Also add tests to ensure LockExclusive() and LockShared() works as
expected.

Test: libsnapshot_test
Test: apply OTA from older build to this, then reboot

Bug: 144549076

Change-Id: Ib4dd9e9be1a43013c328e181b9398ac0b514dbce
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h b/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h
index 8e3875f..7450d19 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h
@@ -277,6 +277,7 @@
     friend class SnapshotTest;
     friend class SnapshotUpdateTest;
     friend class FlashAfterUpdateTest;
+    friend class LockTestConsumer;
     friend struct AutoDeleteCowImage;
     friend struct AutoDeleteSnapshot;
     friend struct PartitionCowCreator;
@@ -304,9 +305,6 @@
         LockedFile(const std::string& path, android::base::unique_fd&& fd, int lock_mode)
             : path_(path), fd_(std::move(fd)), lock_mode_(lock_mode) {}
         ~LockedFile();
-
-        const std::string& path() const { return path_; }
-        int fd() const { return fd_; }
         int lock_mode() const { return lock_mode_; }
 
       private:
@@ -314,8 +312,7 @@
         android::base::unique_fd fd_;
         int lock_mode_;
     };
-    std::unique_ptr<LockedFile> OpenFile(const std::string& file, int open_flags, int lock_flags);
-    bool Truncate(LockedFile* file);
+    static std::unique_ptr<LockedFile> OpenFile(const std::string& file, int lock_flags);
 
     // Create a new snapshot record. This creates the backing COW store and
     // persists information needed to map the device. The device can be mapped
@@ -381,10 +378,13 @@
     // set the update state to None.
     bool RemoveAllUpdateState(LockedFile* lock);
 
-    // Interact with /metadata/ota/state.
-    std::unique_ptr<LockedFile> OpenStateFile(int open_flags, int lock_flags);
+    // Interact with /metadata/ota.
+    std::unique_ptr<LockedFile> OpenLock(int lock_flags);
     std::unique_ptr<LockedFile> LockShared();
     std::unique_ptr<LockedFile> LockExclusive();
+    std::string GetLockPath() const;
+
+    // Interact with /metadata/ota/state.
     UpdateState ReadUpdateState(LockedFile* file);
     bool WriteUpdateState(LockedFile* file, UpdateState state);
     std::string GetStateFilePath() const;
diff --git a/fs_mgr/libsnapshot/snapshot.cpp b/fs_mgr/libsnapshot/snapshot.cpp
index 33c122b..0ee982c 100644
--- a/fs_mgr/libsnapshot/snapshot.cpp
+++ b/fs_mgr/libsnapshot/snapshot.cpp
@@ -1218,12 +1218,12 @@
         return UpdateState::None;
     }
 
-    auto file = LockShared();
-    if (!file) {
+    auto lock = LockShared();
+    if (!lock) {
         return UpdateState::None;
     }
 
-    auto state = ReadUpdateState(file.get());
+    auto state = ReadUpdateState(lock.get());
     if (progress) {
         *progress = 0.0;
         if (state == UpdateState::Merging) {
@@ -1595,9 +1595,9 @@
     return true;
 }
 
-auto SnapshotManager::OpenFile(const std::string& file, int open_flags, int lock_flags)
+auto SnapshotManager::OpenFile(const std::string& file, int lock_flags)
         -> std::unique_ptr<LockedFile> {
-    unique_fd fd(open(file.c_str(), open_flags | O_CLOEXEC | O_NOFOLLOW | O_SYNC, 0660));
+    unique_fd fd(open(file.c_str(), O_RDONLY | O_CLOEXEC | O_NOFOLLOW));
     if (fd < 0) {
         PLOG(ERROR) << "Open failed: " << file;
         return nullptr;
@@ -1622,29 +1622,28 @@
     return metadata_dir_ + "/state"s;
 }
 
-std::unique_ptr<SnapshotManager::LockedFile> SnapshotManager::OpenStateFile(int open_flags,
-                                                                            int lock_flags) {
-    auto state_file = GetStateFilePath();
-    return OpenFile(state_file, open_flags, lock_flags);
+std::string SnapshotManager::GetLockPath() const {
+    return metadata_dir_;
+}
+
+std::unique_ptr<SnapshotManager::LockedFile> SnapshotManager::OpenLock(int lock_flags) {
+    auto lock_file = GetLockPath();
+    return OpenFile(lock_file, lock_flags);
 }
 
 std::unique_ptr<SnapshotManager::LockedFile> SnapshotManager::LockShared() {
-    return OpenStateFile(O_RDONLY, LOCK_SH);
+    return OpenLock(LOCK_SH);
 }
 
 std::unique_ptr<SnapshotManager::LockedFile> SnapshotManager::LockExclusive() {
-    return OpenStateFile(O_RDWR | O_CREAT, LOCK_EX);
+    return OpenLock(LOCK_EX);
 }
 
-UpdateState SnapshotManager::ReadUpdateState(LockedFile* file) {
-    // Reset position since some calls read+write.
-    if (lseek(file->fd(), 0, SEEK_SET) < 0) {
-        PLOG(ERROR) << "lseek state file failed";
-        return UpdateState::None;
-    }
+UpdateState SnapshotManager::ReadUpdateState(LockedFile* lock) {
+    CHECK(lock);
 
     std::string contents;
-    if (!android::base::ReadFdToString(file->fd(), &contents)) {
+    if (!android::base::ReadFileToString(GetStateFilePath(), &contents)) {
         PLOG(ERROR) << "Read state file failed";
         return UpdateState::None;
     }
@@ -1691,14 +1690,15 @@
     }
 }
 
-bool SnapshotManager::WriteUpdateState(LockedFile* file, UpdateState state) {
+bool SnapshotManager::WriteUpdateState(LockedFile* lock, UpdateState state) {
+    CHECK(lock);
+    CHECK(lock->lock_mode() == LOCK_EX);
+
     std::stringstream ss;
     ss << state;
     std::string contents = ss.str();
     if (contents.empty()) return false;
 
-    if (!Truncate(file)) return false;
-
 #ifdef LIBSNAPSHOT_USE_HAL
     auto merge_status = MergeStatus::UNKNOWN;
     switch (state) {
@@ -1731,7 +1731,7 @@
     }
 #endif
 
-    if (!android::base::WriteStringToFd(contents, file->fd())) {
+    if (!android::base::WriteStringToFile(contents, GetStateFilePath())) {
         PLOG(ERROR) << "Could not write to state file";
         return false;
     }
@@ -1795,18 +1795,6 @@
     return true;
 }
 
-bool SnapshotManager::Truncate(LockedFile* file) {
-    if (lseek(file->fd(), 0, SEEK_SET) < 0) {
-        PLOG(ERROR) << "lseek file failed: " << file->path();
-        return false;
-    }
-    if (ftruncate(file->fd(), 0) < 0) {
-        PLOG(ERROR) << "truncate failed: " << file->path();
-        return false;
-    }
-    return true;
-}
-
 std::string SnapshotManager::GetSnapshotDeviceName(const std::string& snapshot_name,
                                                    const SnapshotStatus& status) {
     if (status.device_size() != status.snapshot_size()) {
@@ -2164,7 +2152,7 @@
 bool SnapshotManager::Dump(std::ostream& os) {
     // Don't actually lock. Dump() is for debugging purposes only, so it is okay
     // if it is racy.
-    auto file = OpenStateFile(O_RDONLY, 0);
+    auto file = OpenLock(0 /* lock flag */);
     if (!file) return false;
 
     std::stringstream ss;
diff --git a/fs_mgr/libsnapshot/snapshot_test.cpp b/fs_mgr/libsnapshot/snapshot_test.cpp
index 1d2a1f1..9e5fef3 100644
--- a/fs_mgr/libsnapshot/snapshot_test.cpp
+++ b/fs_mgr/libsnapshot/snapshot_test.cpp
@@ -20,6 +20,8 @@
 #include <sys/types.h>
 
 #include <chrono>
+#include <deque>
+#include <future>
 #include <iostream>
 
 #include <android-base/file.h>
@@ -142,7 +144,7 @@
     }
 
     bool AcquireLock() {
-        lock_ = sm->OpenStateFile(O_RDWR, LOCK_EX);
+        lock_ = sm->LockExclusive();
         return !!lock_;
     }
 
@@ -595,6 +597,154 @@
     ASSERT_EQ(test_device->merge_status(), MergeStatus::MERGING);
 }
 
+enum class Request { UNKNOWN, LOCK_SHARED, LOCK_EXCLUSIVE, UNLOCK, EXIT };
+std::ostream& operator<<(std::ostream& os, Request request) {
+    switch (request) {
+        case Request::LOCK_SHARED:
+            return os << "LOCK_SHARED";
+        case Request::LOCK_EXCLUSIVE:
+            return os << "LOCK_EXCLUSIVE";
+        case Request::UNLOCK:
+            return os << "UNLOCK";
+        case Request::EXIT:
+            return os << "EXIT";
+        case Request::UNKNOWN:
+            [[fallthrough]];
+        default:
+            return os << "UNKNOWN";
+    }
+}
+
+class LockTestConsumer {
+  public:
+    AssertionResult MakeRequest(Request new_request) {
+        {
+            std::unique_lock<std::mutex> ulock(mutex_);
+            requests_.push_back(new_request);
+        }
+        cv_.notify_all();
+        return AssertionSuccess() << "Request " << new_request << " successful";
+    }
+
+    template <typename R, typename P>
+    AssertionResult WaitFulfill(std::chrono::duration<R, P> timeout) {
+        std::unique_lock<std::mutex> ulock(mutex_);
+        if (cv_.wait_for(ulock, timeout, [this] { return requests_.empty(); })) {
+            return AssertionSuccess() << "All requests_ fulfilled.";
+        }
+        return AssertionFailure() << "Timeout waiting for fulfilling " << requests_.size()
+                                  << " request(s), first one is "
+                                  << (requests_.empty() ? Request::UNKNOWN : requests_.front());
+    }
+
+    void StartHandleRequestsInBackground() {
+        future_ = std::async(std::launch::async, &LockTestConsumer::HandleRequests, this);
+    }
+
+  private:
+    void HandleRequests() {
+        static constexpr auto consumer_timeout = 3s;
+
+        auto next_request = Request::UNKNOWN;
+        do {
+            // Peek next request.
+            {
+                std::unique_lock<std::mutex> ulock(mutex_);
+                if (cv_.wait_for(ulock, consumer_timeout, [this] { return !requests_.empty(); })) {
+                    next_request = requests_.front();
+                } else {
+                    next_request = Request::EXIT;
+                }
+            }
+
+            // Handle next request.
+            switch (next_request) {
+                case Request::LOCK_SHARED: {
+                    lock_ = sm->LockShared();
+                } break;
+                case Request::LOCK_EXCLUSIVE: {
+                    lock_ = sm->LockExclusive();
+                } break;
+                case Request::EXIT:
+                    [[fallthrough]];
+                case Request::UNLOCK: {
+                    lock_.reset();
+                } break;
+                case Request::UNKNOWN:
+                    [[fallthrough]];
+                default:
+                    break;
+            }
+
+            // Pop next request. This thread is the only thread that
+            // pops from the front of the requests_ deque.
+            {
+                std::unique_lock<std::mutex> ulock(mutex_);
+                if (next_request == Request::EXIT) {
+                    requests_.clear();
+                } else {
+                    requests_.pop_front();
+                }
+            }
+            cv_.notify_all();
+        } while (next_request != Request::EXIT);
+    }
+
+    std::mutex mutex_;
+    std::condition_variable cv_;
+    std::deque<Request> requests_;
+    std::unique_ptr<SnapshotManager::LockedFile> lock_;
+    std::future<void> future_;
+};
+
+class LockTest : public ::testing::Test {
+  public:
+    void SetUp() {
+        first_consumer.StartHandleRequestsInBackground();
+        second_consumer.StartHandleRequestsInBackground();
+    }
+
+    void TearDown() {
+        EXPECT_TRUE(first_consumer.MakeRequest(Request::EXIT));
+        EXPECT_TRUE(second_consumer.MakeRequest(Request::EXIT));
+    }
+
+    static constexpr auto request_timeout = 500ms;
+    LockTestConsumer first_consumer;
+    LockTestConsumer second_consumer;
+};
+
+TEST_F(LockTest, SharedShared) {
+    ASSERT_TRUE(first_consumer.MakeRequest(Request::LOCK_SHARED));
+    ASSERT_TRUE(first_consumer.WaitFulfill(request_timeout));
+    ASSERT_TRUE(second_consumer.MakeRequest(Request::LOCK_SHARED));
+    ASSERT_TRUE(second_consumer.WaitFulfill(request_timeout));
+}
+
+using LockTestParam = std::pair<Request, Request>;
+class LockTestP : public LockTest, public ::testing::WithParamInterface<LockTestParam> {};
+TEST_P(LockTestP, Test) {
+    ASSERT_TRUE(first_consumer.MakeRequest(GetParam().first));
+    ASSERT_TRUE(first_consumer.WaitFulfill(request_timeout));
+    ASSERT_TRUE(second_consumer.MakeRequest(GetParam().second));
+    ASSERT_FALSE(second_consumer.WaitFulfill(request_timeout))
+            << "Should not be able to " << GetParam().second << " while separate thread "
+            << GetParam().first;
+    ASSERT_TRUE(first_consumer.MakeRequest(Request::UNLOCK));
+    ASSERT_TRUE(second_consumer.WaitFulfill(request_timeout))
+            << "Should be able to hold lock that is released by separate thread";
+}
+INSTANTIATE_TEST_SUITE_P(
+        LockTest, LockTestP,
+        testing::Values(LockTestParam{Request::LOCK_EXCLUSIVE, Request::LOCK_EXCLUSIVE},
+                        LockTestParam{Request::LOCK_EXCLUSIVE, Request::LOCK_SHARED},
+                        LockTestParam{Request::LOCK_SHARED, Request::LOCK_EXCLUSIVE}),
+        [](const testing::TestParamInfo<LockTestP::ParamType>& info) {
+            std::stringstream ss;
+            ss << info.param.first << "_" << info.param.second;
+            return ss.str();
+        });
+
 class SnapshotUpdateTest : public SnapshotTest {
   public:
     void SetUp() override {