diff --git a/services/surfaceflinger/RenderEngine/Description.cpp b/services/surfaceflinger/RenderEngine/Description.cpp
index 09414fd..c218e4d 100644
--- a/services/surfaceflinger/RenderEngine/Description.cpp
+++ b/services/surfaceflinger/RenderEngine/Description.cpp
@@ -51,10 +51,6 @@
     mProjectionMatrix = mtx;
 }
 
-void Description::setSaturationMatrix(const mat4& mtx) {
-    mSaturationMatrix = mtx;
-}
-
 void Description::setColorMatrix(const mat4& mtx) {
     mColorMatrix = mtx;
 }
@@ -82,11 +78,6 @@
     return mColorMatrix != identity;
 }
 
-bool Description::hasSaturationMatrix() const {
-    const mat4 identity;
-    return mSaturationMatrix != identity;
-}
-
 const mat4& Description::getColorMatrix() const {
     return mColorMatrix;
 }
diff --git a/services/surfaceflinger/RenderEngine/Description.h b/services/surfaceflinger/RenderEngine/Description.h
index 06eaf35..6ebb340 100644
--- a/services/surfaceflinger/RenderEngine/Description.h
+++ b/services/surfaceflinger/RenderEngine/Description.h
@@ -42,14 +42,12 @@
     void disableTexture();
     void setColor(const half4& color);
     void setProjectionMatrix(const mat4& mtx);
-    void setSaturationMatrix(const mat4& mtx);
     void setColorMatrix(const mat4& mtx);
     void setInputTransformMatrix(const mat3& matrix);
     void setOutputTransformMatrix(const mat4& matrix);
     bool hasInputTransformMatrix() const;
     bool hasOutputTransformMatrix() const;
     bool hasColorMatrix() const;
-    bool hasSaturationMatrix() const;
     const mat4& getColorMatrix() const;
 
     void setY410BT2020(bool enable);
@@ -92,7 +90,6 @@
     // projection matrix
     mat4 mProjectionMatrix;
     mat4 mColorMatrix;
-    mat4 mSaturationMatrix;
     mat3 mInputTransformMatrix;
     mat4 mOutputTransformMatrix;
 };
diff --git a/services/surfaceflinger/RenderEngine/GLES20RenderEngine.cpp b/services/surfaceflinger/RenderEngine/GLES20RenderEngine.cpp
index 0048000..744a70c 100644
--- a/services/surfaceflinger/RenderEngine/GLES20RenderEngine.cpp
+++ b/services/surfaceflinger/RenderEngine/GLES20RenderEngine.cpp
@@ -267,10 +267,6 @@
     mState.setColorMatrix(colorTransform);
 }
 
-void GLES20RenderEngine::setSaturationMatrix(const mat4& saturationMatrix) {
-    mState.setSaturationMatrix(saturationMatrix);
-}
-
 void GLES20RenderEngine::disableTexturing() {
     mState.disableTexture();
 }
@@ -383,11 +379,10 @@
 
         // we need to convert the RGB value to linear space and convert it back when:
         // - there is a color matrix that is not an identity matrix, or
-        // - there is a saturation matrix that is not an identity matrix, or
         // - there is an output transform matrix that is not an identity matrix, or
         // - the input transfer function doesn't match the output transfer function.
