AAPT2: Ensure style strings are always first in StringPool

Move the styled strings to a separate section of the StringPool so
that sorting can never mess up the order of Styles.

Bug: 63570514
Test: make aapt2_tests
Change-Id: Id2ce1355b92be1bb31ce0daa7e54ae9b5b6c2ffe
diff --git a/tools/aapt2/StringPool.cpp b/tools/aapt2/StringPool.cpp
index 57da5f0..705b1ab 100644
--- a/tools/aapt2/StringPool.cpp
+++ b/tools/aapt2/StringPool.cpp
@@ -27,7 +27,7 @@
 #include "util/BigBuffer.h"
 #include "util/Util.h"
 
-using android::StringPiece;
+using ::android::StringPiece;
 
 namespace aapt {
 
@@ -75,9 +75,14 @@
   return &entry_->value;
 }
 
-const std::string& StringPool::Ref::operator*() const { return entry_->value; }
+const std::string& StringPool::Ref::operator*() const {
+  return entry_->value;
+}
 
-size_t StringPool::Ref::index() const { return entry_->index; }
+size_t StringPool::Ref::index() const {
+  // Account for the styles, which *always* come first.
+  return entry_->pool_->styles_.size() + entry_->index_;
+}
 
 const StringPool::Context& StringPool::Ref::GetContext() const {
   return entry_->context;
@@ -104,8 +109,7 @@
   }
 }
 
