Create holder inner class for pushing context state.
Fix bug with rsForEach corrupting parent context state.
Remove workaround from rsBalls.

Change-Id: I43a948536e70d44645d1c2ef7b97e1c5906f6943
diff --git a/libs/rs/java/Balls/src/com/android/balls/BallsRS.java b/libs/rs/java/Balls/src/com/android/balls/BallsRS.java
index 359f334..76c23b7 100644
--- a/libs/rs/java/Balls/src/com/android/balls/BallsRS.java
+++ b/libs/rs/java/Balls/src/com/android/balls/BallsRS.java
@@ -33,8 +33,6 @@
     private ProgramFragment mPFLines;
     private ProgramFragment mPFPoints;
     private ProgramVertex mPV;
-    private ProgramRaster mPR;
-    private ProgramStore mPS;
     private ScriptField_Point mPoints;
     private ScriptField_Point mArcs;
     private ScriptField_VpConsts mVpConsts;
@@ -48,12 +46,6 @@
         mVpConsts.set(i, 0, true);
     }
 
-    private void createProgramRaster() {
-        ProgramRaster.Builder b = new ProgramRaster.Builder(mRS);
-        mPR = b.create();
-        mScript.set_gPR(mPR);
-    }
-
     private void createProgramVertex() {
         updateProjectionMatrices();
 
@@ -71,7 +63,7 @@
         sb.addInput(mPoints.getElement());
         ProgramVertex pvs = sb.create();
         pvs.bindConstants(mVpConsts.getAllocation(), 0);
-        mScript.set_gPV(pvs);
+        mRS.contextBindProgramVertex(pvs);
     }
 
     private Allocation loadTexture(int id) {
@@ -125,10 +117,8 @@
         mScript.set_gPFLines(mPFLines);
         mScript.set_gPFPoints(mPFPoints);
         createProgramVertex();
-        createProgramRaster();
 
-        mPS = ProgramStore.BLEND_ADD_DEPTH_NO_DEPTH(mRS);
-        mScript.set_gPS(mPS);
+        mRS.contextBindProgramStore(ProgramStore.BLEND_ADD_DEPTH_NO_DEPTH(mRS));
 
         mPhysicsScript.set_gMinPos(new Float2(5, 5));
         mPhysicsScript.set_gMaxPos(new Float2(width - 5, height - 5));
diff --git a/libs/rs/java/Balls/src/com/android/balls/balls.rs b/libs/rs/java/Balls/src/com/android/balls/balls.rs
index bbd03cf..3edbe2d 100644
--- a/libs/rs/java/Balls/src/com/android/balls/balls.rs
+++ b/libs/rs/java/Balls/src/com/android/balls/balls.rs
@@ -4,13 +4,11 @@
 
 #include "balls.rsh"
 
-#pragma stateFragment(parent)
+#pragma stateVertex(parent)
+#pragma stateStore(parent)
 
 rs_program_fragment gPFPoints;
 rs_program_fragment gPFLines;
-rs_program_vertex gPV;
-rs_program_raster gPR;
-rs_program_store gPS;
 rs_mesh partMesh;
 rs_mesh arcMesh;
 
@@ -95,9 +93,6 @@
 
     frame++;
     rsgBindProgramFragment(gPFLines);
-    rsgBindProgramVertex(gPV);
-    rsgBindProgramRaster(gPR);
-    rsgBindProgramStore(gPS);
     rsgDrawMesh(arcMesh, 0, 0, arcIdx);
     rsgBindProgramFragment(gPFPoints);
     rsgDrawMesh(partMesh);
diff --git a/libs/rs/rsContext.cpp b/libs/rs/rsContext.cpp
index 18bf9fa..143c4dc 100644
--- a/libs/rs/rsContext.cpp
+++ b/libs/rs/rsContext.cpp
@@ -262,21 +262,27 @@
     }
 }
 
