Check the size of the strings in the StringPool before flattening.

Test: Tested for normal functionality when string does not exceed
maximum length and tests for detection of string that is too lonhg for
UTF8i
Bug: b/74176037

Change-Id: Ic71d3671a069e7012e8ca107e79e071499eebbf6
(cherry picked from commit a15c2a8957b9883cb293fdacaeabd7f2e037a0a5)
diff --git a/tools/aapt2/LoadedApk.cpp b/tools/aapt2/LoadedApk.cpp
index 4a4260d..1dd46ba 100644
--- a/tools/aapt2/LoadedApk.cpp
+++ b/tools/aapt2/LoadedApk.cpp
@@ -225,7 +225,7 @@
       }
     } else if (format_ == ApkFormat::kProto && path == kProtoResourceTablePath) {
       pb::ResourceTable pb_table;
-      SerializeTableToPb(*split_table, &pb_table);
+      SerializeTableToPb(*split_table, &pb_table, context->GetDiagnostics());
       if (!io::CopyProtoToArchive(context,
                                   &pb_table,
                                   path,
diff --git a/tools/aapt2/StringPool.cpp b/tools/aapt2/StringPool.cpp
index 3a1a18c..b0ce9e1 100644
--- a/tools/aapt2/StringPool.cpp
+++ b/tools/aapt2/StringPool.cpp
@@ -332,6 +332,25 @@
   return data;
 }
 
+/**
+ * Returns the maximum possible string length that can be successfully encoded
+ * using 2 units of the specified T.
+ *    EncodeLengthMax<char> -> maximum unit length of 0x7FFF
+ *    EncodeLengthMax<char16_t> -> maximum unit length of 0x7FFFFFFF
+ **/
+template <typename T>
+static size_t EncodeLengthMax() {
+  static_assert(std::is_integral<T>::value, "wat.");
+
+  constexpr size_t kMask = 1 << ((sizeof(T) * 8 * 2) - 1);
+  constexpr size_t max = kMask - 1;
+  return max;
+}
+
+/**
+ * Returns the number of units (1 or 2) needed to encode the string length
+ * before writing the string.
+ */
 template <typename T>
 static size_t EncodedLengthUnits(size_t length) {
   static_assert(std::is_integral<T>::value, "wat.");
@@ -341,15 +360,30 @@
   return length > kMaxSize ? 2 : 1;
 }
 
-static void EncodeString(const std::string& str, const bool utf8, BigBuffer* out) {
+const std::string kStringTooLarge = "STRING_TOO_LARGE";
+
+static bool EncodeString(const std::string& str, const bool utf8, BigBuffer* out,
+                         IDiagnostics* diag) {
   if (utf8) {
     const std::string& encoded = str;
-    const ssize_t utf16_length =
-        utf8_to_utf16_length(reinterpret_cast<const uint8_t*>(str.data()), str.size());
+    const ssize_t utf16_length = utf8_to_utf16_length(
+        reinterpret_cast<const uint8_t*>(encoded.data()), encoded.size());
     CHECK(utf16_length >= 0);
 
-    const size_t total_size = EncodedLengthUnits<char>(utf16_length) +
-                              EncodedLengthUnits<char>(encoded.length()) + encoded.size() + 1;
+    // Make sure the lengths to be encoded do not exceed the maximum length that
+    // can be encoded using chars
+    if ((((size_t)encoded.size()) > EncodeLengthMax<char>())
+        || (((size_t)utf16_length) > EncodeLengthMax<char>())) {
+
+      diag->Error(DiagMessage() << "string too large to encode using UTF-8 "
+          << "written instead as '" << kStringTooLarge << "'");
+
+      EncodeString(kStringTooLarge, utf8, out, diag);
+      return false;
+    }
+
+    const size_t total_size = EncodedLengthUnits<char>(utf16_length)
+        + EncodedLengthUnits<char>(encoded.size()) + encoded.size() + 1;
 
     char* data = out->NextBlock<char>(total_size);
 
@@ -357,32 +391,47 @@
     data = EncodeLength(data, utf16_length);
 
     // Now encode the size of the real UTF8 string.
-    data = EncodeLength(data, encoded.length());
+    data = EncodeLength(data, encoded.size());
     strncpy(data, encoded.data(), encoded.size());
 
-    } else {
-      const std::u16string encoded = util::Utf8ToUtf16(str);
-      const ssize_t utf16_length = encoded.size();
+  } else {
+    const std::u16string encoded = util::Utf8ToUtf16(str);
+    const ssize_t utf16_length = encoded.size();
 
-      // Total number of 16-bit words to write.
-      const size_t total_size = EncodedLengthUnits<char16_t>(utf16_length) + encoded.size() + 1;
+    // Make sure the length to be encoded does not exceed the maximum possible
+    // length that can be encoded
+    if (((size_t)utf16_length) > EncodeLengthMax<char16_t>()) {
+      diag->Error(DiagMessage() << "string too large to encode using UTF-16 "
+          << "written instead as '" << kStringTooLarge << "'");
 
-      char16_t* data = out->NextBlock<char16_t>(total_size);
-
-      // Encode the actual UTF16 string length.
-      data = EncodeLength(data, utf16_length);
-      const size_t byte_length = encoded.size() * sizeof(char16_t);
-
-      // NOTE: For some reason, strncpy16(data, entry->value.data(),
-      // entry->value.size()) truncates the string.
-      memcpy(data, encoded.data(), byte_length);
-
-      // The null-terminating character is already here due to the block of data
-      // being set to 0s on allocation.
+      EncodeString(kStringTooLarge, utf8, out, diag);
+      return false;
     }
+
+    // Total number of 16-bit words to write.
+    const size_t total_size = EncodedLengthUnits<char16_t>(utf16_length)
+        + encoded.size() + 1;
+
+    char16_t* data = out->NextBlock<char16_t>(total_size);
+
+    // Encode the actual UTF16 string length.
+    data = EncodeLength(data, utf16_length);
+    const size_t byte_length = encoded.size() * sizeof(char16_t);
+
+    // NOTE: For some reason, strncpy16(data, entry->value.data(),
+    // entry->value.size()) truncates the string.
+    memcpy(data, encoded.data(), byte_length);
+
+    // The null-terminating character is already here due to the block of data
+    // being set to 0s on allocation.
+  }
+
+  return true;
 }
 
-bool StringPool::Flatten(BigBuffer* out, const StringPool& pool, bool utf8) {
+bool StringPool::Flatten(BigBuffer* out, const StringPool& pool, bool utf8,
+                         IDiagnostics* diag) {
+  bool no_error = true;
   const size_t start_index = out->size();
   android::ResStringPool_header* header = out->NextBlock<android::ResStringPool_header>();
   header->header.type = util::HostToDevice16(android::RES_STRING_POOL_TYPE);
@@ -403,12 +452,12 @@
   // Styles always come first.
   for (const std::unique_ptr<StyleEntry>& entry : pool.styles_) {
     *indices++ = out->size() - before_strings_index;
-    EncodeString(entry->value, utf8, out);
+    no_error = EncodeString(entry->value, utf8, out, diag) && no_error;
   }
 
   for (const std::unique_ptr<Entry>& entry : pool.strings_) {
     *indices++ = out->size() - before_strings_index;
-    EncodeString(entry->value, utf8, out);
+    no_error = EncodeString(entry->value, utf8, out, diag) && no_error;
   }
 
   out->Align4();
@@ -446,15 +495,15 @@
     out->Align4();
   }
   header->header.size = util::HostToDevice32(out->size() - start_index);
-  return true;
+  return no_error;
 }
 
-bool StringPool::FlattenUtf8(BigBuffer* out, const StringPool& pool) {
-  return Flatten(out, pool, true);
+bool StringPool::FlattenUtf8(BigBuffer* out, const StringPool& pool, IDiagnostics* diag) {
+  return Flatten(out, pool, true, diag);
 }
 
-bool StringPool::FlattenUtf16(BigBuffer* out, const StringPool& pool) {
-  return Flatten(out, pool, false);
+bool StringPool::FlattenUtf16(BigBuffer* out, const StringPool& pool, IDiagnostics* diag) {
+  return Flatten(out, pool, false, diag);
 }
 
 }  // namespace aapt
diff --git a/tools/aapt2/StringPool.h b/tools/aapt2/StringPool.h
index 3c1f3dc..f5b464d 100644
--- a/tools/aapt2/StringPool.h
+++ b/tools/aapt2/StringPool.h
@@ -27,6 +27,7 @@
 #include "androidfw/StringPiece.h"
 
 #include "ConfigDescription.h"
+#include "Diagnostics.h"
 #include "util/BigBuffer.h"
 
 namespace aapt {
@@ -152,8 +153,8 @@
     int ref_;
   };
 
-  static bool FlattenUtf8(BigBuffer* out, const StringPool& pool);
-  static bool FlattenUtf16(BigBuffer* out, const StringPool& pool);
+  static bool FlattenUtf8(BigBuffer* out, const StringPool& pool, IDiagnostics* diag);
+  static bool FlattenUtf16(BigBuffer* out, const StringPool& pool, IDiagnostics* diag);
 
   StringPool() = default;
   StringPool(StringPool&&) = default;
@@ -207,7 +208,7 @@
  private:
   DISALLOW_COPY_AND_ASSIGN(StringPool);
 
-  static bool Flatten(BigBuffer* out, const StringPool& pool, bool utf8);
+  static bool Flatten(BigBuffer* out, const StringPool& pool, bool utf8, IDiagnostics* diag);
 
   Ref MakeRefImpl(const android::StringPiece& str, const Context& context, bool unique);
   void ReAssignIndices();
diff --git a/tools/aapt2/StringPool_test.cpp b/tools/aapt2/StringPool_test.cpp
index b1e5ce2..58a03de 100644
--- a/tools/aapt2/StringPool_test.cpp
+++ b/tools/aapt2/StringPool_test.cpp
@@ -20,6 +20,7 @@
 
 #include "androidfw/StringPiece.h"
 
+#include "Diagnostics.h"
 #include "test/Test.h"
 #include "util/Util.h"
 
@@ -188,10 +189,11 @@
 
 TEST(StringPoolTest, FlattenEmptyStringPoolUtf8) {
   using namespace android;  // For NO_ERROR on Windows.
+  StdErrDiagnostics diag;
 
   StringPool pool;
   BigBuffer buffer(1024);
-  StringPool::FlattenUtf8(&buffer, pool);
+  StringPool::FlattenUtf8(&buffer, pool, &diag);
 
   std::unique_ptr<uint8_t[]> data = util::Copy(buffer);
   ResStringPool test;
@@ -200,11 +202,12 @@
 
 TEST(StringPoolTest, FlattenOddCharactersUtf16) {
   using namespace android;  // For NO_ERROR on Windows.
+  StdErrDiagnostics diag;
 
   StringPool pool;
   pool.MakeRef("\u093f");
   BigBuffer buffer(1024);
-  StringPool::FlattenUtf16(&buffer, pool);
+  StringPool::FlattenUtf16(&buffer, pool, &diag);
 
   std::unique_ptr<uint8_t[]> data = util::Copy(buffer);
   ResStringPool test;
@@ -225,6 +228,7 @@
 
 TEST(StringPoolTest, Flatten) {
   using namespace android;  // For NO_ERROR on Windows.
+  StdErrDiagnostics diag;
 
   StringPool pool;
 
@@ -244,8 +248,8 @@
   EXPECT_THAT(ref_d.index(), Eq(4u));
 
   BigBuffer buffers[2] = {BigBuffer(1024), BigBuffer(1024)};
-  StringPool::FlattenUtf8(&buffers[0], pool);
-  StringPool::FlattenUtf16(&buffers[1], pool);
+  StringPool::FlattenUtf8(&buffers[0], pool, &diag);
+  StringPool::FlattenUtf16(&buffers[1], pool, &diag);
 
   // Test both UTF-8 and UTF-16 buffers.
   for (const BigBuffer& buffer : buffers) {
@@ -288,4 +292,53 @@
   }
 }
 
+
+TEST(StringPoolTest, MaxEncodingLength) {
+  StdErrDiagnostics diag;
+  using namespace android;  // For NO_ERROR on Windows.
+  ResStringPool test;
+
+  StringPool pool;
+  pool.MakeRef("aaaaaaaaaa");
+  BigBuffer buffers[2] = {BigBuffer(1024), BigBuffer(1024)};
+
+  // Make sure a UTF-8 string under the maximum length does not produce an error
+  EXPECT_THAT(StringPool::FlattenUtf8(&buffers[0], pool, &diag), Eq(true));
+  std::unique_ptr<uint8_t[]> data = util::Copy(buffers[0]);
+  test.setTo(data.get(), buffers[0].size());
+  EXPECT_THAT(util::GetString(test, 0), Eq("aaaaaaaaaa"));
+
+  // Make sure a UTF-16 string under the maximum length does not produce an error
+  EXPECT_THAT(StringPool::FlattenUtf16(&buffers[1], pool, &diag), Eq(true));
+  data = util::Copy(buffers[1]);
+  test.setTo(data.get(), buffers[1].size());
+  EXPECT_THAT(util::GetString16(test, 0), Eq(u"aaaaaaaaaa"));
+
+  StringPool pool2;
+  std::string longStr(50000, 'a');
+  pool2.MakeRef("this fits1");
+  pool2.MakeRef(longStr);
+  pool2.MakeRef("this fits2");
+  BigBuffer buffers2[2] = {BigBuffer(1024), BigBuffer(1024)};
+
+  // Make sure a string that exceeds the maximum length of UTF-8 produces an
+  // error and writes a shorter error string instead
+  EXPECT_THAT(StringPool::FlattenUtf8(&buffers2[0], pool2, &diag), Eq(false));
+  data = util::Copy(buffers2[0]);
+  test.setTo(data.get(), buffers2[0].size());
+  EXPECT_THAT(util::GetString(test, 0), "this fits1");
+  EXPECT_THAT(util::GetString(test, 1), "STRING_TOO_LARGE");
+  EXPECT_THAT(util::GetString(test, 2), "this fits2");
+
+  // Make sure a string that a string that exceeds the maximum length of UTF-8
+  // but not UTF-16 does not error for UTF-16
+  StringPool pool3;
+  std::u16string longStr16(50000, 'a');
+  pool3.MakeRef(longStr);
+  EXPECT_THAT(StringPool::FlattenUtf16(&buffers2[1], pool3, &diag), Eq(true));
+  data = util::Copy(buffers2[1]);
+  test.setTo(data.get(), buffers2[1].size());
+  EXPECT_THAT(util::GetString16(test, 0), Eq(longStr16));
+}
+
 }  // namespace aapt
diff --git a/tools/aapt2/cmd/Compile.cpp b/tools/aapt2/cmd/Compile.cpp
index 101f74e..2d83a14 100644
--- a/tools/aapt2/cmd/Compile.cpp
+++ b/tools/aapt2/cmd/Compile.cpp
@@ -258,7 +258,7 @@
     ContainerWriter container_writer(&copying_adaptor, 1u);
 
     pb::ResourceTable pb_table;
-    SerializeTableToPb(table, &pb_table);
+    SerializeTableToPb(table, &pb_table, context->GetDiagnostics());
     if (!container_writer.AddResTableEntry(pb_table)) {
       context->GetDiagnostics()->Error(DiagMessage(output_path) << "failed to write");
       return false;
diff --git a/tools/aapt2/cmd/Convert.cpp b/tools/aapt2/cmd/Convert.cpp
index 7f956c5..eb307fb 100644
--- a/tools/aapt2/cmd/Convert.cpp
+++ b/tools/aapt2/cmd/Convert.cpp
@@ -226,7 +226,7 @@
 
   bool SerializeTable(ResourceTable* table, IArchiveWriter* writer) override {
     pb::ResourceTable pb_table;
-    SerializeTableToPb(*table, &pb_table);
+    SerializeTableToPb(*table, &pb_table, context_->GetDiagnostics());
     return io::CopyProtoToArchive(context_, &pb_table, kProtoResourceTablePath,
                                   ArchiveEntry::kCompress, writer);
   }
diff --git a/tools/aapt2/cmd/Link.cpp b/tools/aapt2/cmd/Link.cpp
index 0839f6f..818ad20 100644
--- a/tools/aapt2/cmd/Link.cpp
+++ b/tools/aapt2/cmd/Link.cpp
@@ -1073,7 +1073,7 @@
 
       case OutputFormat::kProto: {
         pb::ResourceTable pb_table;
-        SerializeTableToPb(*table, &pb_table);
+        SerializeTableToPb(*table, &pb_table, context_->GetDiagnostics());
         return io::CopyProtoToArchive(context_, &pb_table, kProtoResourceTablePath,
                                       ArchiveEntry::kCompress, writer);
       } break;
diff --git a/tools/aapt2/format/binary/TableFlattener.cpp b/tools/aapt2/format/binary/TableFlattener.cpp
index 24a4112..6fd4c8d 100644
--- a/tools/aapt2/format/binary/TableFlattener.cpp
+++ b/tools/aapt2/format/binary/TableFlattener.cpp
@@ -255,10 +255,10 @@
     FlattenTypes(&type_buffer);
 
     pkg_header->typeStrings = util::HostToDevice32(pkg_writer.size());
-    StringPool::FlattenUtf16(pkg_writer.buffer(), type_pool_);
+    StringPool::FlattenUtf16(pkg_writer.buffer(), type_pool_, diag_);
 
     pkg_header->keyStrings = util::HostToDevice32(pkg_writer.size());
-    StringPool::FlattenUtf8(pkg_writer.buffer(), key_pool_);
+    StringPool::FlattenUtf8(pkg_writer.buffer(), key_pool_, diag_);
 
     // Append the types.
     buffer->AppendBuffer(std::move(type_buffer));
@@ -590,7 +590,8 @@
   table_header->packageCount = util::HostToDevice32(table->packages.size());
 
   // Flatten the values string pool.
-  StringPool::FlattenUtf8(table_writer.buffer(), table->string_pool);
+  StringPool::FlattenUtf8(table_writer.buffer(), table->string_pool,
+      context->GetDiagnostics());
 
   BigBuffer package_buffer(1024);
 
diff --git a/tools/aapt2/format/binary/XmlFlattener.cpp b/tools/aapt2/format/binary/XmlFlattener.cpp
index 781b9fe..d897941 100644
--- a/tools/aapt2/format/binary/XmlFlattener.cpp
+++ b/tools/aapt2/format/binary/XmlFlattener.cpp
@@ -333,9 +333,9 @@
 
   // Flatten the StringPool.
   if (options_.use_utf16) {
-    StringPool::FlattenUtf16(buffer_, visitor.pool);
+    StringPool::FlattenUtf16(buffer_, visitor.pool, context->GetDiagnostics());
   } else {
-    StringPool::FlattenUtf8(buffer_, visitor.pool);
+    StringPool::FlattenUtf8(buffer_, visitor.pool, context->GetDiagnostics());
   }
 
   {
diff --git a/tools/aapt2/format/proto/ProtoSerialize.cpp b/tools/aapt2/format/proto/ProtoSerialize.cpp
index 1d00852..2e56359 100644
--- a/tools/aapt2/format/proto/ProtoSerialize.cpp
+++ b/tools/aapt2/format/proto/ProtoSerialize.cpp
@@ -21,9 +21,9 @@
 
 namespace aapt {
 
-void SerializeStringPoolToPb(const StringPool& pool, pb::StringPool* out_pb_pool) {
+void SerializeStringPoolToPb(const StringPool& pool, pb::StringPool* out_pb_pool, IDiagnostics* diag) {
   BigBuffer buffer(1024);
-  StringPool::FlattenUtf8(&buffer, pool);
+  StringPool::FlattenUtf8(&buffer, pool, diag);
 
   std::string* data = out_pb_pool->mutable_data();
   data->reserve(buffer.size());
@@ -270,7 +270,8 @@
   out_pb_config->set_sdk_version(config.sdkVersion);
 }
 
-void SerializeTableToPb(const ResourceTable& table, pb::ResourceTable* out_table) {
+void SerializeTableToPb(const ResourceTable& table, pb::ResourceTable* out_table,
+                        IDiagnostics* diag) {
   StringPool source_pool;
   for (const std::unique_ptr<ResourceTablePackage>& package : table.packages) {
     pb::Package* pb_package = out_table->add_package();
@@ -323,7 +324,7 @@
       }
     }
   }
-  SerializeStringPoolToPb(source_pool, out_table->mutable_source_pool());
+  SerializeStringPoolToPb(source_pool, out_table->mutable_source_pool(), diag);
 }
 
 static pb::Reference_Type SerializeReferenceTypeToPb(Reference::Type type) {
diff --git a/tools/aapt2/format/proto/ProtoSerialize.h b/tools/aapt2/format/proto/ProtoSerialize.h
index 95dd413..951494c 100644
--- a/tools/aapt2/format/proto/ProtoSerialize.h
+++ b/tools/aapt2/format/proto/ProtoSerialize.h
@@ -46,13 +46,13 @@
 
 // Serializes a StringPool into its protobuf representation, which is really just the binary
 // ResStringPool representation stuffed into a bytes field.
-void SerializeStringPoolToPb(const StringPool& pool, pb::StringPool* out_pb_pool);
+void SerializeStringPoolToPb(const StringPool& pool, pb::StringPool* out_pb_pool, IDiagnostics* diag);
 
 // Serializes a ConfigDescription into its protobuf representation.
 void SerializeConfig(const ConfigDescription& config, pb::Configuration* out_pb_config);
 
 // Serializes a ResourceTable into its protobuf representation.
-void SerializeTableToPb(const ResourceTable& table, pb::ResourceTable* out_table);
+void SerializeTableToPb(const ResourceTable& table, pb::ResourceTable* out_table, IDiagnostics* diag);
 
 // Serializes a ResourceFile into its protobuf representation.
 void SerializeCompiledFileToPb(const ResourceFile& file, pb::internal::CompiledFile* out_file);
diff --git a/tools/aapt2/format/proto/ProtoSerialize_test.cpp b/tools/aapt2/format/proto/ProtoSerialize_test.cpp
index 9081ab6..6366a3d 100644
--- a/tools/aapt2/format/proto/ProtoSerialize_test.cpp
+++ b/tools/aapt2/format/proto/ProtoSerialize_test.cpp
@@ -95,7 +95,7 @@
                                     Overlayable{}, test::GetDiagnostics()));
 
   pb::ResourceTable pb_table;
-  SerializeTableToPb(*table, &pb_table);
+  SerializeTableToPb(*table, &pb_table, context->GetDiagnostics());
 
   test::TestFile file_a("res/layout/main.xml");
   MockFileCollection files;
@@ -255,6 +255,7 @@
 }
 
 TEST(ProtoSerializeTest, SerializeAndDeserializePrimitives) {
+  std::unique_ptr<IAaptContext> context = test::ContextBuilder().Build();
   std::unique_ptr<ResourceTable> table =
       test::ResourceTableBuilder()
           .AddValue("android:bool/boolean_true",
@@ -274,7 +275,7 @@
           .Build();
 
   pb::ResourceTable pb_table;
-  SerializeTableToPb(*table, &pb_table);
+  SerializeTableToPb(*table, &pb_table, context->GetDiagnostics());
 
   test::TestFile file_a("res/layout/main.xml");
   MockFileCollection files;