Add an optional global limit to OperationLimiter
TODO: unit tests
Change-Id: Id98157ba9b973e0a4f7b82c4e2bc9249a7da628c
diff --git a/OperationLimiter.h b/OperationLimiter.h
index df39cd0..1fa1bf2 100644
--- a/OperationLimiter.h
+++ b/OperationLimiter.h
@@ -33,39 +33,44 @@
// The intended usage pattern is:
// OperationLimiter<UserId> connections_per_user;
// ...
-// // Before opening a new connection
-// if (!limiter.start(user)) {
-// return error;
-// } else {
-// // open the connection
-// // ...do some work...
-// // close the connection
-// limiter.finish(user);
+// int connectToSomeResource(int user) {
+// if (!connections_per_user.start(user)) return TRY_AGAIN_LATER;
+// // ...do expensive work here...
+// connections_per_user.finish(user);
// }
//
// This class is thread-safe.
template <typename KeyType>
class OperationLimiter {
public:
- explicit OperationLimiter(int limit) : mLimitPerKey(limit) {}
+ OperationLimiter(int limitPerKey, int globalLimit = INT_MAX)
+ : mLimitPerKey(limitPerKey), mGlobalLimit(globalLimit) {}
~OperationLimiter() {
DCHECK(mCounters.empty()) << "Destroying OperationLimiter with active operations";
}
- // Returns false if |key| has reached the maximum number of concurrent
- // operations, otherwise increments the counter and returns true.
+ // Returns false if |key| has reached the maximum number of concurrent operations,
+ // or if the global limit has been reached. Otherwise, increments the counter and returns true.
//
// Note: each successful start(key) must be matched by exactly one call to
// finish(key).
bool start(KeyType key) EXCLUDES(mMutex) {
std::lock_guard lock(mMutex);
+
+ if (mGlobalCounter >= mGlobalLimit) {
+ // Oh, no!
+ return false;
+ }
+
auto& cnt = mCounters[key]; // operator[] creates new entries as needed.
if (cnt >= mLimitPerKey) {
// Oh, no!
return false;
}
+
++cnt;
+ ++mGlobalCounter;
return true;
}
@@ -73,6 +78,13 @@
// See usage notes on start().
void finish(KeyType key) EXCLUDES(mMutex) {
std::lock_guard lock(mMutex);
+
+ --mGlobalCounter;
+ if (mGlobalCounter < 0) {
+ LOG(FATAL_WITHOUT_ABORT) << "Global operations counter going negative, this is a bug.";
+ return;
+ }
+
auto it = mCounters.find(key);
if (it == mCounters.end()) {
LOG(FATAL_WITHOUT_ABORT) << "Decremented non-existent counter for key=" << key;
@@ -93,8 +105,13 @@
// Tracks the number of outstanding queries by key.
std::unordered_map<KeyType, int> mCounters GUARDED_BY(mMutex);
+ int mGlobalCounter GUARDED_BY(mMutex) = 0;
+
// Maximum number of outstanding queries from a single key.
const int mLimitPerKey;
+
+ // Maximum number of outstanding queries, globally.
+ const int mGlobalLimit;
};
} // namespace netdutils