Refactor pthread_key.cpp to be lock-free.

Change-Id: I20dfb9d3cdc40eed10ea12ac34f03caaa94f7a49
diff --git a/libc/bionic/pthread_create.cpp b/libc/bionic/pthread_create.cpp
index c6d8494..2bca43f 100644
--- a/libc/bionic/pthread_create.cpp
+++ b/libc/bionic/pthread_create.cpp
@@ -56,7 +56,8 @@
   if (thread->mmap_size == 0) {
     // If the TLS area was not allocated by mmap(), it may not have been cleared to zero.
     // So assume the worst and zero the TLS area.
-    memset(&thread->tls[0], 0, BIONIC_TLS_SLOTS * sizeof(void*));
+    memset(thread->tls, 0, sizeof(thread->tls));
+    memset(thread->key_data, 0, sizeof(thread->key_data));
   }
 
   // Slot 0 must point to itself. The x86 Linux kernel reads the TLS from %fs:0.
@@ -155,7 +156,7 @@
   }
 
   // Mapped space(or user allocated stack) is used for:
-  //   thread_internal_t (including tls array)
+  //   thread_internal_t
   //   thread stack (including guard page)
   stack_top -= sizeof(pthread_internal_t);
   pthread_internal_t* thread = reinterpret_cast<pthread_internal_t*>(stack_top);
diff --git a/libc/bionic/pthread_internal.h b/libc/bionic/pthread_internal.h
index 8fbaf22..f131d7a 100644
--- a/libc/bionic/pthread_internal.h
+++ b/libc/bionic/pthread_internal.h
@@ -44,6 +44,11 @@
 /* Is this the main thread? */
 #define PTHREAD_ATTR_FLAG_MAIN_THREAD 0x80000000
 