+Context::PushState::PushState(Context *con) {
+    mRsc = con;
+    mFragment.set(con->getProgramFragment());
+    mVertex.set(con->getProgramVertex());
+    mStore.set(con->getProgramStore());
+    mRaster.set(con->getProgramRaster());
+}
+
+Context::PushState::~PushState() {
+    mRsc->setProgramFragment(mFragment.get());
+    mRsc->setProgramVertex(mVertex.get());
+    mRsc->setProgramStore(mStore.get());
+    mRsc->setProgramRaster(mRaster.get());
+    mRsc->setFont(mFont.get());
+}
+
 
 uint32_t Context::runScript(Script *s) {
-    ObjectBaseRef<ProgramFragment> frag(mFragment);
-    ObjectBaseRef<ProgramVertex> vtx(mVertex);
-    ObjectBaseRef<ProgramStore> store(mFragmentStore);
-    ObjectBaseRef<ProgramRaster> raster(mRaster);
-    ObjectBaseRef<Font> font(mFont);
+    PushState(this);
 
     uint32_t ret = s->run(this);
-
-    mFragment.set(frag);
-    mVertex.set(vtx);
-    mFragmentStore.set(store);
-    mRaster.set(raster);
-    mFont.set(font);
     return ret;
 }
 
@@ -441,13 +447,13 @@
      rsc->mScriptC.init(rsc);
      if (rsc->mIsGraphicsContext) {
          rsc->mStateRaster.init(rsc);
-         rsc->setRaster(NULL);
+         rsc->setProgramRaster(NULL);
          rsc->mStateVertex.init(rsc);
-         rsc->setVertex(NULL);
+         rsc->setProgramVertex(NULL);
          rsc->mStateFragment.init(rsc);
-         rsc->setFragment(NULL);
+         rsc->setProgramFragment(NULL);
          rsc->mStateFragmentStore.init(rsc);
-         rsc->setFragmentStore(NULL);
+         rsc->setProgramStore(NULL);
          rsc->mStateFont.init(rsc);
          rsc->setFont(NULL);
          rsc->mStateVertexArray.init(rsc);
@@ -753,7 +759,7 @@
     mRootScript.set(s);
 }
 