-StringPool::StyleRef& StringPool::StyleRef::operator=(
-    const StringPool::StyleRef& rhs) {
+StringPool::StyleRef& StringPool::StyleRef::operator=(const StringPool::StyleRef& rhs) {
   if (rhs.entry_ != nullptr) {
     rhs.entry_->ref_++;
   }
@@ -118,7 +122,7 @@
 }
 
 bool StringPool::StyleRef::operator==(const StyleRef& rhs) const {
-  if (entry_->str != rhs.entry_->str) {
+  if (entry_->value != rhs.entry_->value) {
     return false;
   }
 
@@ -137,7 +141,9 @@
   return true;
 }
 
-bool StringPool::StyleRef::operator!=(const StyleRef& rhs) const { return !operator==(rhs); }
+bool StringPool::StyleRef::operator!=(const StyleRef& rhs) const {
+  return !operator==(rhs);
+}
 
 const StringPool::StyleEntry* StringPool::StyleRef::operator->() const {
   return entry_;
@@ -147,23 +153,24 @@
   return *entry_;
 }
 
-size_t StringPool::StyleRef::index() const { return entry_->str.index(); }
+size_t StringPool::StyleRef::index() const {
+  return entry_->index_;
+}
 
 const StringPool::Context& StringPool::StyleRef::GetContext() const {
-  return entry_->str.GetContext();
+  return entry_->context;
 }
 
 StringPool::Ref StringPool::MakeRef(const StringPiece& str) {
   return MakeRefImpl(str, Context{}, true);
 }
 
-StringPool::Ref StringPool::MakeRef(const StringPiece& str,
-                                    const Context& context) {
+StringPool::Ref StringPool::MakeRef(const StringPiece& str, const Context& context) {
   return MakeRefImpl(str, context, true);
 }
 
-StringPool::Ref StringPool::MakeRefImpl(const StringPiece& str,
-                                        const Context& context, bool unique) {
+StringPool::Ref StringPool::MakeRefImpl(const StringPiece& str, const Context& context,
+                                        bool unique) {
   if (unique) {
     auto iter = indexed_strings_.find(str);
     if (iter != std::end(indexed_strings_)) {
@@ -171,82 +178,87 @@
     }
   }
 
-  Entry* entry = new Entry();
+  std::unique_ptr<Entry> entry(new Entry());
   entry->value = str.to_string();
   entry->context = context;
-  entry->index = strings_.size();
+  entry->index_ = strings_.size();
   entry->ref_ = 0;
-  strings_.emplace_back(entry);
-  indexed_strings_.insert(std::make_pair(StringPiece(entry->value), entry));
-  return Ref(entry);
+  entry->pool_ = this;
+
+  Entry* borrow = entry.get();
+  strings_.emplace_back(std::move(entry));
+  indexed_strings_.insert(std::make_pair(StringPiece(borrow->value), borrow));
+  return Ref(borrow);
 }
 
 StringPool::StyleRef StringPool::MakeRef(const StyleString& str) {
   return MakeRef(str, Context{});
 }
 
-StringPool::StyleRef StringPool::MakeRef(const StyleString& str,
-                                         const Context& context) {
-  Entry* entry = new Entry();
+StringPool::StyleRef StringPool::MakeRef(const StyleString& str, const Context& context) {
+  std::unique_ptr<StyleEntry> entry(new StyleEntry());
   entry->value = str.str;
   entry->context = context;
-  entry->index = strings_.size();
+  entry->index_ = styles_.size();
   entry->ref_ = 0;
-  strings_.emplace_back(entry);
-  indexed_strings_.insert(std::make_pair(StringPiece(entry->value), entry));
-
-  StyleEntry* style_entry = new StyleEntry();
-  style_entry->str = Ref(entry);
   for (const aapt::Span& span : str.spans) {
-    style_entry->spans.emplace_back(
-        Span{MakeRef(span.name), span.first_char, span.last_char});
+    entry->spans.emplace_back(Span{MakeRef(span.name), span.first_char, span.last_char});
   }
-  style_entry->ref_ = 0;
-  styles_.emplace_back(style_entry);
-  return StyleRef(style_entry);
+
+  StyleEntry* borrow = entry.get();
+  styles_.emplace_back(std::move(entry));
+  return StyleRef(borrow);
 }
 
 StringPool::StyleRef StringPool::MakeRef(const StyleRef& ref) {
-  Entry* entry = new Entry();
-  entry->value = *ref.entry_->str;
-  entry->context = ref.entry_->str.entry_->context;
-  entry->index = strings_.size();
+  std::unique_ptr<StyleEntry> entry(new StyleEntry());
+  entry->value = ref.entry_->value;
+  entry->context = ref.entry_->context;
+  entry->index_ = styles_.size();
   entry->ref_ = 0;
-  strings_.emplace_back(entry);
-  indexed_strings_.insert(std::make_pair(StringPiece(entry->value), entry));
-
-  StyleEntry* style_entry = new StyleEntry();
-  style_entry->str = Ref(entry);
   for (const Span& span : ref.entry_->spans) {
-    style_entry->spans.emplace_back(
-        Span{MakeRef(*span.name), span.first_char, span.last_char});
+    entry->spans.emplace_back(Span{MakeRef(*span.name), span.first_char, span.last_char});
   }
-  style_entry->ref_ = 0;
-  styles_.emplace_back(style_entry);
-  return StyleRef(style_entry);
+
+  StyleEntry* borrow = entry.get();
+  styles_.emplace_back(std::move(entry));
+  return StyleRef(borrow);
+}
+
+void StringPool::ReAssignIndices() {
+  // Assign the style indices.
+  const size_t style_len = styles_.size();
+  for (size_t index = 0; index < style_len; index++) {
+    styles_[index]->index_ = index;
+  }
+
+  // Assign the string indices.
+  const size_t string_len = strings_.size();
+  for (size_t index = 0; index < string_len; index++) {
+    strings_[index]->index_ = index;
+  }
 }
 
 void StringPool::Merge(StringPool&& pool) {
-  indexed_strings_.insert(pool.indexed_strings_.begin(),
-                          pool.indexed_strings_.end());
-  pool.indexed_strings_.clear();
-  std::move(pool.strings_.begin(), pool.strings_.end(),
-            std::back_inserter(strings_));
-  pool.strings_.clear();
-  std::move(pool.styles_.begin(), pool.styles_.end(),
-            std::back_inserter(styles_));
-  pool.styles_.clear();
-
-  // Assign the indices.
-  const size_t len = strings_.size();
-  for (size_t index = 0; index < len; index++) {
-    strings_[index]->index = index;
+  // First, change the owning pool for the incoming strings.
+  for (std::unique_ptr<Entry>& entry : pool.strings_) {
+    entry->pool_ = this;
   }
+
+  // Now move the styles, strings, and indices over.
+  std::move(pool.styles_.begin(), pool.styles_.end(), std::back_inserter(styles_));
+  pool.styles_.clear();
+  std::move(pool.strings_.begin(), pool.strings_.end(), std::back_inserter(strings_));
+  pool.strings_.clear();
+  indexed_strings_.insert(pool.indexed_strings_.begin(), pool.indexed_strings_.end());
+  pool.indexed_strings_.clear();
+
+  ReAssignIndices();
 }
 
-void StringPool::HintWillAdd(size_t stringCount, size_t styleCount) {
-  strings_.reserve(strings_.size() + stringCount);
-  styles_.reserve(styles_.size() + styleCount);
+void StringPool::HintWillAdd(size_t string_count, size_t style_count) {
+  strings_.reserve(strings_.size() + string_count);
+  styles_.reserve(styles_.size() + style_count);
 }
 
 void StringPool::Prune() {
@@ -262,47 +274,42 @@
 
   auto end_iter2 =
       std::remove_if(strings_.begin(), strings_.end(),
-                     [](const std::unique_ptr<Entry>& entry) -> bool {
-                       return entry->ref_ <= 0;
-                     });
+                     [](const std::unique_ptr<Entry>& entry) -> bool { return entry->ref_ <= 0; });
+  auto end_iter3 = std::remove_if(
+      styles_.begin(), styles_.end(),
+      [](const std::unique_ptr<StyleEntry>& entry) -> bool { return entry->ref_ <= 0; });
 
-  auto end_iter3 =
-      std::remove_if(styles_.begin(), styles_.end(),
-                     [](const std::unique_ptr<StyleEntry>& entry) -> bool {
-                       return entry->ref_ <= 0;
-                     });
-
-  // Remove the entries at the end or else we'll be accessing
-  // a deleted string from the StyleEntry.
+  // Remove the entries at the end or else we'll be accessing a deleted string from the StyleEntry.
   strings_.erase(end_iter2, strings_.end());
   styles_.erase(end_iter3, styles_.end());
 
-  // Reassign the indices.
-  const size_t len = strings_.size();
-  for (size_t index = 0; index < len; index++) {
-    strings_[index]->index = index;
+  ReAssignIndices();
+}
+
+template <typename E>
+static void SortEntries(
+    std::vector<std::unique_ptr<E>>& entries,
+    const std::function<int(const StringPool::Context&, const StringPool::Context&)>& cmp) {
+  using UEntry = std::unique_ptr<E>;
+
+  if (cmp != nullptr) {
+    std::sort(entries.begin(), entries.end(), [&cmp](const UEntry& a, const UEntry& b) -> bool {
+      int r = cmp(a->context, b->context);
+      if (r == 0) {
+        r = a->value.compare(b->value);
+      }
+      return r < 0;
+    });
+  } else {
+    std::sort(entries.begin(), entries.end(),
+              [](const UEntry& a, const UEntry& b) -> bool { return a->value < b->value; });
   }
 }
 
-void StringPool::Sort(
-    const std::function<bool(const Entry&, const Entry&)>& cmp) {
-  std::sort(
-      strings_.begin(), strings_.end(),
-      [&cmp](const std::unique_ptr<Entry>& a,
-             const std::unique_ptr<Entry>& b) -> bool { return cmp(*a, *b); });
-
-  // Assign the indices.
-  const size_t len = strings_.size();
-  for (size_t index = 0; index < len; index++) {
-    strings_[index]->index = index;
-  }
-
-  // Reorder the styles.
-  std::sort(styles_.begin(), styles_.end(),
-            [](const std::unique_ptr<StyleEntry>& lhs,
-               const std::unique_ptr<StyleEntry>& rhs) -> bool {
-              return lhs->str.index() < rhs->str.index();
-            });
+void StringPool::Sort(const std::function<int(const Context&, const Context&)>& cmp) {
+  SortEntries(styles_, cmp);
+  SortEntries(strings_, cmp);
+  ReAssignIndices();
 }
 
 template <typename T>
@@ -327,60 +334,31 @@
   return length > kMaxSize ? 2 : 1;
 }
 
-bool StringPool::Flatten(BigBuffer* out, const StringPool& pool, bool utf8) {
-  const size_t start_index = out->size();
-  android::ResStringPool_header* header =
-      out->NextBlock<android::ResStringPool_header>();
-  header->header.type = android::RES_STRING_POOL_TYPE;
-  header->header.headerSize = sizeof(*header);
-  header->stringCount = pool.size();
+static void EncodeString(const std::string& str, const bool utf8, BigBuffer* out) {
   if (utf8) {
-    header->flags |= android::ResStringPool_header::UTF8_FLAG;
-  }
+    const std::string& encoded = str;
+    const ssize_t utf16_length =
+        utf8_to_utf16_length(reinterpret_cast<const uint8_t*>(str.data()), str.size());
+    CHECK(utf16_length >= 0);
 
-  uint32_t* indices =
-      pool.size() != 0 ? out->NextBlock<uint32_t>(pool.size()) : nullptr;
+    const size_t total_size = EncodedLengthUnits<char>(utf16_length) +
+                              EncodedLengthUnits<char>(encoded.length()) + encoded.size() + 1;
 
-  uint32_t* style_indices = nullptr;
-  if (!pool.styles_.empty()) {
-    header->styleCount = pool.styles_.back()->str.index() + 1;
-    style_indices = out->NextBlock<uint32_t>(header->styleCount);
-  }
+    char* data = out->NextBlock<char>(total_size);
 
-  const size_t before_strings_index = out->size();
-  header->stringsStart = before_strings_index - start_index;
+    // First encode the UTF16 string length.
+    data = EncodeLength(data, utf16_length);
 
-  for (const auto& entry : pool) {
-    *indices = out->size() - before_strings_index;
-    indices++;
-
-    if (utf8) {
-      const std::string& encoded = entry->value;
-      const ssize_t utf16_length = utf8_to_utf16_length(
-          reinterpret_cast<const uint8_t*>(entry->value.data()),
-          entry->value.size());
-      CHECK(utf16_length >= 0);
-
-      const size_t total_size = EncodedLengthUnits<char>(utf16_length) +
-                                EncodedLengthUnits<char>(encoded.length()) +
-                                encoded.size() + 1;
-
-      char* data = out->NextBlock<char>(total_size);
-
-      // First encode the UTF16 string length.
-      data = EncodeLength(data, utf16_length);
-
-      // Now encode the size of the real UTF8 string.
-      data = EncodeLength(data, encoded.length());
-      strncpy(data, encoded.data(), encoded.size());
+    // Now encode the size of the real UTF8 string.
+    data = EncodeLength(data, encoded.length());
+    strncpy(data, encoded.data(), encoded.size());
 
     } else {
-      const std::u16string encoded = util::Utf8ToUtf16(entry->value);
+      const std::u16string encoded = util::Utf8ToUtf16(str);
       const ssize_t utf16_length = encoded.size();
 
       // Total number of 16-bit words to write.
-      const size_t total_size =
-          EncodedLengthUnits<char16_t>(utf16_length) + encoded.size() + 1;
+      const size_t total_size = EncodedLengthUnits<char16_t>(utf16_length) + encoded.size() + 1;
 
       char16_t* data = out->NextBlock<char16_t>(total_size);
 
@@ -395,31 +373,55 @@
       // The null-terminating character is already here due to the block of data
       // being set to 0s on allocation.
     }
+}
+
+bool StringPool::Flatten(BigBuffer* out, const StringPool& pool, bool utf8) {
+  const size_t start_index = out->size();
+  android::ResStringPool_header* header = out->NextBlock<android::ResStringPool_header>();
+  header->header.type = util::HostToDevice16(android::RES_STRING_POOL_TYPE);
+  header->header.headerSize = util::HostToDevice16(sizeof(*header));
+  header->stringCount = util::HostToDevice32(pool.size());
+  header->styleCount = util::HostToDevice32(pool.styles_.size());
+  if (utf8) {
+    header->flags |= android::ResStringPool_header::UTF8_FLAG;
+  }
+
+  uint32_t* indices = pool.size() != 0 ? out->NextBlock<uint32_t>(pool.size()) : nullptr;
+  uint32_t* style_indices =
+      pool.styles_.size() != 0 ? out->NextBlock<uint32_t>(pool.styles_.size()) : nullptr;
+
+  const size_t before_strings_index = out->size();
+  header->stringsStart = before_strings_index - start_index;
+
+  // Styles always come first.
+  for (const std::unique_ptr<StyleEntry>& entry : pool.styles_) {
+    *indices++ = out->size() - before_strings_index;
+    EncodeString(entry->value, utf8, out);
+  }
+
+  for (const std::unique_ptr<Entry>& entry : pool.strings_) {
+    *indices++ = out->size() - before_strings_index;
+    EncodeString(entry->value, utf8, out);
   }
 
   out->Align4();
 
-  if (!pool.styles_.empty()) {
+  if (style_indices != nullptr) {
     const size_t before_styles_index = out->size();
-    header->stylesStart = before_styles_index - start_index;
+    header->stylesStart = util::HostToDevice32(before_styles_index - start_index);
 
-    size_t current_index = 0;
-    for (const auto& entry : pool.styles_) {
-      while (entry->str.index() > current_index) {
-        style_indices[current_index++] = out->size() - before_styles_index;
+    for (const std::unique_ptr<StyleEntry>& entry : pool.styles_) {
+      *style_indices++ = out->size() - before_styles_index;
 
-        uint32_t* span_offset = out->NextBlock<uint32_t>();
-        *span_offset = android::ResStringPool_span::END;
-      }
-      style_indices[current_index++] = out->size() - before_styles_index;
-
-      android::ResStringPool_span* span =
-          out->NextBlock<android::ResStringPool_span>(entry->spans.size());
-      for (const auto& s : entry->spans) {
-        span->name.index = s.name.index();
-        span->firstChar = s.first_char;
-        span->lastChar = s.last_char;
-        span++;
+      if (!entry->spans.empty()) {
+        android::ResStringPool_span* span =
+            out->NextBlock<android::ResStringPool_span>(entry->spans.size());
+        for (const Span& s : entry->spans) {
+          span->name.index = util::HostToDevice32(s.name.index());
+          span->firstChar = util::HostToDevice32(s.first_char);
+          span->lastChar = util::HostToDevice32(s.last_char);
+          span++;
+        }
       }
 
       uint32_t* spanEnd = out->NextBlock<uint32_t>();
@@ -436,7 +438,7 @@
     memset(padding, 0xff, padding_length);
     out->Align4();
   }
-  header->header.size = out->size() - start_index;
+  header->header.size = util::HostToDevice32(out->size() - start_index);
   return true;
 }