+struct pthread_key_data_t {
+  uintptr_t seq; // Use uintptr_t just for alignment, as we use pointer below.
+  void* data;
+};
+
 struct pthread_internal_t {
   struct pthread_internal_t* next;
   struct pthread_internal_t* prev;
@@ -86,6 +91,8 @@
 
   void* tls[BIONIC_TLS_SLOTS];
 
+  pthread_key_data_t key_data[BIONIC_PTHREAD_KEY_COUNT];
+
   /*
    * The dynamic linker implements dlerror(3), which makes it hard for us to implement this
    * per-thread buffer by simply using malloc(3) and free(3).
diff --git a/libc/bionic/pthread_key.cpp b/libc/bionic/pthread_key.cpp
index 49b72e9..65e0879 100644
--- a/libc/bionic/pthread_key.cpp
+++ b/libc/bionic/pthread_key.cpp
@@ -28,175 +28,98 @@
 
 #include <errno.h>
 #include <pthread.h>
+#include <stdatomic.h>
 
 #include "private/bionic_tls.h"
 #include "pthread_internal.h"
 
-/* A technical note regarding our thread-local-storage (TLS) implementation:
- *
- * There can be up to BIONIC_TLS_SLOTS independent TLS keys in a given process,
- * The keys below TLS_SLOT_FIRST_USER_SLOT are reserved for Bionic to hold
- * special thread-specific variables like errno or a pointer to
- * the current thread's descriptor. These entries cannot be accessed through
- * pthread_getspecific() / pthread_setspecific() or pthread_key_delete()
- *
- * The 'tls_map_t' type defined below implements a shared global map of
- * currently created/allocated TLS keys and the destructors associated
- * with them.
- *
- * The global TLS map simply contains a bitmap of allocated keys, and
- * an array of destructors.
- *
- * Each thread has a TLS area that is a simple array of BIONIC_TLS_SLOTS void*
- * pointers. the TLS area of the main thread is stack-allocated in
- * __libc_init_common, while the TLS area of other threads is placed at
- * the top of their stack in pthread_create.
- *
- * When pthread_key_delete() is called it will erase the key's bitmap bit
- * and its destructor, and will also clear the key data in the TLS area of
- * all created threads. As mandated by Posix, it is the responsibility of
- * the caller of pthread_key_delete() to properly reclaim the objects that
- * were pointed to by these data fields (either before or after the call).
- */
-
-#define TLSMAP_BITS       32
-#define TLSMAP_WORDS      ((BIONIC_TLS_SLOTS+TLSMAP_BITS-1)/TLSMAP_BITS)
-#define TLSMAP_WORD(m,k)  (m).map[(k)/TLSMAP_BITS]
-#define TLSMAP_MASK(k)    (1U << ((k)&(TLSMAP_BITS-1)))
-
-static inline bool IsValidUserKey(pthread_key_t key) {
-  return (key >= TLS_SLOT_FIRST_USER_SLOT && key < BIONIC_TLS_SLOTS);
-}
-
 typedef void (*key_destructor_t)(void*);
 
-struct tls_map_t {
-  bool is_initialized;
+#define SEQ_KEY_IN_USE_BIT     0
 
-  /* bitmap of allocated keys */
-  uint32_t map[TLSMAP_WORDS];
+#define SEQ_INCREMENT_STEP  (1 << SEQ_KEY_IN_USE_BIT)
 
-  key_destructor_t key_destructors[BIONIC_TLS_SLOTS];
+// pthread_key_internal_t records the use of each pthread key slot:
+//   seq records the state of the slot.
+//      bit 0 is 1 when the key is in use, 0 when it is unused. Each time we create or delete the
+//      pthread key in the slot, we increse the seq by 1 (which inverts bit 0). The reason to use
+//      a sequence number instead of a boolean value here is that when the key slot is deleted and
+//      reused for a new key, pthread_getspecific will not return stale data.
+//   key_destructor records the destructor called at thread exit.
+struct pthread_key_internal_t {
+  atomic_uintptr_t seq;
+  atomic_uintptr_t key_destructor;
 };
 
-class ScopedTlsMapAccess {
- public:
-  ScopedTlsMapAccess() {
-    Lock();
+static pthread_key_internal_t key_map[BIONIC_PTHREAD_KEY_COUNT];
 
-    // If this is the first time the TLS map has been accessed,
-    // mark the slots belonging to well-known keys as being in use.
-    // This isn't currently necessary because the well-known keys
-    // can only be accessed directly by bionic itself, do not have
-    // destructors, and all the functions that touch the TLS map
-    // start after the maximum well-known slot.
-    if (!s_tls_map_.is_initialized) {
-      for (pthread_key_t key = 0; key < TLS_SLOT_FIRST_USER_SLOT; ++key) {
-        SetInUse(key, NULL);
-      }
-      s_tls_map_.is_initialized = true;
-    }
-  }
+static inline bool SeqOfKeyInUse(uintptr_t seq) {
+  return seq & (1 << SEQ_KEY_IN_USE_BIT);
+}
 
-  ~ScopedTlsMapAccess() {
-    Unlock();
-  }
+static inline bool KeyInValidRange(pthread_key_t key) {
+  return key >= 0 && key < BIONIC_PTHREAD_KEY_COUNT;
+}
 
-  int CreateKey(pthread_key_t* result, void (*key_destructor)(void*)) {
-    // Take the first unallocated key.
-    for (int key = 0; key < BIONIC_TLS_SLOTS; ++key) {
-      if (!IsInUse(key)) {
-        SetInUse(key, key_destructor);
-        *result = key;
-        return 0;
-      }
-    }
-
-    // We hit PTHREAD_KEYS_MAX. POSIX says EAGAIN for this case.
-    return EAGAIN;
-  }
-
-  void DeleteKey(pthread_key_t key) {
-    TLSMAP_WORD(s_tls_map_, key) &= ~TLSMAP_MASK(key);
-    s_tls_map_.key_destructors[key] = NULL;
-  }
-
-  bool IsInUse(pthread_key_t key) {
-    return (TLSMAP_WORD(s_tls_map_, key) & TLSMAP_MASK(key)) != 0;
-  }
-
-  void SetInUse(pthread_key_t key, void (*key_destructor)(void*)) {
-    TLSMAP_WORD(s_tls_map_, key) |= TLSMAP_MASK(key);
-    s_tls_map_.key_destructors[key] = key_destructor;
-  }
-
-  // Called from pthread_exit() to remove all TLS key data
-  // from this thread's TLS area. This must call the destructor of all keys
-  // that have a non-NULL data value and a non-NULL destructor.
-  void CleanAll() {
-    void** tls = __get_tls();
-
-    // Because destructors can do funky things like deleting/creating other
-    // keys, we need to implement this in a loop.
-    for (int rounds = PTHREAD_DESTRUCTOR_ITERATIONS; rounds > 0; --rounds) {
-      size_t called_destructor_count = 0;
-      for (int key = 0; key < BIONIC_TLS_SLOTS; ++key) {
-        if (IsInUse(key)) {
-          void* data = tls[key];
-          void (*key_destructor)(void*) = s_tls_map_.key_destructors[key];
-
-          if (data != NULL && key_destructor != NULL) {
-            // we need to clear the key data now, this will prevent the
-            // destructor (or a later one) from seeing the old value if
-            // it calls pthread_getspecific() for some odd reason
-
-            // we do not do this if 'key_destructor == NULL' just in case another
-            // destructor function might be responsible for manually
-            // releasing the corresponding data.
-            tls[key] = NULL;
-
-            // because the destructor is free to call pthread_key_create
-            // and/or pthread_key_delete, we need to temporarily unlock
-            // the TLS map
-            Unlock();
-            (*key_destructor)(data);
-            Lock();
-            ++called_destructor_count;
-          }
-        }
-      }
-
-      // If we didn't call any destructors, there is no need to check the TLS data again.
-      if (called_destructor_count == 0) {
-        break;
-      }
-    }
-  }
-
- private:
-  static tls_map_t s_tls_map_;
-  static pthread_mutex_t s_tls_map_lock_;
-
-  void Lock() {
-    pthread_mutex_lock(&s_tls_map_lock_);
-  }
-
-  void Unlock() {
-    pthread_mutex_unlock(&s_tls_map_lock_);
-  }
-};
-
-__LIBC_HIDDEN__ tls_map_t ScopedTlsMapAccess::s_tls_map_;
-__LIBC_HIDDEN__ pthread_mutex_t ScopedTlsMapAccess::s_tls_map_lock_;
-
+// Called from pthread_exit() to remove all pthread keys. This must call the destructor of
+// all keys that have a non-NULL data value and a non-NULL destructor.
 __LIBC_HIDDEN__ void pthread_key_clean_all() {
-  ScopedTlsMapAccess tls_map;
-  tls_map.CleanAll();
+  // Because destructors can do funky things like deleting/creating other keys,
+  // we need to implement this in a loop.
+  pthread_key_data_t* key_data = __get_thread()->key_data;
+  for (size_t rounds = PTHREAD_DESTRUCTOR_ITERATIONS; rounds > 0; --rounds) {
+    size_t called_destructor_count = 0;
+    for (size_t i = 0; i < BIONIC_PTHREAD_KEY_COUNT; ++i) {
+      uintptr_t seq = atomic_load_explicit(&key_map[i].seq, memory_order_relaxed);
+      if (SeqOfKeyInUse(seq) && seq == key_data[i].seq && key_data[i].data != NULL) {
+        // Other threads may be calling pthread_key_delete/pthread_key_create while current thread
+        // is exiting. So we need to ensure we read the right key_destructor.
+        // We can rely on a user-established happens-before relationship between the creation and
+        // use of pthread key to ensure that we're not getting an earlier key_destructor.
+        // To avoid using the key_destructor of the newly created key in the same slot, we need to
+        // recheck the sequence number after reading key_destructor. As a result, we either see the
+        // right key_destructor, or the sequence number must have changed when we reread it below.
+        key_destructor_t key_destructor = reinterpret_cast<key_destructor_t>(
+          atomic_load_explicit(&key_map[i].key_destructor, memory_order_relaxed));
+        if (key_destructor == NULL) {
+          continue;
+        }
+        atomic_thread_fence(memory_order_acquire);
+        if (atomic_load_explicit(&key_map[i].seq, memory_order_relaxed) != seq) {
+           continue;
+        }
+
+        // We need to clear the key data now, this will prevent the destructor (or a later one)
+        // from seeing the old value if it calls pthread_getspecific().
+        // We don't do this if 'key_destructor == NULL' just in case another destructor
+        // function is responsible for manually releasing the corresponding data.
+        void* data = key_data[i].data;
+        key_data[i].data = NULL;
+
+        (*key_destructor)(data);
+        ++called_destructor_count;
+      }
+    }
+
+    // If we didn't call any destructors, there is no need to check the pthread keys again.
+    if (called_destructor_count == 0) {
+      break;
+    }
+  }
 }
 
 int pthread_key_create(pthread_key_t* key, void (*key_destructor)(void*)) {
-  ScopedTlsMapAccess tls_map;
-  return tls_map.CreateKey(key, key_destructor);
+  for (size_t i = 0; i < BIONIC_PTHREAD_KEY_COUNT; ++i) {
+    uintptr_t seq = atomic_load_explicit(&key_map[i].seq, memory_order_relaxed);
+    while (!SeqOfKeyInUse(seq)) {
+      if (atomic_compare_exchange_weak(&key_map[i].seq, &seq, seq + SEQ_INCREMENT_STEP)) {
+        atomic_store(&key_map[i].key_destructor, reinterpret_cast<uintptr_t>(key_destructor));
+        *key = i;
+        return 0;
+      }
+    }
+  }
+  return EAGAIN;
 }
 
 // Deletes a pthread_key_t. note that the standard mandates that this does
@@ -204,42 +127,44 @@
 // responsibility of the caller to properly dispose of the corresponding data
 // and resources, using any means it finds suitable.
 int pthread_key_delete(pthread_key_t key) {
-  ScopedTlsMapAccess tls_map;
-
-  if (!IsValidUserKey(key) || !tls_map.IsInUse(key)) {
+  if (!KeyInValidRange(key)) {
     return EINVAL;
   }
-
-  // Clear value in all threads.
-  pthread_mutex_lock(&g_thread_list_lock);
-  for (pthread_internal_t*  t = g_thread_list; t != NULL; t = t->next) {
-    t->tls[key] = NULL;
+  // Increase seq to invalidate values in all threads.
+  uintptr_t seq = atomic_load_explicit(&key_map[key].seq, memory_order_relaxed);
+  if (SeqOfKeyInUse(seq)) {
+    if (atomic_compare_exchange_strong(&key_map[key].seq, &seq, seq + SEQ_INCREMENT_STEP)) {
+      return 0;
+    }
   }
-  tls_map.DeleteKey(key);
-
-  pthread_mutex_unlock(&g_thread_list_lock);
-  return 0;
+  return EINVAL;
 }
 
 void* pthread_getspecific(pthread_key_t key) {
-  if (!IsValidUserKey(key)) {
+  if (!KeyInValidRange(key)) {
     return NULL;
   }
-
-  // For performance reasons, we do not lock/unlock the global TLS map
-  // to check that the key is properly allocated. If the key was not
-  // allocated, the value read from the TLS should always be NULL
-  // due to pthread_key_delete() clearing the values for all threads.
-  return __get_tls()[key];
+  uintptr_t seq = atomic_load_explicit(&key_map[key].seq, memory_order_relaxed);
+  pthread_key_data_t* data = &(__get_thread()->key_data[key]);
+  // It is user's responsibility to synchornize between the creation and use of pthread keys,
+  // so we use memory_order_relaxed when checking the sequence number.
+  if (__predict_true(SeqOfKeyInUse(seq) && data->seq == seq)) {
+    return data->data;
+  }
+  data->data = NULL;
+  return NULL;
 }
 
 int pthread_setspecific(pthread_key_t key, const void* ptr) {
-  ScopedTlsMapAccess tls_map;
-
-  if (!IsValidUserKey(key) || !tls_map.IsInUse(key)) {
+  if (!KeyInValidRange(key)) {
     return EINVAL;
   }
-
-  __get_tls()[key] = const_cast<void*>(ptr);
-  return 0;
+  uintptr_t seq = atomic_load_explicit(&key_map[key].seq, memory_order_relaxed);
+  if (SeqOfKeyInUse(seq)) {
+    pthread_key_data_t* data = &(__get_thread()->key_data[key]);
+    data->seq = seq;
+    data->data = const_cast<void*>(ptr);
+    return 0;
+  }
+  return EINVAL;
 }