ART: Add JNI function table manipulation

Add support for a function table override. This will override the
decision between the regular and the CheckJNI function tables, if
set.

Bug: 34343708
Test: m test-art-host-gtest-jni_internal_test
Change-Id: I0e95b0cbd21f4efdcd8c3d312781d9aeeff54a1e
diff --git a/runtime/base/mutex.cc b/runtime/base/mutex.cc
index 9116097..e05a85a 100644
--- a/runtime/base/mutex.cc
+++ b/runtime/base/mutex.cc
@@ -46,6 +46,7 @@
 ReaderWriterMutex* Locks::heap_bitmap_lock_ = nullptr;
 Mutex* Locks::instrument_entrypoints_lock_ = nullptr;
 Mutex* Locks::intern_table_lock_ = nullptr;
+Mutex* Locks::jni_function_table_lock_ = nullptr;
 Mutex* Locks::jni_libraries_lock_ = nullptr;
 Mutex* Locks::logging_lock_ = nullptr;
 Mutex* Locks::mem_maps_lock_ = nullptr;
@@ -957,6 +958,7 @@
     DCHECK(verifier_deps_lock_ != nullptr);
     DCHECK(host_dlopen_handles_lock_ != nullptr);
     DCHECK(intern_table_lock_ != nullptr);
+    DCHECK(jni_function_table_lock_ != nullptr);
     DCHECK(jni_libraries_lock_ != nullptr);
     DCHECK(logging_lock_ != nullptr);
     DCHECK(mutator_lock_ != nullptr);
@@ -1098,6 +1100,10 @@
     DCHECK(jni_weak_globals_lock_ == nullptr);
     jni_weak_globals_lock_ = new Mutex("JNI weak global reference table lock", current_lock_level);
 
+    UPDATE_CURRENT_LOCK_LEVEL(kJniFunctionTableLock);
+    DCHECK(jni_function_table_lock_ == nullptr);
+    jni_function_table_lock_ = new Mutex("JNI function table lock", current_lock_level);
+
     UPDATE_CURRENT_LOCK_LEVEL(kAbortLock);
     DCHECK(abort_lock_ == nullptr);
     abort_lock_ = new Mutex("abort lock", current_lock_level, true);
diff --git a/runtime/base/mutex.h b/runtime/base/mutex.h
index 2adeb8c..21dd437 100644
--- a/runtime/base/mutex.h
+++ b/runtime/base/mutex.h
@@ -68,6 +68,7 @@
   kRosAllocBulkFreeLock,
   kMarkSweepMarkStackLock,
   kTransactionLogLock,
+  kJniFunctionTableLock,
   kJniWeakGlobalsLock,
   kJniGlobalsLock,
   kReferenceQueueSoftReferencesLock,
@@ -698,8 +699,11 @@
   // Guard accesses to the JNI Weak Global Reference table.
   static Mutex* jni_weak_globals_lock_ ACQUIRED_AFTER(jni_globals_lock_);
 
+  // Guard accesses to the JNI function table override.
+  static Mutex* jni_function_table_lock_ ACQUIRED_AFTER(jni_weak_globals_lock_);
+
   // Have an exclusive aborting thread.
-  static Mutex* abort_lock_ ACQUIRED_AFTER(jni_weak_globals_lock_);
+  static Mutex* abort_lock_ ACQUIRED_AFTER(jni_function_table_lock_);
 
   // Allow mutual exclusion when manipulating Thread::suspend_count_.
   // TODO: Does the trade-off of a per-thread lock make sense?
diff --git a/runtime/jni_env_ext.cc b/runtime/jni_env_ext.cc
index 5a3fafa..0148a1c 100644
--- a/runtime/jni_env_ext.cc
+++ b/runtime/jni_env_ext.cc
@@ -29,6 +29,7 @@
 #include "mirror/object-inl.h"
 #include "nth_caller_visitor.h"
 #include "thread-inl.h"
+#include "thread_list.h"
 
 namespace art {
 
@@ -37,6 +38,8 @@
 static constexpr size_t kMonitorsInitial = 32;  // Arbitrary.
 static constexpr size_t kMonitorsMax = 4096;  // Arbitrary sanity check.
 
+const JNINativeInterface* JNIEnvExt::table_override_ = nullptr;
+
 // Checking "locals" requires the mutator lock, but at creation time we're really only interested
 // in validity, which isn't changing. To avoid grabbing the mutator lock, factored out and tagged
 // with NO_THREAD_SAFETY_ANALYSIS.
@@ -78,10 +81,10 @@
       runtime_deleted(false),
       critical(0),
       monitors("monitors", kMonitorsInitial, kMonitorsMax) {
-  functions = unchecked_functions = GetJniNativeInterface();
-  if (vm->IsCheckJniEnabled()) {
-    SetCheckJniEnabled(true);
-  }
+  MutexLock mu(Thread::Current(), *Locks::jni_function_table_lock_);
+  check_jni = vm->IsCheckJniEnabled();
+  functions = GetFunctionTable(check_jni);
+  unchecked_functions = GetJniNativeInterface();
 }
 
 void JNIEnvExt::SetFunctionsToRuntimeShutdownFunctions() {
@@ -107,7 +110,12 @@
 
 void JNIEnvExt::SetCheckJniEnabled(bool enabled) {
   check_jni = enabled;
-  functions = enabled ? GetCheckJniNativeInterface() : GetJniNativeInterface();
+  MutexLock mu(Thread::Current(), *Locks::jni_function_table_lock_);
+  functions = GetFunctionTable(enabled);
+  // Check whether this is a no-op because of override.
+  if (enabled && JNIEnvExt::table_override_ != nullptr) {
+    LOG(WARNING) << "Enabling CheckJNI after a JNIEnv function table override is not functional.";
+  }
 }
 
 void JNIEnvExt::DumpReferenceTables(std::ostream& os) {
@@ -269,4 +277,33 @@
   }
 }
 
+static void ThreadResetFunctionTable(Thread* thread, void* arg ATTRIBUTE_UNUSED)
+    REQUIRES(Locks::jni_function_table_lock_) {
+  JNIEnvExt* env = thread->GetJniEnv();
+  bool check_jni = env->check_jni;
+  env->functions = JNIEnvExt::GetFunctionTable(check_jni);
+}
+
+void JNIEnvExt::SetTableOverride(const JNINativeInterface* table_override) {
+  MutexLock mu(Thread::Current(), *Locks::thread_list_lock_);
+  MutexLock mu2(Thread::Current(), *Locks::jni_function_table_lock_);
+
+  JNIEnvExt::table_override_ = table_override;
+
+  // See if we have a runtime. Note: we cannot run other code (like JavaVMExt's CheckJNI install
+  // code), as we'd have to recursively lock the mutex.
+  Runtime* runtime = Runtime::Current();
+  if (runtime != nullptr) {
+    runtime->GetThreadList()->ForEach(ThreadResetFunctionTable, nullptr);
+  }
+}
+
+const JNINativeInterface* JNIEnvExt::GetFunctionTable(bool check_jni) {
+  const JNINativeInterface* override = JNIEnvExt::table_override_;
+  if (override != nullptr) {
+    return override;
+  }
+  return check_jni ? GetCheckJniNativeInterface() : GetJniNativeInterface();
+}
+
 }  // namespace art