-void Context::setFragmentStore(ProgramStore *pfs) {
+void Context::setProgramStore(ProgramStore *pfs) {
     rsAssert(mIsGraphicsContext);
     if (pfs == NULL) {
         mFragmentStore.set(mStateFragmentStore.mDefault);
@@ -762,7 +768,7 @@
     }
 }
 
-void Context::setFragment(ProgramFragment *pf) {
+void Context::setProgramFragment(ProgramFragment *pf) {
     rsAssert(mIsGraphicsContext);
     if (pf == NULL) {
         mFragment.set(mStateFragment.mDefault);
@@ -771,7 +777,7 @@
     }
 }
 
-void Context::setRaster(ProgramRaster *pr) {
+void Context::setProgramRaster(ProgramRaster *pr) {
     rsAssert(mIsGraphicsContext);
     if (pr == NULL) {
         mRaster.set(mStateRaster.mDefault);
@@ -780,7 +786,7 @@
     }
 }
 
-void Context::setVertex(ProgramVertex *pv) {
+void Context::setProgramVertex(ProgramVertex *pv) {
     rsAssert(mIsGraphicsContext);
     if (pv == NULL) {
         mVertex.set(mStateVertex.mDefault);
@@ -951,22 +957,22 @@
 
 void rsi_ContextBindProgramStore(Context *rsc, RsProgramStore vpfs) {
     ProgramStore *pfs = static_cast<ProgramStore *>(vpfs);
-    rsc->setFragmentStore(pfs);
+    rsc->setProgramStore(pfs);
 }
 
 void rsi_ContextBindProgramFragment(Context *rsc, RsProgramFragment vpf) {
     ProgramFragment *pf = static_cast<ProgramFragment *>(vpf);
-    rsc->setFragment(pf);
+    rsc->setProgramFragment(pf);
 }
 
 void rsi_ContextBindProgramRaster(Context *rsc, RsProgramRaster vpr) {
     ProgramRaster *pr = static_cast<ProgramRaster *>(vpr);
-    rsc->setRaster(pr);
+    rsc->setProgramRaster(pr);
 }
 
 void rsi_ContextBindProgramVertex(Context *rsc, RsProgramVertex vpv) {
     ProgramVertex *pv = static_cast<ProgramVertex *>(vpv);
-    rsc->setVertex(pv);
+    rsc->setProgramVertex(pv);
 }
 
 void rsi_ContextBindFont(Context *rsc, RsFont vfont) {
diff --git a/libs/rs/rsContext.h b/libs/rs/rsContext.h
index 6945342d..c377c73 100644
--- a/libs/rs/rsContext.h
+++ b/libs/rs/rsContext.h
@@ -80,6 +80,21 @@
         Context * mContext;
         Script * mScript;
     };
+
+    class PushState {
+    public:
+        PushState(Context *);
+        ~PushState();
+
+    private:
+        ObjectBaseRef<ProgramFragment> mFragment;
+        ObjectBaseRef<ProgramVertex> mVertex;
+        ObjectBaseRef<ProgramStore> mStore;
+        ObjectBaseRef<ProgramRaster> mRaster;
+        ObjectBaseRef<Font> mFont;
+        Context *mRsc;
+    };
+
     ScriptTLSStruct *mTlsStruct;
     RsSurfaceConfig mUserSurfaceConfig;
 
@@ -101,18 +116,18 @@
 
     void swapBuffers();
     void setRootScript(Script *);
-    void setRaster(ProgramRaster *);
-    void setVertex(ProgramVertex *);
-    void setFragment(ProgramFragment *);
-    void setFragmentStore(ProgramStore *);
+    void setProgramRaster(ProgramRaster *);
+    void setProgramVertex(ProgramVertex *);
+    void setProgramFragment(ProgramFragment *);
+    void setProgramStore(ProgramStore *);
     void setFont(Font *);
 
     void updateSurface(void *sur);
 
-    const ProgramFragment * getFragment() {return mFragment.get();}
-    const ProgramStore * getFragmentStore() {return mFragmentStore.get();}
-    const ProgramRaster * getRaster() {return mRaster.get();}
-    const ProgramVertex * getVertex() {return mVertex.get();}
+    ProgramFragment * getProgramFragment() {return mFragment.get();}
+    ProgramStore * getProgramStore() {return mFragmentStore.get();}
+    ProgramRaster * getProgramRaster() {return mRaster.get();}
+    ProgramVertex * getProgramVertex() {return mVertex.get();}
     Font * getFont() {return mFont.get();}
 
     bool setupCheck();
diff --git a/libs/rs/rsFont.cpp b/libs/rs/rsFont.cpp
index 96e350d..e4d77b2 100644
--- a/libs/rs/rsFont.cpp
+++ b/libs/rs/rsFont.cpp
@@ -613,18 +613,12 @@
 }
 
 void FontState::issueDrawCommand() {
+    Context::PushState ps(mRSC);
 
-    ObjectBaseRef<const ProgramVertex> tmpV(mRSC->getVertex());
-    mRSC->setVertex(mRSC->getDefaultProgramVertex());
-
-    ObjectBaseRef<const ProgramRaster> tmpR(mRSC->getRaster());
-    mRSC->setRaster(mRSC->getDefaultProgramRaster());
-
-    ObjectBaseRef<const ProgramFragment> tmpF(mRSC->getFragment());
-    mRSC->setFragment(mFontShaderF.get());
-
-    ObjectBaseRef<const ProgramStore> tmpPS(mRSC->getFragmentStore());
-    mRSC->setFragmentStore(mFontProgramStore.get());
+    mRSC->setProgramVertex(mRSC->getDefaultProgramVertex());
+    mRSC->setProgramRaster(mRSC->getDefaultProgramRaster());
+    mRSC->setProgramFragment(mFontShaderF.get());
+    mRSC->setProgramStore(mFontProgramStore.get());
 
     if (mConstantsDirty) {
         mFontShaderFConstant->data(mRSC, &mConstants, sizeof(mConstants));
@@ -632,10 +626,6 @@
     }
 
     if (!mRSC->setupCheck()) {
-        mRSC->setVertex((ProgramVertex *)tmpV.get());
-        mRSC->setRaster((ProgramRaster *)tmpR.get());
-        mRSC->setFragment((ProgramFragment *)tmpF.get());
-        mRSC->setFragmentStore((ProgramStore *)tmpPS.get());
         return;
     }
 
@@ -651,12 +641,6 @@
     mIndexBuffer->uploadCheck(mRSC);
     glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, mIndexBuffer->getBufferObjectID());
     glDrawElements(GL_TRIANGLES, mCurrentQuadIndex * 6, GL_UNSIGNED_SHORT, (uint16_t *)(0));
-
-    // Reset the state
-    mRSC->setVertex((ProgramVertex *)tmpV.get());
-    mRSC->setRaster((ProgramRaster *)tmpR.get());
-    mRSC->setFragment((ProgramFragment *)tmpF.get());
-    mRSC->setFragmentStore((ProgramStore *)tmpPS.get());
 }
 
 void FontState::appendMeshQuad(float x1, float y1, float z1,
diff --git a/libs/rs/rsScriptC.cpp b/libs/rs/rsScriptC.cpp
index 072cc168..ec7780e 100644
--- a/libs/rs/rsScriptC.cpp
+++ b/libs/rs/rsScriptC.cpp
@@ -104,16 +104,16 @@
 
 void ScriptC::setupGLState(Context *rsc) {
     if (mEnviroment.mFragmentStore.get()) {
-        rsc->setFragmentStore(mEnviroment.mFragmentStore.get());
+        rsc->setProgramStore(mEnviroment.mFragmentStore.get());
     }
     if (mEnviroment.mFragment.get()) {
-        rsc->setFragment(mEnviroment.mFragment.get());
+        rsc->setProgramFragment(mEnviroment.mFragment.get());
     }
     if (mEnviroment.mVertex.get()) {
-        rsc->setVertex(mEnviroment.mVertex.get());
+        rsc->setProgramVertex(mEnviroment.mVertex.get());
     }
     if (mEnviroment.mRaster.get()) {
-        rsc->setRaster(mEnviroment.mRaster.get());
+        rsc->setProgramRaster(mEnviroment.mRaster.get());
     }
 }
 
@@ -232,6 +232,7 @@
                          const RsScriptCall *sc) {
     MTLaunchStruct mtls;
     memset(&mtls, 0, sizeof(mtls));
+    Context::PushState ps(rsc);
 
     if (ain) {
         mtls.dimX = ain->getType()->getDimX();
diff --git a/libs/rs/rsScriptC_LibGL.cpp b/libs/rs/rsScriptC_LibGL.cpp
index ef1475c..0f84e4b 100644
--- a/libs/rs/rsScriptC_LibGL.cpp
+++ b/libs/rs/rsScriptC_LibGL.cpp
@@ -92,17 +92,17 @@
 
 static void SC_vpLoadProjectionMatrix(const rsc_Matrix *m) {
     GET_TLS();
-    rsc->getVertex()->setProjectionMatrix(rsc, m);
+    rsc->getProgramVertex()->setProjectionMatrix(rsc, m);
 }
 
 static void SC_vpLoadModelMatrix(const rsc_Matrix *m) {
     GET_TLS();
-    rsc->getVertex()->setModelviewMatrix(rsc, m);
+    rsc->getProgramVertex()->setModelviewMatrix(rsc, m);
 }
 
 static void SC_vpLoadTextureMatrix(const rsc_Matrix *m) {
     GET_TLS();
-    rsc->getVertex()->setTextureMatrix(rsc, m);
+    rsc->getProgramVertex()->setTextureMatrix(rsc, m);
 }
 
 static void SC_pfConstantColor(RsProgramFragment vpf, float r, float g, float b, float a) {
@@ -114,7 +114,7 @@
 
 static void SC_vpGetProjectionMatrix(rsc_Matrix *m) {
     GET_TLS();
-    rsc->getVertex()->getProjectionMatrix(rsc, m);
+    rsc->getProgramVertex()->getProjectionMatrix(rsc, m);
 }
 
 //////////////////////////////////////////////////////////////////////////////
@@ -165,8 +165,8 @@
 
 static void SC_drawSpriteScreenspace(float x, float y, float z, float w, float h) {
     GET_TLS();
-    ObjectBaseRef<const ProgramVertex> tmp(rsc->getVertex());
-    rsc->setVertex(rsc->getDefaultProgramVertex());
+    ObjectBaseRef<const ProgramVertex> tmp(rsc->getProgramVertex());
+    rsc->setProgramVertex(rsc->getDefaultProgramVertex());
     //rsc->setupCheck();
 
     //GLint crop[4] = {0, h, w, -h};
@@ -177,7 +177,7 @@
                 x+w, sh - y,     z,
                 x+w, sh - (y+h), z,
                 x,   sh - (y+h), z);
-    rsc->setVertex((ProgramVertex *)tmp.get());
+    rsc->setProgramVertex((ProgramVertex *)tmp.get());
 }
 /*
 static void SC_drawSprite(float x, float y, float z, float w, float h)
@@ -271,7 +271,7 @@
 
 static void SC_color(float r, float g, float b, float a) {
     GET_TLS();
-    ProgramFragment *pf = (ProgramFragment *)rsc->getFragment();
+    ProgramFragment *pf = (ProgramFragment *)rsc->getProgramFragment();
     pf->setConstantColor(rsc, r, g, b, a);
 }