ART: Move access flags checking to dex file verifier

Actually implement all the access flags checking in the dex file
verifier. Add tests.

Change-Id: I8b797357831b588589d56d6e2e22f7b410f33008
diff --git a/runtime/dex_file_verifier.cc b/runtime/dex_file_verifier.cc
index eec4983..2dd96aa 100644
--- a/runtime/dex_file_verifier.cc
+++ b/runtime/dex_file_verifier.cc
@@ -16,7 +16,9 @@
 
 #include "dex_file_verifier.h"
 
+#include <inttypes.h>
 #include <zlib.h>
+
 #include <memory>
 
 #include "base/stringprintf.h"
@@ -444,66 +446,86 @@
   return true;
 }
 
-bool DexFileVerifier::CheckClassDataItemField(uint32_t idx, uint32_t access_flags,
+bool DexFileVerifier::CheckClassDataItemField(uint32_t idx,
+                                              uint32_t access_flags,
+                                              uint32_t class_access_flags,
+                                              uint32_t class_type_index,
                                               bool expect_static) {
+  // Check for overflow.
   if (!CheckIndex(idx, header_->field_ids_size_, "class_data_item field_idx")) {
     return false;
   }
 
+  // Check that it's the right class.
+  uint16_t my_class_index =
+      (reinterpret_cast<const DexFile::FieldId*>(begin_ + header_->field_ids_off_) + idx)->
+          class_idx_;
+  if (class_type_index != my_class_index) {
+    ErrorStringPrintf("Field's class index unexpected, %" PRIu16 "vs %" PRIu16,
+                      my_class_index,
+                      class_type_index);
+    return false;
+  }
+
+  // Check that it falls into the right class-data list.
   bool is_static = (access_flags & kAccStatic) != 0;
   if (UNLIKELY(is_static != expect_static)) {
     ErrorStringPrintf("Static/instance field not in expected list");
     return false;
   }
 
-  if (UNLIKELY((access_flags & ~kAccJavaFlagsMask) != 0)) {
-    ErrorStringPrintf("Bad class_data_item field access_flags %x", access_flags);
+  // Check field access flags.
+  std::string error_msg;
+  if (!CheckFieldAccessFlags(access_flags, class_access_flags, &error_msg)) {
+    ErrorStringPrintf("%s", error_msg.c_str());
     return false;
   }
 
   return true;
 }
 
-bool DexFileVerifier::CheckClassDataItemMethod(uint32_t idx, uint32_t access_flags,
+bool DexFileVerifier::CheckClassDataItemMethod(uint32_t idx,
+                                               uint32_t access_flags,
+                                               uint32_t class_access_flags,
+                                               uint32_t class_type_index,
                                                uint32_t code_offset,
-                                               std::unordered_set<uint32_t>& direct_method_indexes,
+                                               std::unordered_set<uint32_t>* direct_method_indexes,
                                                bool expect_direct) {
+  DCHECK(direct_method_indexes != nullptr);
+  // Check for overflow.
   if (!CheckIndex(idx, header_->method_ids_size_, "class_data_item method_idx")) {
     return false;
   }
 
-  bool is_direct = (access_flags & (kAccStatic | kAccPrivate | kAccConstructor)) != 0;
-  bool expect_code = (access_flags & (kAccNative | kAccAbstract)) == 0;
-  bool is_synchronized = (access_flags & kAccSynchronized) != 0;
-  bool allow_synchronized = (access_flags & kAccNative) != 0;
-
-  if (UNLIKELY(is_direct != expect_direct)) {
-    ErrorStringPrintf("Direct/virtual method not in expected list");
+  // Check that it's the right class.
+  uint16_t my_class_index =
+      (reinterpret_cast<const DexFile::MethodId*>(begin_ + header_->method_ids_off_) + idx)->
+          class_idx_;
+  if (class_type_index != my_class_index) {
+    ErrorStringPrintf("Method's class index unexpected, %" PRIu16 "vs %" PRIu16,
+                      my_class_index,
+                      class_type_index);
     return false;
   }
 
+  // Check that it's not defined as both direct and virtual.
   if (expect_direct) {
-    direct_method_indexes.insert(idx);
-  } else if (direct_method_indexes.find(idx) != direct_method_indexes.end()) {
+    direct_method_indexes->insert(idx);
+  } else if (direct_method_indexes->find(idx) != direct_method_indexes->end()) {
     ErrorStringPrintf("Found virtual method with same index as direct method: %d", idx);
     return false;
   }
 
-  constexpr uint32_t access_method_mask = kAccJavaFlagsMask | kAccConstructor |
-      kAccDeclaredSynchronized;
-  if (UNLIKELY(((access_flags & ~access_method_mask) != 0) ||
-               (is_synchronized && !allow_synchronized))) {
-    ErrorStringPrintf("Bad class_data_item method access_flags %x", access_flags);
-    return false;
-  }
-
-  if (UNLIKELY(expect_code && (code_offset == 0))) {
-    ErrorStringPrintf("Unexpected zero value for class_data_item method code_off with access "
-                      "flags %x", access_flags);
-    return false;
-  } else if (UNLIKELY(!expect_code && (code_offset != 0))) {
-    ErrorStringPrintf("Unexpected non-zero value %x for class_data_item method code_off"
-                      " with access flags %x", code_offset, access_flags);
+  // Check method access flags.
+  bool has_code = (code_offset != 0);
+  std::string error_msg;
+  if (!CheckMethodAccessFlags(idx,
+                              access_flags,
+                              class_access_flags,
+                              has_code,
+                              expect_direct,
+                              &error_msg)) {
+    ErrorStringPrintf("%s", error_msg.c_str());
     return false;
   }
 
@@ -689,60 +711,185 @@
   return true;
 }
 
+bool DexFileVerifier::FindClassFlags(uint32_t index,
+                                     bool is_field,
+                                     uint16_t* class_type_index,
+                                     uint32_t* class_access_flags) {
+  DCHECK(class_type_index != nullptr);
+  DCHECK(class_access_flags != nullptr);
+
+  // First check if the index is valid.
+  if (index >= (is_field ? header_->field_ids_size_ : header_->method_ids_size_)) {
+    return false;
+  }
+
+  // Next get the type index.
+  if (is_field) {
+    *class_type_index =
+        (reinterpret_cast<const DexFile::FieldId*>(begin_ + header_->field_ids_off_) + index)->
+            class_idx_;
+  } else {
+    *class_type_index =
+        (reinterpret_cast<const DexFile::MethodId*>(begin_ + header_->method_ids_off_) + index)->
+            class_idx_;
+  }
+
+  // Check if that is valid.
+  if (*class_type_index >= header_->type_ids_size_) {
+    return false;
+  }
+
+  // Now search for the class def. This is basically a specialized version of the DexFile code, as
+  // we should not trust that this is a valid DexFile just yet.
+  const DexFile::ClassDef* class_def_begin =
+      reinterpret_cast<const DexFile::ClassDef*>(begin_ + header_->class_defs_off_);
+  for (size_t i = 0; i < header_->class_defs_size_; ++i) {
+    const DexFile::ClassDef* class_def = class_def_begin + i;
+    if (class_def->class_idx_ == *class_type_index) {
+      *class_access_flags = class_def->access_flags_;
+      return true;
+    }
+  }
+
+  // Didn't find the class-def, not defined here...
+  return false;
+}
+
+bool DexFileVerifier::CheckOrderAndGetClassFlags(bool is_field,
+                                                 const char* type_descr,
+                                                 uint32_t curr_index,
+                                                 uint32_t prev_index,
+                                                 bool* have_class,
+                                                 uint16_t* class_type_index,
+                                                 uint32_t* class_access_flags) {
+  if (curr_index < prev_index) {
+    ErrorStringPrintf("out-of-order %s indexes %" PRIu32 " and %" PRIu32,
+                      type_descr,
+                      prev_index,
+                      curr_index);
+    return false;
+  }
+
+  if (!*have_class) {
+    *have_class = FindClassFlags(curr_index, is_field, class_type_index, class_access_flags);
+    if (!*have_class) {
+      // Should have really found one.
+      ErrorStringPrintf("could not find declaring class for %s index %" PRIu32,
+                        type_descr,
+                        curr_index);
+      return false;
+    }
+  }
+  return true;
+}
+
+template <bool kStatic>
+bool DexFileVerifier::CheckIntraClassDataItemFields(ClassDataItemIterator* it,
+                                                    bool* have_class,
+                                                    uint16_t* class_type_index,
+                                                    uint32_t* class_access_flags) {
+  DCHECK(it != nullptr);
+  // These calls use the raw access flags to check whether the whole dex field is valid.
+  uint32_t prev_index = 0;
+  for (; kStatic ? it->HasNextStaticField() : it->HasNextInstanceField(); it->Next()) {
+    uint32_t curr_index = it->GetMemberIndex();
+    if (!CheckOrderAndGetClassFlags(true,
+                                    kStatic ? "static field" : "instance field",
+                                    curr_index,
+                                    prev_index,
+                                    have_class,
+                                    class_type_index,
+                                    class_access_flags)) {
+      return false;
+    }
+    prev_index = curr_index;
+
+    if (!CheckClassDataItemField(curr_index,
+                                 it->GetRawMemberAccessFlags(),
+                                 *class_access_flags,
+                                 *class_type_index,
+                                 kStatic)) {
+      return false;
+    }
+  }
+
+  return true;
+}
+
+template <bool kDirect>
+bool DexFileVerifier::CheckIntraClassDataItemMethods(
+    ClassDataItemIterator* it,
+    std::unordered_set<uint32_t>* direct_method_indexes,
+    bool* have_class,
+    uint16_t* class_type_index,
+    uint32_t* class_access_flags) {
+  uint32_t prev_index = 0;
+  for (; kDirect ? it->HasNextDirectMethod() : it->HasNextVirtualMethod(); it->Next()) {
+    uint32_t curr_index = it->GetMemberIndex();
+    if (!CheckOrderAndGetClassFlags(false,
+                                    kDirect ? "direct method" : "virtual method",
+                                    curr_index,
+                                    prev_index,
+                                    have_class,
+                                    class_type_index,
+                                    class_access_flags)) {
+      return false;
+    }
+    prev_index = curr_index;
+
+    if (!CheckClassDataItemMethod(curr_index,
+                                  it->GetRawMemberAccessFlags(),
+                                  *class_access_flags,
+                                  *class_type_index,
+                                  it->GetMethodCodeItemOffset(),
+                                  direct_method_indexes,
+                                  kDirect)) {
+      return false;
+    }
+  }
+
+  return true;
+}
+
 bool DexFileVerifier::CheckIntraClassDataItem() {
   ClassDataItemIterator it(*dex_file_, ptr_);
   std::unordered_set<uint32_t> direct_method_indexes;
 
-  // These calls use the raw access flags to check whether the whole dex field is valid.
-  uint32_t prev_index = 0;
-  for (; it.HasNextStaticField(); it.Next()) {
-    uint32_t curr_index = it.GetMemberIndex();
-    if (curr_index < prev_index) {
-      ErrorStringPrintf("out-of-order static field indexes %d and %d", prev_index, curr_index);
-      return false;
-    }
-    prev_index = curr_index;
-    if (!CheckClassDataItemField(curr_index, it.GetRawMemberAccessFlags(), true)) {
-      return false;
-    }
+  // This code is complicated by the fact that we don't directly know which class this belongs to.
+  // So we need to explicitly search with the first item we find (either field or method), and then,
+  // as the lookup is expensive, cache the result.
+  bool have_class = false;
+  uint16_t class_type_index;
+  uint32_t class_access_flags;
+
+  // Check fields.
+  if (!CheckIntraClassDataItemFields<true>(&it,
+                                           &have_class,
+                                           &class_type_index,
+                                           &class_access_flags)) {
+    return false;
   }
-  prev_index = 0;
-  for (; it.HasNextInstanceField(); it.Next()) {
-    uint32_t curr_index = it.GetMemberIndex();
-    if (curr_index < prev_index) {
-      ErrorStringPrintf("out-of-order instance field indexes %d and %d", prev_index, curr_index);
-      return false;
-    }
-    prev_index = curr_index;
-    if (!CheckClassDataItemField(curr_index, it.GetRawMemberAccessFlags(), false)) {
-      return false;
-    }
+  if (!CheckIntraClassDataItemFields<false>(&it,
+                                            &have_class,
+                                            &class_type_index,
+                                            &class_access_flags)) {
+    return false;
   }
-  prev_index = 0;
-  for (; it.HasNextDirectMethod(); it.Next()) {
-    uint32_t curr_index = it.GetMemberIndex();
-    if (curr_index < prev_index) {
-      ErrorStringPrintf("out-of-order direct method indexes %d and %d", prev_index, curr_index);
-      return false;
-    }
-    prev_index = curr_index;
-    if (!CheckClassDataItemMethod(curr_index, it.GetRawMemberAccessFlags(),
-        it.GetMethodCodeItemOffset(), direct_method_indexes, true)) {
-      return false;
-    }
+
+  // Check methods.
+  if (!CheckIntraClassDataItemMethods<true>(&it,
+                                            &direct_method_indexes,
+                                            &have_class,
+                                            &class_type_index,
+                                            &class_access_flags)) {
+    return false;
   }
-  prev_index = 0;
-  for (; it.HasNextVirtualMethod(); it.Next()) {
-    uint32_t curr_index = it.GetMemberIndex();
-    if (curr_index < prev_index) {
-      ErrorStringPrintf("out-of-order virtual method indexes %d and %d", prev_index, curr_index);
-      return false;
-    }
-    prev_index = curr_index;
-    if (!CheckClassDataItemMethod(curr_index, it.GetRawMemberAccessFlags(),
-        it.GetMethodCodeItemOffset(), direct_method_indexes, false)) {
-      return false;
-    }
+  if (!CheckIntraClassDataItemMethods<false>(&it,
+                                             &direct_method_indexes,
+                                             &have_class,
+                                             &class_type_index,
+                                             &class_access_flags)) {
+    return false;
   }
 
   ptr_ = it.EndDataPointer();
@@ -2149,4 +2296,259 @@
   va_end(ap);
 }
 
+// Fields and methods may have only one of public/protected/private.
+static bool CheckAtMostOneOfPublicProtectedPrivate(uint32_t flags) {
+  size_t count = (((flags & kAccPublic) == 0) ? 0 : 1) +
+                 (((flags & kAccProtected) == 0) ? 0 : 1) +
+                 (((flags & kAccPrivate) == 0) ? 0 : 1);
+  return count <= 1;
+}
+
+bool DexFileVerifier::CheckFieldAccessFlags(uint32_t field_access_flags,
+                                            uint32_t class_access_flags,
+                                            std::string* error_msg) {
+  // Generally sort out >16-bit flags.
+  if ((field_access_flags & ~kAccJavaFlagsMask) != 0) {
+    *error_msg = StringPrintf("Bad class_data_item field access_flags %x", field_access_flags);
+    return false;
+  }
+
+  // Flags allowed on fields, in general. Other lower-16-bit flags are to be ignored.
+  constexpr uint32_t kFieldAccessFlags = kAccPublic |
+                                         kAccPrivate |
+                                         kAccProtected |
+                                         kAccStatic |
+                                         kAccFinal |
+                                         kAccVolatile |
+                                         kAccTransient |
+                                         kAccSynthetic |
+                                         kAccEnum;
+
+  // Fields may have only one of public/protected/final.
+  if (!CheckAtMostOneOfPublicProtectedPrivate(field_access_flags)) {
+    *error_msg = StringPrintf("Field may have only one of public/protected/private, %x",
+                              field_access_flags);
+    return false;
+  }
+
+  // Interfaces have a pretty restricted list.
+  if ((class_access_flags & kAccInterface) != 0) {
+    // Interface fields must be public final static.
+    constexpr uint32_t kPublicFinalStatic = kAccPublic | kAccFinal | kAccStatic;
+    if ((field_access_flags & kPublicFinalStatic) != kPublicFinalStatic) {
+      *error_msg = StringPrintf("Interface field is not public final static: %x",
+                                field_access_flags);
+      return false;
+    }
+    // Interface fields may be synthetic, but may not have other flags.
+    constexpr uint32_t kDisallowed = ~(kPublicFinalStatic | kAccSynthetic);
+    if ((field_access_flags & kFieldAccessFlags & kDisallowed) != 0) {
+      *error_msg = StringPrintf("Interface field has disallowed flag: %x", field_access_flags);
+      return false;
+    }
+    return true;
+  }
+
+  // Volatile fields may not be final.
+  constexpr uint32_t kVolatileFinal = kAccVolatile | kAccFinal;
+  if ((field_access_flags & kVolatileFinal) == kVolatileFinal) {
+    *error_msg = "Fields may not be volatile and final";
+    return false;
+  }
+
+  return true;
+}
+
+// Try to find the name of the method with the given index. We do not want to rely on DexFile
+// infrastructure at this point, so do it all by hand. begin and header correspond to begin_ and
+// header_ of the DexFileVerifier. str will contain the pointer to the method name on success
+// (flagged by the return value), otherwise error_msg will contain an error string.
+static bool FindMethodName(uint32_t method_index,
+                           const uint8_t* begin,
+                           const DexFile::Header* header,
+                           const char** str,
+                           std::string* error_msg) {
+  if (method_index >= header->method_ids_size_) {
+    *error_msg = "Method index not available for method flags verification";
+    return false;
+  }
+  uint32_t string_idx =
+      (reinterpret_cast<const DexFile::MethodId*>(begin + header->method_ids_off_) +
+          method_index)->name_idx_;
+  if (string_idx >= header->string_ids_size_) {
+    *error_msg = "String index not available for method flags verification";
+    return false;
+  }
+  uint32_t string_off =
+      (reinterpret_cast<const DexFile::StringId*>(begin + header->string_ids_off_) + string_idx)->
+          string_data_off_;
+  if (string_off >= header->file_size_) {
+    *error_msg = "String offset out of bounds for method flags verification";
+    return false;
+  }
+  const uint8_t* str_data_ptr = begin + string_off;
+  DecodeUnsignedLeb128(&str_data_ptr);
+  *str = reinterpret_cast<const char*>(str_data_ptr);
+  return true;
+}
+
+bool DexFileVerifier::CheckMethodAccessFlags(uint32_t method_index,
+                                             uint32_t method_access_flags,
+                                             uint32_t class_access_flags,
+                                             bool has_code,
+                                             bool expect_direct,
+                                             std::string* error_msg) {
+  // Generally sort out >16-bit flags, except dex knows Constructor and DeclaredSynchronized.
+  constexpr uint32_t kAllMethodFlags =
+      kAccJavaFlagsMask | kAccConstructor | kAccDeclaredSynchronized;
+  if ((method_access_flags & ~kAllMethodFlags) != 0) {
+    *error_msg = StringPrintf("Bad class_data_item method access_flags %x", method_access_flags);
+    return false;
+  }
+
+  // Flags allowed on fields, in general. Other lower-16-bit flags are to be ignored.
+  constexpr uint32_t kMethodAccessFlags = kAccPublic |
+                                          kAccPrivate |
+                                          kAccProtected |
+                                          kAccStatic |
+                                          kAccFinal |
+                                          kAccSynthetic |
+                                          kAccSynchronized |
+                                          kAccBridge |
+                                          kAccVarargs |
+                                          kAccNative |
+                                          kAccAbstract |
+                                          kAccStrict;
+
+  // Methods may have only one of public/protected/final.
+  if (!CheckAtMostOneOfPublicProtectedPrivate(method_access_flags)) {
+    *error_msg = StringPrintf("Method may have only one of public/protected/private, %x",
+                              method_access_flags);
+    return false;
+  }
+
+  // Try to find the name, to check for constructor properties.
+  const char* str;
+  if (!FindMethodName(method_index, begin_, header_, &str, error_msg)) {
+    return false;
+  }
+  bool is_init_by_name = false;
+  constexpr const char* kInitName = "<init>";
+  size_t str_offset = (reinterpret_cast<const uint8_t*>(str) - begin_);
+  if (header_->file_size_ - str_offset >= sizeof(kInitName)) {
+    is_init_by_name = strcmp(kInitName, str) == 0;
+  }
+  bool is_clinit_by_name = false;
+  constexpr const char* kClinitName = "<clinit>";
+  if (header_->file_size_ - str_offset >= sizeof(kClinitName)) {
+    is_clinit_by_name = strcmp(kClinitName, str) == 0;
+  }
+  bool is_constructor = is_init_by_name || is_clinit_by_name;
+
+  // Only methods named "<clinit>" or "<init>" may be marked constructor. Note: we cannot enforce
+  // the reverse for backwards compatibility reasons.
+  if (((method_access_flags & kAccConstructor) != 0) && !is_constructor) {
+    *error_msg = StringPrintf("Method %" PRIu32 " is marked constructor, but doesn't match name",
+                              method_index);
+    return false;
+  }
+  // Check that the static constructor (= static initializer) is named "<clinit>" and that the
+  // instance constructor is called "<init>".
+  if (is_constructor) {
+    bool is_static = (method_access_flags & kAccStatic) != 0;
+    if (is_static ^ is_clinit_by_name) {
+      *error_msg = StringPrintf("Constructor %" PRIu32 " is not flagged correctly wrt/ static.",
+                                method_index);
+      return false;
+    }
+  }
+  // Check that static and private methods, as well as constructors, are in the direct methods list,
+  // and other methods in the virtual methods list.
+  bool is_direct = (method_access_flags & (kAccStatic | kAccPrivate)) != 0 || is_constructor;
+  if (is_direct != expect_direct) {
+    *error_msg = StringPrintf("Direct/virtual method %" PRIu32 " not in expected list %d",
+                              method_index,
+                              expect_direct);
+    return false;
+  }
+
+
+  // From here on out it is easier to mask out the bits we're supposed to ignore.
+  method_access_flags &= kMethodAccessFlags;
+
+  // If there aren't any instructions, make sure that's expected.
+  if (!has_code) {
+    // Only native or abstract methods may not have code.
+    if ((method_access_flags & (kAccNative | kAccAbstract)) == 0) {
+      *error_msg = StringPrintf("Method %" PRIu32 " has no code, but is not marked native or "
+                                "abstract",
+                                method_index);
+      return false;
+    }
+    // Constructors must always have code.
+    if (is_constructor) {
+      *error_msg = StringPrintf("Constructor %u must not be abstract or native", method_index);
+      return false;
+    }
+    if ((method_access_flags & kAccAbstract) != 0) {
+      // Abstract methods are not allowed to have the following flags.
+      constexpr uint32_t kForbidden =
+          kAccPrivate | kAccStatic | kAccFinal | kAccNative | kAccStrict | kAccSynchronized;
+      if ((method_access_flags & kForbidden) != 0) {
+        *error_msg = StringPrintf("Abstract method %" PRIu32 " has disallowed access flags %x",
+                                  method_index,
+                                  method_access_flags);
+        return false;
+      }
+      // Abstract methods must be in an abstract class or interface.
+      if ((class_access_flags & (kAccInterface | kAccAbstract)) == 0) {
+        *error_msg = StringPrintf("Method %" PRIu32 " is abstract, but the declaring class "
+                                  "is neither abstract nor an interface", method_index);
+        return false;
+      }
+    }
+    // Interfaces are special.
+    if ((class_access_flags & kAccInterface) != 0) {
+      // Interface methods must be public and abstract.
+      if ((method_access_flags & (kAccPublic | kAccAbstract)) != (kAccPublic | kAccAbstract)) {
+        *error_msg = StringPrintf("Interface method %" PRIu32 " is not public and abstract",
+                                  method_index);
+        return false;
+      }
+      // At this point, we know the method is public and abstract. This means that all the checks
+      // for invalid combinations above applies. In addition, interface methods must not be
+      // protected. This is caught by the check for only-one-of-public-protected-private.
+    }
+    return true;
+  }
+
+  // When there's code, the method must not be native or abstract.
+  if ((method_access_flags & (kAccNative | kAccAbstract)) != 0) {
+    *error_msg = StringPrintf("Method %" PRIu32 " has code, but is marked native or abstract",
+                              method_index);
+    return false;
+  }
+
+  // Only the static initializer may have code in an interface.
+  if (((class_access_flags & kAccInterface) != 0) && !is_clinit_by_name) {
+    *error_msg = StringPrintf("Non-clinit interface method %" PRIu32 " should not have code",
+                              method_index);
+    return false;
+  }
+
+  // Instance constructors must not be synchronized and a few other flags.
+  if (is_init_by_name) {
+    static constexpr uint32_t kInitAllowed =
+        kAccPrivate | kAccProtected | kAccPublic | kAccStrict | kAccVarargs | kAccSynthetic;
+    if ((method_access_flags & ~kInitAllowed) != 0) {
+      *error_msg = StringPrintf("Constructor %" PRIu32 " flagged inappropriately %x",
+                                method_index,
+                                method_access_flags);
+      return false;
+    }
+  }
+
+  return true;
+}
+
 }  // namespace art