-        if (wideColorState.hasColorMatrix() || wideColorState.hasSaturationMatrix() ||
-            wideColorState.hasOutputTransformMatrix() || inputTransfer != outputTransfer) {
+        if (wideColorState.hasColorMatrix() || wideColorState.hasOutputTransformMatrix() ||
+            inputTransfer != outputTransfer) {
             switch (inputTransfer) {
                 case Dataspace::TRANSFER_ST2084:
                     wideColorState.setInputTransferFunction(Description::TransferFunction::ST2084);
diff --git a/services/surfaceflinger/RenderEngine/GLES20RenderEngine.h b/services/surfaceflinger/RenderEngine/GLES20RenderEngine.h
index de5761b..cc8eb1d 100644
--- a/services/surfaceflinger/RenderEngine/GLES20RenderEngine.h
+++ b/services/surfaceflinger/RenderEngine/GLES20RenderEngine.h
@@ -81,7 +81,6 @@
     virtual void setupLayerBlackedOut();
     virtual void setupFillWithColor(float r, float g, float b, float a);
     virtual void setupColorTransform(const mat4& colorTransform);
-    virtual void setSaturationMatrix(const mat4& saturationMatrix);
     virtual void disableTexturing();
     virtual void disableBlending();
 
diff --git a/services/surfaceflinger/RenderEngine/Program.cpp b/services/surfaceflinger/RenderEngine/Program.cpp
index 95adaca..fe536f0 100644
--- a/services/surfaceflinger/RenderEngine/Program.cpp
+++ b/services/surfaceflinger/RenderEngine/Program.cpp
@@ -135,22 +135,13 @@
         glUniform4fv(mColorLoc, 1, color);
     }
     if (mInputTransformMatrixLoc >= 0) {
-        // If the input transform matrix is not identity matrix, we want to merge
-        // the saturation matrix with input transform matrix so that the saturation
-        // matrix is applied at the correct stage.
-        mat4 inputTransformMatrix = mat4(desc.mInputTransformMatrix) * desc.mSaturationMatrix;
+        mat4 inputTransformMatrix = mat4(desc.mInputTransformMatrix);
         glUniformMatrix4fv(mInputTransformMatrixLoc, 1, GL_FALSE, inputTransformMatrix.asArray());
     }
     if (mOutputTransformMatrixLoc >= 0) {
         // The output transform matrix and color matrix can be combined as one matrix
         // that is applied right before applying OETF.
         mat4 outputTransformMatrix = desc.mColorMatrix * desc.mOutputTransformMatrix;
-        // If there is no input transform matrix, we want to merge the saturation
-        // matrix with output transform matrix to avoid extra matrix multiplication
-        // in shader.
-        if (mInputTransformMatrixLoc < 0) {
-            outputTransformMatrix *= desc.mSaturationMatrix;
-        }
         glUniformMatrix4fv(mOutputTransformMatrixLoc, 1, GL_FALSE,
                            outputTransformMatrix.asArray());
     }
diff --git a/services/surfaceflinger/RenderEngine/ProgramCache.cpp b/services/surfaceflinger/RenderEngine/ProgramCache.cpp
index 796901a..9dc6858 100644
--- a/services/surfaceflinger/RenderEngine/ProgramCache.cpp
+++ b/services/surfaceflinger/RenderEngine/ProgramCache.cpp
@@ -149,8 +149,7 @@
                  description.hasInputTransformMatrix() ?
                      Key::INPUT_TRANSFORM_MATRIX_ON : Key::INPUT_TRANSFORM_MATRIX_OFF)
             .set(Key::Key::OUTPUT_TRANSFORM_MATRIX_MASK,
-                 description.hasOutputTransformMatrix() || description.hasColorMatrix() ||
-                 (!description.hasInputTransformMatrix() && description.hasSaturationMatrix()) ?
+                 description.hasOutputTransformMatrix() || description.hasColorMatrix() ?
                      Key::OUTPUT_TRANSFORM_MATRIX_ON : Key::OUTPUT_TRANSFORM_MATRIX_OFF);
 
     needs.set(Key::Y410_BT2020_MASK,
diff --git a/services/surfaceflinger/RenderEngine/RenderEngine.h b/services/surfaceflinger/RenderEngine/RenderEngine.h
index 1196216..1786155 100644
--- a/services/surfaceflinger/RenderEngine/RenderEngine.h
+++ b/services/surfaceflinger/RenderEngine/RenderEngine.h
@@ -113,7 +113,6 @@
     virtual void setupFillWithColor(float r, float g, float b, float a) = 0;
 
     virtual void setupColorTransform(const mat4& /* colorTransform */) = 0;
-    virtual void setSaturationMatrix(const mat4& /* saturationMatrix */) = 0;
 
     virtual void disableTexturing() = 0;
     virtual void disableBlending() = 0;
@@ -228,7 +227,6 @@
     void checkErrors() const override;
 
     void setupColorTransform(const mat4& /* colorTransform */) override {}
-    void setSaturationMatrix(const mat4& /* saturationMatrix */) override {}
 
     // internal to RenderEngine
     EGLDisplay getEGLDisplay() const;
diff --git a/services/surfaceflinger/SurfaceFlinger.cpp b/services/surfaceflinger/SurfaceFlinger.cpp
index 488bc20..7680f2a 100644
--- a/services/surfaceflinger/SurfaceFlinger.cpp
+++ b/services/surfaceflinger/SurfaceFlinger.cpp
@@ -37,6 +37,7 @@
 
 #include <dvr/vr_flinger.h>
 
+#include <ui/ColorSpace.h>
 #include <ui/DebugUtils.h>
 #include <ui/DisplayInfo.h>
 #include <ui/DisplayStatInfo.h>
@@ -763,8 +764,23 @@
         ALOGE("Run StartPropertySetThread failed!");
     }
 
-    mLegacySrgbSaturationMatrix =
-            getHwComposer().getDataspaceSaturationMatrix(display->getId(), Dataspace::SRGB_LINEAR);
+    // This is a hack. Per definition of getDataspaceSaturationMatrix, the returned matrix
+    // is used to saturate legacy sRGB content. However, to make sure the same color under
+    // Display P3 will be saturated to the same color, we intentionally break the API spec
+    // and apply this saturation matrix on Display P3 content. Unless the risk of applying
+    // such saturation matrix on Display P3 is understood fully, the API should always return
+    // identify matrix.
+    mEnhancedSaturationMatrix = getBE().mHwc->getDataspaceSaturationMatrix(display->getId(),
+            Dataspace::SRGB_LINEAR);
+
+    // we will apply this on Display P3.
+    if (mEnhancedSaturationMatrix != mat4()) {
+        ColorSpace srgb(ColorSpace::sRGB());
+        ColorSpace displayP3(ColorSpace::DisplayP3());
+        mat4 srgbToP3 = mat4(ColorSpaceConnector(srgb, displayP3).getTransform());
+        mat4 p3ToSrgb = mat4(ColorSpaceConnector(displayP3, srgb).getTransform());
+        mEnhancedSaturationMatrix = srgbToP3 * mEnhancedSaturationMatrix * p3ToSrgb;
+    }
 
     ALOGV("Done initializing");
 }
