ARM64: Tune SIMD loop unrolling factor heuristic.
Improve SIMD loop unrolling factor heuristic for ARM64 by
accounting for max desired loop size, trip_count, etc. The
following example shows 21% perf increase:
for (int i = 0; i < LENGTH; i++) {
bc[i] = ba[i]; // Byte arrays
}
Test: test-art-host, test-art-target.
Change-Id: Ic587759c51aa4354df621ffb1c7ce4ebd798dfc1
diff --git a/compiler/optimizing/loop_optimization.cc b/compiler/optimizing/loop_optimization.cc
index a249cac..e150b65 100644
--- a/compiler/optimizing/loop_optimization.cc
+++ b/compiler/optimizing/loop_optimization.cc
@@ -1761,21 +1761,33 @@
vector_peeling_candidate_ = candidate;
}
+static constexpr uint32_t ARM64_SIMD_MAXIMUM_UNROLL_FACTOR = 8;
+static constexpr uint32_t ARM64_SIMD_HEURISTIC_MAX_BODY_SIZE = 50;
+
uint32_t HLoopOptimization::GetUnrollingFactor(HBasicBlock* block, int64_t trip_count) {
- // Current heuristic: unroll by 2 on ARM64/X86 for large known trip
- // counts and small loop bodies.
- // TODO: refine with operation count, remaining iterations, etc.
- // Artem had some really cool ideas for this already.
switch (compiler_driver_->GetInstructionSet()) {
- case kArm64:
- case kX86:
- case kX86_64: {
- size_t num_instructions = block->GetInstructions().CountSize();
- if (num_instructions <= 10 && trip_count >= 4 * vector_length_) {
- return 2;
+ case kArm64: {
+ DCHECK_NE(vector_length_, 0u);
+ // TODO: Unroll loops with unknown trip count.
+ if (trip_count < 2 * vector_length_) {
+ return 1;
}
- return 1;
+
+ uint32_t instruction_count = block->GetInstructions().CountSize();
+
+ // Find a beneficial unroll factor with the following restrictions:
+ // - At least one iteration of the transformed loop should be executed.
+ // - The loop body shouldn't be "too big" (heuristic).
+ uint32_t uf1 = ARM64_SIMD_HEURISTIC_MAX_BODY_SIZE / instruction_count;
+ uint32_t uf2 = trip_count / vector_length_;
+ uint32_t unroll_factor =
+ TruncToPowerOfTwo(std::min({uf1, uf2, ARM64_SIMD_MAXIMUM_UNROLL_FACTOR}));
+ DCHECK_GE(unroll_factor, 1u);
+
+ return unroll_factor;
}
+ case kX86:
+ case kX86_64:
default:
return 1;
}
diff --git a/runtime/base/bit_utils.h b/runtime/base/bit_utils.h
index 0844678..87dac02 100644
--- a/runtime/base/bit_utils.h
+++ b/runtime/base/bit_utils.h
@@ -127,6 +127,14 @@
return (x < 2u) ? x : static_cast<T>(1u) << (std::numeric_limits<T>::digits - CLZ(x - 1u));
}
+// Return highest possible N - a power of two - such that val >= N.
+template <typename T>
+constexpr T TruncToPowerOfTwo(T val) {
+ static_assert(std::is_integral<T>::value, "T must be integral");
+ static_assert(std::is_unsigned<T>::value, "T must be unsigned");
+ return (val != 0) ? static_cast<T>(1u) << (BitSizeOf<T>() - CLZ(val) - 1u) : 0;
+}
+
template<typename T>
constexpr bool IsPowerOfTwo(T x) {
static_assert(std::is_integral<T>::value, "T must be integral");
diff --git a/runtime/base/bit_utils_test.cc b/runtime/base/bit_utils_test.cc
index 9f22fb4..c96c6dc 100644
--- a/runtime/base/bit_utils_test.cc
+++ b/runtime/base/bit_utils_test.cc
@@ -122,6 +122,32 @@
static_assert(33u == MinimumBitsToStore<uint64_t>(UINT64_C(0x1FFFFFFFF)), "TestMinBits2Store64#10");
static_assert(64u == MinimumBitsToStore<uint64_t>(~UINT64_C(0)), "TestMinBits2Store64#11");
+static_assert(0 == TruncToPowerOfTwo<uint32_t>(0u), "TestTruncToPowerOfTwo32#1");
+static_assert(1 == TruncToPowerOfTwo<uint32_t>(1u), "TestTruncToPowerOfTwo32#2");
+static_assert(2 == TruncToPowerOfTwo<uint32_t>(2u), "TestTruncToPowerOfTwo32#3");
+static_assert(2 == TruncToPowerOfTwo<uint32_t>(3u), "TestTruncToPowerOfTwo32#4");
+static_assert(4 == TruncToPowerOfTwo<uint32_t>(7u), "TestTruncToPowerOfTwo32#5");
+static_assert(0x20000u == TruncToPowerOfTwo<uint32_t>(0x3aaaau),
+ "TestTruncToPowerOfTwo32#6");
+static_assert(0x40000000u == TruncToPowerOfTwo<uint32_t>(0x40000001u),
+ "TestTruncToPowerOfTwo32#7");
+static_assert(0x80000000u == TruncToPowerOfTwo<uint32_t>(0x80000000u),
+ "TestTruncToPowerOfTwo32#8");
+
+static_assert(0 == TruncToPowerOfTwo<uint64_t>(UINT64_C(0)), "TestTruncToPowerOfTwo64#1");
+static_assert(1 == TruncToPowerOfTwo<uint64_t>(UINT64_C(1)), "TestTruncToPowerOfTwo64#2");
+static_assert(2 == TruncToPowerOfTwo<uint64_t>(UINT64_C(2)), "TestTruncToPowerOfTwo64#3");
+static_assert(2 == TruncToPowerOfTwo<uint64_t>(UINT64_C(3)), "TestTruncToPowerOfTwo64#4");
+static_assert(4 == TruncToPowerOfTwo<uint64_t>(UINT64_C(7)), "TestTruncToPowerOfTwo64#5");
+static_assert(UINT64_C(0x20000) == TruncToPowerOfTwo<uint64_t>(UINT64_C(0x3aaaa)),
+ "TestTruncToPowerOfTwo64#6");
+static_assert(
+ UINT64_C(0x4000000000000000) == TruncToPowerOfTwo<uint64_t>(UINT64_C(0x4000000000000001)),
+ "TestTruncToPowerOfTwo64#7");
+static_assert(
+ UINT64_C(0x8000000000000000) == TruncToPowerOfTwo<uint64_t>(UINT64_C(0x8000000000000000)),
+ "TestTruncToPowerOfTwo64#8");
+
static_assert(0 == RoundUpToPowerOfTwo<uint32_t>(0u), "TestRoundUpPowerOfTwo32#1");
static_assert(1 == RoundUpToPowerOfTwo<uint32_t>(1u), "TestRoundUpPowerOfTwo32#2");
static_assert(2 == RoundUpToPowerOfTwo<uint32_t>(2u), "TestRoundUpPowerOfTwo32#3");