diff --git a/runtime/jni_env_ext.h b/runtime/jni_env_ext.h
index 5cca0ae..4004c45 100644
--- a/runtime/jni_env_ext.h
+++ b/runtime/jni_env_ext.h
@@ -43,7 +43,7 @@
   void DumpReferenceTables(std::ostream& os)
       REQUIRES_SHARED(Locks::mutator_lock_);
 
-  void SetCheckJniEnabled(bool enabled);
+  void SetCheckJniEnabled(bool enabled) REQUIRES(!Locks::jni_function_table_lock_);
 
   void PushFrame(int capacity) REQUIRES_SHARED(Locks::mutator_lock_);
   void PopFrame() REQUIRES_SHARED(Locks::mutator_lock_);
@@ -104,10 +104,27 @@
   // Set the functions to the runtime shutdown functions.
   void SetFunctionsToRuntimeShutdownFunctions();
 
+  // Set the function table override. This will install the override (or original table, if null)
+  // to all threads.
+  // Note: JNI function table overrides are sensitive to the order of operations wrt/ CheckJNI.
+  //       After overriding the JNI function table, CheckJNI toggling is ignored.
+  static void SetTableOverride(const JNINativeInterface* table_override)
+      REQUIRES(!Locks::thread_list_lock_, !Locks::jni_function_table_lock_);
+
+  // Return either the regular, or the CheckJNI function table. Will return table_override_ instead
+  // if it is not null.
+  static const JNINativeInterface* GetFunctionTable(bool check_jni)
+      REQUIRES(Locks::jni_function_table_lock_);
+
  private:
+  // Override of function tables. This applies to both default as well as instrumented (CheckJNI)
+  // function tables.
+  static const JNINativeInterface* table_override_ GUARDED_BY(Locks::jni_function_table_lock_);
+
   // The constructor should not be called directly. It may leave the object in an erroneous state,
   // and the result needs to be checked.
-  JNIEnvExt(Thread* self, JavaVMExt* vm, std::string* error_msg);
+  JNIEnvExt(Thread* self, JavaVMExt* vm, std::string* error_msg)
+      REQUIRES(!Locks::jni_function_table_lock_);
 
   // All locked objects, with the (Java caller) stack frame that locked them. Used in CheckJNI
   // to ensure that only monitors locked in this native frame are being unlocked, and that at
diff --git a/runtime/jni_internal_test.cc b/runtime/jni_internal_test.cc
index 4da5e23..08d1eeb 100644
--- a/runtime/jni_internal_test.cc
+++ b/runtime/jni_internal_test.cc
@@ -2346,4 +2346,39 @@
   EXPECT_EQ(segment_state_now, segment_state_computed);
 }
 
+static size_t gGlobalRefCount = 0;
+static const JNINativeInterface* gOriginalEnv = nullptr;
+
+static jobject CountNewGlobalRef(JNIEnv* env, jobject o) {
+  ++gGlobalRefCount;
+  return gOriginalEnv->NewGlobalRef(env, o);
+}
+
+// Test the table override.
+TEST_F(JniInternalTest, JNIEnvExtTableOverride) {
+  JNINativeInterface env_override;
+  memcpy(&env_override, env_->functions, sizeof(JNINativeInterface));
+
+  gOriginalEnv = env_->functions;
+  env_override.NewGlobalRef = CountNewGlobalRef;
+  gGlobalRefCount = 0;
+
+  jclass local = env_->FindClass("java/lang/Object");
+  ASSERT_TRUE(local != nullptr);
+
+  // Set the table, add a global ref, see whether the counter increases.
+  JNIEnvExt::SetTableOverride(&env_override);
+
+  jobject global = env_->NewGlobalRef(local);
+  EXPECT_EQ(1u, gGlobalRefCount);
+  env_->DeleteGlobalRef(global);
+
+  // Reset
+  JNIEnvExt::SetTableOverride(nullptr);
+
+  jobject global2 = env_->NewGlobalRef(local);
+  EXPECT_EQ(1u, gGlobalRefCount);
+  env_->DeleteGlobalRef(global2);
+}
+
 }  // namespace art