RTP: Fix non-zero DC in EchoSuppressor caused while aggregating samples.

Rewrite using integer arithmetic to get full 32-bit precision instead
of 23-bit in single precision floating-points.

Bug: 3029745
Change-Id: If67dcc403923755f403d08bbafb41ebce26e4e8b
diff --git a/voip/jni/rtp/AudioGroup.cpp b/voip/jni/rtp/AudioGroup.cpp
index 9da560a..0c8a725 100644
--- a/voip/jni/rtp/AudioGroup.cpp
+++ b/voip/jni/rtp/AudioGroup.cpp
@@ -768,7 +768,7 @@
     LOGD("latency: output %d, input %d", track.latency(), record.latency());
 
     // Initialize echo canceler.
-    EchoSuppressor echo(sampleRate, sampleCount, sampleCount * 2 +
+    EchoSuppressor echo(sampleCount,
         (track.latency() + record.latency()) * sampleRate / 1000);
 
     // Give device socket a reasonable buffer size.
diff --git a/voip/jni/rtp/EchoSuppressor.cpp b/voip/jni/rtp/EchoSuppressor.cpp
index ad63cd6..d5cff6e8 100644
--- a/voip/jni/rtp/EchoSuppressor.cpp
+++ b/voip/jni/rtp/EchoSuppressor.cpp
@@ -24,146 +24,163 @@
 
 #include "EchoSuppressor.h"
 
-EchoSuppressor::EchoSuppressor(int sampleRate, int sampleCount, int tailLength)
+// It is very difficult to do echo cancellation at this level due to the lack of
+// the timing information of the samples being played and recorded. Therefore,
+// for the first release only echo suppression is implemented.
+
+// The algorithm is derived from the "previous works" summarized in
+//   A new class of doubletalk detectors based on cross-correlation,
+//   J Benesty, DR Morgan, JH Cho, IEEE Trans. on Speech and Audio Processing.
+// The method proposed in that paper is not used because of its high complexity.
+
+// It is well known that cross-correlation can be computed using convolution,
+// but unfortunately not every mobile processor has a (fast enough) FPU. Thus
+// we use integer arithmetic as much as possible and do lots of bookkeeping.
+// Again, parameters and thresholds are chosen by experiments.
+
+EchoSuppressor::EchoSuppressor(int sampleCount, int tailLength)
 {
-    int scale = 1;
-    while (tailLength > 200 * scale) {
-        scale <<= 1;
-    }
-    if (scale > sampleCount) {
-        scale = sampleCount;
+    tailLength += sampleCount * 4;
+
+    int shift = 0;
+    while ((sampleCount >> shift) > 1 && (tailLength >> shift) > 256) {
+        ++shift;
     }
 
-    mScale = scale;
+    mShift = shift + 4;
+    mScale = 1 << shift;
     mSampleCount = sampleCount;
-    mWindowSize = sampleCount / scale;
-    mTailLength = (tailLength + scale - 1) / scale;
-    mRecordLength = (sampleRate + sampleCount - 1) / sampleCount;
+    mWindowSize = sampleCount >> shift;
+    mTailLength = tailLength >> shift;
+    mRecordLength = tailLength * 2 / sampleCount;
     mRecordOffset = 0;
 
-    mXs = new float[mTailLength + mWindowSize];
-    memset(mXs, 0, sizeof(float) * (mTailLength + mWindowSize));
-    mXYs = new float[mTailLength];
-    memset(mXYs, 0, sizeof(float) * mTailLength);
-    mXXs = new float[mTailLength];
-    memset(mXYs, 0, sizeof(float) * mTailLength);
-    mYY = 0;
+    mXs = new uint16_t[mTailLength + mWindowSize];
+    memset(mXs, 0, sizeof(*mXs) * (mTailLength + mWindowSize));
+    mXSums = new uint32_t[mTailLength];
+    memset(mXSums, 0, sizeof(*mXSums) * mTailLength);
+    mX2Sums = new uint32_t[mTailLength];
+    memset(mX2Sums, 0, sizeof(*mX2Sums) * mTailLength);
+    mXRecords = new uint16_t[mRecordLength * mWindowSize];
+    memset(mXRecords, 0, sizeof(*mXRecords) * mRecordLength * mWindowSize);
 
-    mXYRecords = new float[mRecordLength * mTailLength];
-    memset(mXYRecords, 0, sizeof(float) * mRecordLength * mTailLength);
-    mXXRecords = new float[mRecordLength * mWindowSize];
-    memset(mXXRecords, 0, sizeof(float) * mRecordLength * mWindowSize);
-    mYYRecords = new float[mRecordLength];
-    memset(mYYRecords, 0, sizeof(float) * mRecordLength);
+    mYSum = 0;
+    mY2Sum = 0;
+    mYRecords = new uint32_t[mRecordLength];
+    memset(mYRecords, 0, sizeof(*mYRecords) * mRecordLength);
+    mY2Records = new uint32_t[mRecordLength];
+    memset(mY2Records, 0, sizeof(*mY2Records) * mRecordLength);
+
+    mXYSums = new uint32_t[mTailLength];
+    memset(mXYSums, 0, sizeof(*mXYSums) * mTailLength);
+    mXYRecords = new uint32_t[mRecordLength * mTailLength];
+    memset(mXYRecords, 0, sizeof(*mXYRecords) * mRecordLength * mTailLength);
 
     mLastX = 0;
     mLastY = 0;
+    mWeight = 1.0f / (mRecordLength * mWindowSize);
 }
 
 EchoSuppressor::~EchoSuppressor()
 {
     delete [] mXs;
-    delete [] mXYs;
-    delete [] mXXs;
+    delete [] mXSums;
+    delete [] mX2Sums;
+    delete [] mXRecords;
+    delete [] mYRecords;
+    delete [] mY2Records;
+    delete [] mXYSums;
     delete [] mXYRecords;
-    delete [] mXXRecords;
-    delete [] mYYRecords;
 }
 
 void EchoSuppressor::run(int16_t *playbacked, int16_t *recorded)
 {
-    float *records;
-
     // Update Xs.
-    for (int i = 0; i < mTailLength; ++i) {
-        mXs[i] = mXs[mWindowSize + i];
+    for (int i = mTailLength - 1; i >= 0; --i) {
+        mXs[i + mWindowSize] = mXs[i];
     }
-    for (int i = 0, j = 0; i < mWindowSize; ++i, j += mScale) {
-        float sum = 0;
+    for (int i = mWindowSize - 1, j = 0; i >= 0; --i, j += mScale) {
+        uint32_t sum = 0;
         for (int k = 0; k < mScale; ++k) {
-            float x = playbacked[j + k] >> 8;
+            int32_t x = playbacked[j + k] << 15;
             mLastX += x;
-            sum += (mLastX >= 0) ? mLastX : -mLastX;
-            mLastX = 0.005f * mLastX - x;
+            sum += ((mLastX >= 0) ? mLastX : -mLastX) >> 15;
+            mLastX -= (mLastX >> 10) + x;
         }
-        mXs[mTailLength - 1 + i] = sum;
+        mXs[i] = sum >> mShift;
     }
 
-    // Update XXs and XXRecords.
-    for (int i = 0; i < mTailLength - mWindowSize; ++i) {
-        mXXs[i] = mXXs[mWindowSize + i];
+    // Update XSums, X2Sums, and XRecords.
+    for (int i = mTailLength - mWindowSize - 1; i >= 0; --i) {
+        mXSums[i + mWindowSize] = mXSums[i];
+        mX2Sums[i + mWindowSize] = mX2Sums[i];
     }
-    records = &mXXRecords[mRecordOffset * mWindowSize];
-    for (int i = 0, j = mTailLength - mWindowSize; i < mWindowSize; ++i, ++j) {
-        float xx = mXs[mTailLength - 1 + i] * mXs[mTailLength - 1 + i];
-        mXXs[j] = mXXs[j - 1] + xx - records[i];
-        records[i] = xx;
-        if (mXXs[j] < 0) {
-            mXXs[j] = 0;
-        }
+    uint16_t *xRecords = &mXRecords[mRecordOffset * mWindowSize];
+    for (int i = mWindowSize - 1; i >= 0; --i) {
+        uint16_t x = mXs[i];
+        mXSums[i] = mXSums[i + 1] + x - xRecords[i];
+        mX2Sums[i] = mX2Sums[i + 1] + x * x - xRecords[i] * xRecords[i];
+        xRecords[i] = x;
     }
 
     // Compute Ys.
-    float ys[mWindowSize];
-    for (int i = 0, j = 0; i < mWindowSize; ++i, j += mScale) {
-        float sum = 0;
+    uint16_t ys[mWindowSize];
+    for (int i = mWindowSize - 1, j = 0; i >= 0; --i, j += mScale) {
+        uint32_t sum = 0;
         for (int k = 0; k < mScale; ++k) {
-            float y = recorded[j + k] >> 8;
+            int32_t y = recorded[j + k] << 15;
             mLastY += y;
-            sum += (mLastY >= 0) ? mLastY : -mLastY;
-            mLastY = 0.005f * mLastY - y;
+            sum += ((mLastY >= 0) ? mLastY : -mLastY) >> 15;
+            mLastY -= (mLastY >> 10) + y;
         }
-        ys[i] = sum;
+        ys[i] = sum >> mShift;
     }
 
-    // Update YY and YYRecords.
-    float yy = 0;
-    for (int i = 0; i < mWindowSize; ++i) {
-        yy += ys[i] * ys[i];
+    // Update YSum, Y2Sum, YRecords, and Y2Records.
+    uint32_t ySum = 0;
+    uint32_t y2Sum = 0;
+    for (int i = mWindowSize - 1; i >= 0; --i) {
+        ySum += ys[i];
+        y2Sum += ys[i] * ys[i];
     }
-    mYY += yy - mYYRecords[mRecordOffset];
-    mYYRecords[mRecordOffset] = yy;
-    if (mYY < 0) {
-        mYY = 0;
+    mYSum += ySum - mYRecords[mRecordOffset];
+    mY2Sum += y2Sum - mY2Records[mRecordOffset];
+    mYRecords[mRecordOffset] = ySum;
+    mY2Records[mRecordOffset] = y2Sum;
+
+    // Update XYSums and XYRecords.
+    uint32_t *xyRecords = &mXYRecords[mRecordOffset * mTailLength];
+    for (int i = mTailLength - 1; i >= 0; --i) {
+        uint32_t xySum = 0;
+        for (int j = mWindowSize - 1; j >= 0; --j) {
+            xySum += mXs[i + j] * ys[j];
+        }
+        mXYSums[i] += xySum - xyRecords[i];
+        xyRecords[i] = xySum;
     }
 
-    // Update XYs and XYRecords.
-    records = &mXYRecords[mRecordOffset * mTailLength];
-    for (int i = 0; i < mTailLength; ++i) {
-        float xy = 0;
-        for (int j = 0;j < mWindowSize; ++j) {
-            xy += mXs[i + j] * ys[j];
-        }
-        mXYs[i] += xy - records[i];
-        records[i] = xy;
-        if (mXYs[i] < 0) {
-            mXYs[i] = 0;
-        }
-    }
-
-    // Computes correlations from XYs, XXs, and YY.
-    float weight = 1.0f / (mYY + 1);
-    float correlation = 0;
+    // Compute correlations.
+    float corr2 = 0.0f;
     int latency = 0;
-    for (int i = 0; i < mTailLength; ++i) {
-        float c = mXYs[i] * mXYs[i] * weight / (mXXs[i] + 1);
-        if (c > correlation) {
-            correlation = c;
+    float varY = mY2Sum - mWeight * mYSum * mYSum;
+    for (int i = mTailLength - 1; i >= 0; --i) {
+        float varX = mX2Sums[i] - mWeight * mXSums[i] * mXSums[i];
+        float cov = mXYSums[i] - mWeight * mXSums[i] * mYSum;
+        float c2 = cov * cov / (varX * varY + 1);
+        if (c2 > corr2) {
+            corr2 = c2;
             latency = i;
         }
     }
+    //LOGI("correlation^2 = %.10f, latency = %d", corr2, latency * mScale);
 
-    correlation = sqrtf(correlation);
-    if (correlation > 0.3f) {
-        float factor = 1.0f - correlation;
-        factor *= factor;
-        factor /= 2.0; // suppress harder
+    // Do echo suppression.
+    if (corr2 > 0.1f) {
+        int factor = (corr2 > 1.0f) ? 0 : (1.0f - sqrtf(corr2)) * 4096;
         for (int i = 0; i < mSampleCount; ++i) {
-            recorded[i] *= factor;
+            recorded[i] = recorded[i] * factor >> 16;
         }
     }
-    //LOGI("latency %5d, correlation %.10f", latency, correlation);
-
 
     // Increase RecordOffset.
     ++mRecordOffset;
diff --git a/voip/jni/rtp/EchoSuppressor.h b/voip/jni/rtp/EchoSuppressor.h
index 85decf5..2f3b593 100644
--- a/voip/jni/rtp/EchoSuppressor.h
+++ b/voip/jni/rtp/EchoSuppressor.h
@@ -23,11 +23,12 @@
 {
 public:
     // The sampleCount must be power of 2.
-    EchoSuppressor(int sampleRate, int sampleCount, int tailLength);
+    EchoSuppressor(int sampleCount, int tailLength);
     ~EchoSuppressor();
     void run(int16_t *playbacked, int16_t *recorded);
 
 private:
+    int mShift;
     int mScale;
     int mSampleCount;
     int mWindowSize;
@@ -35,17 +36,23 @@
     int mRecordLength;
     int mRecordOffset;
 
-    float *mXs;
-    float *mXYs;
-    float *mXXs;
-    float mYY;
+    uint16_t *mXs;
+    uint32_t *mXSums;
+    uint32_t *mX2Sums;
+    uint16_t *mXRecords;
 
-    float *mXYRecords;
-    float *mXXRecords;
-    float *mYYRecords;
+    uint32_t mYSum;
+    uint32_t mY2Sum;
+    uint32_t *mYRecords;
+    uint32_t *mY2Records;
 
-    float mLastX;
-    float mLastY;
+    uint32_t *mXYSums;
+    uint32_t *mXYRecords;
+
+    int32_t mLastX;
+    int32_t mLastY;
+
+    float mWeight;
 };
 
 #endif