diff --git a/tests/cts/hostside/AndroidTest.xml b/tests/cts/hostside/AndroidTest.xml
index 7cc0dd1..b7fefaf 100644
--- a/tests/cts/hostside/AndroidTest.xml
+++ b/tests/cts/hostside/AndroidTest.xml
@@ -20,6 +20,7 @@
     <option name="config-descriptor:metadata" key="parameter" value="not_multi_abi" />
     <option name="config-descriptor:metadata" key="parameter" value="secondary_user" />
 
+    <target_preparer class="com.android.compatibility.common.tradefed.targetprep.LocationCheck" />
     <target_preparer class="com.android.cts.net.NetworkPolicyTestsPreparer" />
 
     <target_preparer class="com.android.tradefed.targetprep.RunCommandTargetPreparer">
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/AbstractRestrictBackgroundNetworkTestCase.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/AbstractRestrictBackgroundNetworkTestCase.java
index 21212af..29ba68c 100644
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/AbstractRestrictBackgroundNetworkTestCase.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/AbstractRestrictBackgroundNetworkTestCase.java
@@ -126,12 +126,10 @@
     protected Context mContext;
     protected Instrumentation mInstrumentation;
     protected ConnectivityManager mCm;
-    protected WifiManager mWfm;
     protected int mUid;
     private int mMyUid;
     private MyServiceClient mServiceClient;
     private String mDeviceIdleConstantsSetting;
-    private boolean mIsLocationOn;
 
     @Rule
     public final RuleChain mRuleChain = RuleChain.outerRule(new RequiredPropertiesRule())
@@ -144,16 +142,11 @@
         mInstrumentation = getInstrumentation();
         mContext = getContext();
         mCm = getConnectivityManager();
-        mWfm = getWifiManager();
         mUid = getUid(TEST_APP2_PKG);
         mMyUid = getUid(mContext.getPackageName());
         mServiceClient = new MyServiceClient(mContext);
         mServiceClient.bind();
         mDeviceIdleConstantsSetting = "device_idle_constants";
-        mIsLocationOn = isLocationOn();
-        if (!mIsLocationOn) {
-            enableLocation();
-        }
         executeShellCommand("cmd netpolicy start-watching " + mUid);
         setAppIdle(false);
 
@@ -164,33 +157,9 @@
 
     protected void tearDown() throws Exception {
         executeShellCommand("cmd netpolicy stop-watching");
-        if (!mIsLocationOn) {
-            disableLocation();
-        }
         mServiceClient.unbind();
     }
 
-    private void enableLocation() throws Exception {
-        Settings.Secure.putInt(mContext.getContentResolver(), Settings.Secure.LOCATION_MODE,
-                Settings.Secure.LOCATION_MODE_SENSORS_ONLY);
-        assertEquals(Settings.Secure.LOCATION_MODE_SENSORS_ONLY,
-                Settings.Secure.getInt(mContext.getContentResolver(),
-                        Settings.Secure.LOCATION_MODE));
-    }
-
-    private void disableLocation() throws Exception {
-        Settings.Secure.putInt(mContext.getContentResolver(), Settings.Secure.LOCATION_MODE,
-                Settings.Secure.LOCATION_MODE_OFF);
-        assertEquals(Settings.Secure.LOCATION_MODE_OFF,
-                Settings.Secure.getInt(mContext.getContentResolver(),
-                        Settings.Secure.LOCATION_MODE));
-    }
-
-    private boolean isLocationOn() throws Exception {
-        return Settings.Secure.getInt(mContext.getContentResolver(),
-                Settings.Secure.LOCATION_MODE) != Settings.Secure.LOCATION_MODE_OFF;
-    }
-
     protected int getUid(String packageName) throws Exception {
         return mContext.getPackageManager().getPackageUid(packageName, 0);
     }
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java
index ca2864c..3807d79 100644
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/NetworkPolicyTestUtils.java
@@ -27,17 +27,20 @@
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 import android.app.ActivityManager;
 import android.app.Instrumentation;
 import android.content.Context;
+import android.location.LocationManager;
 import android.net.ConnectivityManager;
 import android.net.ConnectivityManager.NetworkCallback;
 import android.net.Network;
 import android.net.NetworkCapabilities;
 import android.net.wifi.WifiManager;
+import android.os.Process;
 import android.text.TextUtils;
 import android.util.Log;
 import android.util.Pair;
@@ -113,6 +116,20 @@
         return am.isLowRamDevice();
     }
 
+    public static boolean isLocationEnabled() {
+        final LocationManager lm = (LocationManager) getContext().getSystemService(
+                Context.LOCATION_SERVICE);
+        return lm.isLocationEnabled();
+    }
+
+    public static void setLocationEnabled(boolean enabled) {
+        final LocationManager lm = (LocationManager) getContext().getSystemService(
+                Context.LOCATION_SERVICE);
+        lm.setLocationEnabledForUser(enabled, Process.myUserHandle());
+        assertEquals("Couldn't change location enabled state", lm.isLocationEnabled(), enabled);
+        Log.d(TAG, "Changed location enabled state to " + enabled);
+    }
+
     public static boolean isActiveNetworkMetered(boolean metered) {
         return getConnectivityManager().isActiveNetworkMetered() == metered;
     }
@@ -128,9 +145,21 @@
         if (isActiveNetworkMetered(metered)) {
             return null;
         }
-        final String ssid = unquoteSSID(getWifiManager().getConnectionInfo().getSSID());
-        setWifiMeteredStatus(ssid, metered);
-        return Pair.create(ssid, !metered);
+        final boolean isLocationEnabled = isLocationEnabled();
+        try {
+            if (!isLocationEnabled) {
+                setLocationEnabled(true);
+            }
+            final String ssid = unquoteSSID(getWifiManager().getConnectionInfo().getSSID());
+            assertNotEquals(WifiManager.UNKNOWN_SSID, ssid);
+            setWifiMeteredStatus(ssid, metered);
+            return Pair.create(ssid, !metered);
+        } finally {
+            // Reset the location enabled state
+            if (!isLocationEnabled) {
+                setLocationEnabled(false);
+            }
+        }
     }
 
     public static void resetMeteredNetwork(String ssid, boolean metered) throws Exception {