@@ -2882,8 +2898,7 @@
     ATRACE_INT("hasClientComposition", hasClientComposition);
 
     bool applyColorMatrix = false;
-    bool needsLegacyColorMatrix = false;
-    bool legacyColorMatrixApplied = false;
+    bool needsEnhancedColorMatrix = false;
 
     if (hasClientComposition) {
         ALOGV("hasClientComposition");
@@ -2900,14 +2915,23 @@
         const bool skipClientColorTransform = getBE().mHwc->hasCapability(
             HWC2::Capability::SkipClientColorTransform);
 
+        mat4 colorMatrix;
         applyColorMatrix = !hasDeviceComposition && !skipClientColorTransform;
         if (applyColorMatrix) {
-            getRenderEngine().setupColorTransform(mDrawingState.colorMatrix);
+            colorMatrix = mDrawingState.colorMatrix;
         }
 
-        needsLegacyColorMatrix =
-                (display->getActiveRenderIntent() >= RenderIntent::ENHANCE &&
-                 outputDataspace != Dataspace::UNKNOWN && outputDataspace != Dataspace::SRGB);
+        // The current enhanced saturation matrix is designed to enhance Display P3,
+        // thus we only apply this matrix when the render intent is not colorimetric
+        // and the output color space is Display P3.
+        needsEnhancedColorMatrix =
+            (display->getActiveRenderIntent() >= RenderIntent::ENHANCE &&
+             outputDataspace == Dataspace::DISPLAY_P3);
+        if (needsEnhancedColorMatrix) {
+            colorMatrix *= mEnhancedSaturationMatrix;
+        }
+
+        getRenderEngine().setupColorTransform(colorMatrix);
 
         if (!display->makeCurrent()) {
             ALOGW("DisplayDevice::makeCurrent failed. Aborting surface composition for display %s",
@@ -2993,17 +3017,6 @@
                     break;
                 }
                 case HWC2::Composition::Client: {
-                    // switch color matrices lazily
-                    if (layer->isLegacyDataSpace() && needsLegacyColorMatrix) {
-                        if (!legacyColorMatrixApplied) {
-                            getRenderEngine().setSaturationMatrix(mLegacySrgbSaturationMatrix);
-                            legacyColorMatrixApplied = true;
-                        }
-                    } else if (legacyColorMatrixApplied) {
-                        getRenderEngine().setSaturationMatrix(mat4());
-                        legacyColorMatrixApplied = false;
-                    }
-
                     layer->draw(renderArea, clip);
                     break;
                 }
@@ -3016,12 +3029,9 @@
         firstLayer = false;
     }
 
-    if (applyColorMatrix) {
+    if (applyColorMatrix || needsEnhancedColorMatrix) {
         getRenderEngine().setupColorTransform(mat4());
     }
-    if (needsLegacyColorMatrix && legacyColorMatrixApplied) {
-        getRenderEngine().setSaturationMatrix(mat4());
-    }
 
     // disable scissor at the end of the frame
     getBE().mRenderEngine->disableScissor();
diff --git a/services/surfaceflinger/SurfaceFlinger.h b/services/surfaceflinger/SurfaceFlinger.h
index e04db5d..32dc4f6 100644
--- a/services/surfaceflinger/SurfaceFlinger.h
+++ b/services/surfaceflinger/SurfaceFlinger.h
@@ -889,8 +889,8 @@
     std::thread::id mMainThreadId;
 
     DisplayColorSetting mDisplayColorSetting = DisplayColorSetting::MANAGED;
-    // Applied on sRGB layers when the render intent is non-colorimetric.
-    mat4 mLegacySrgbSaturationMatrix;
+    // Applied on Display P3 layers when the render intent is non-colorimetric.
+    mat4 mEnhancedSaturationMatrix;
 
     using CreateBufferQueueFunction =
             std::function<void(sp<IGraphicBufferProducer>* /* outProducer */,
