Merge "lmkd: Add support for psi monitors"
diff --git a/lmkd/Android.bp b/lmkd/Android.bp
index f9ed57c..4c5ca03 100644
--- a/lmkd/Android.bp
+++ b/lmkd/Android.bp
@@ -6,6 +6,7 @@
         "libcutils",
         "liblog",
         "libprocessgroup",
+        "libpsi",
     ],
     static_libs: [
         "libstatslogc",
diff --git a/lmkd/lmkd.c b/lmkd/lmkd.c
index ca78c38..6cb059e 100644
--- a/lmkd/lmkd.c
+++ b/lmkd/lmkd.c
@@ -44,6 +44,7 @@
 #include <log/log.h>
 #include <log/log_event_list.h>
 #include <log/log_time.h>
+#include <psi/psi.h>
 #include <system/thread_defs.h>
 
 #ifdef LMKD_LOG_STATS
@@ -93,6 +94,7 @@
 #define TARGET_UPDATE_MIN_INTERVAL_MS 1000
 
 #define NS_PER_MS (NS_PER_SEC / MS_PER_SEC)
+#define US_PER_MS (US_PER_SEC / MS_PER_SEC)
 
 /* Defined as ProcessList.SYSTEM_ADJ in ProcessList.java */
 #define SYSTEM_ADJ (-900)
@@ -100,6 +102,18 @@
 #define STRINGIFY(x) STRINGIFY_INTERNAL(x)
 #define STRINGIFY_INTERNAL(x) #x
 
+/*
+ * PSI monitor tracking window size.
+ * PSI monitor generates events at most once per window,
+ * therefore we poll memory state for the duration of
+ * PSI_WINDOW_SIZE_MS after the event happens.
+ */
+#define PSI_WINDOW_SIZE_MS 1000
+/* Polling period after initial PSI signal */
+#define PSI_POLL_PERIOD_MS 200
+/* Poll for the duration of one window after initial PSI signal */
+#define PSI_POLL_COUNT (PSI_WINDOW_SIZE_MS / PSI_POLL_PERIOD_MS)
+
 #define min(a, b) (((a) < (b)) ? (a) : (b))
 
 #define FAIL_REPORT_RLIMIT_MS 1000
@@ -127,6 +141,11 @@
     int64_t max_nr_free_pages;
 } low_pressure_mem = { -1, -1 };
 
+struct psi_threshold {
+    enum psi_stall_type stall_type;
+    int threshold_ms;
+};
+
 static int level_oomadj[VMPRESS_LEVEL_COUNT];
 static int mpevfd[VMPRESS_LEVEL_COUNT] = { -1, -1, -1 };
 static bool debug_process_killing;
@@ -139,6 +158,12 @@
 static bool use_minfree_levels;
 static bool per_app_memcg;
 static int swap_free_low_percentage;
+static bool use_psi_monitors = false;
+static struct psi_threshold psi_thresholds[VMPRESS_LEVEL_COUNT] = {
+    { PSI_SOME, 70 },    /* 70ms out of 1sec for partial stall */
+    { PSI_SOME, 100 },   /* 100ms out of 1sec for partial stall */
+    { PSI_FULL, 70 },    /* 70ms out of 1sec for complete stall */
+};
 
 static android_log_context ctx;
 
@@ -1524,17 +1549,23 @@
         .fd = -1,
     };
 
-    /*
-     * Check all event counters from low to critical
-     * and upgrade to the highest priority one. By reading
-     * eventfd we also reset the event counters.
-     */
-    for (lvl = VMPRESS_LEVEL_LOW; lvl < VMPRESS_LEVEL_COUNT; lvl++) {
-        if (mpevfd[lvl] != -1 &&
-            TEMP_FAILURE_RETRY(read(mpevfd[lvl],
-                               &evcount, sizeof(evcount))) > 0 &&
-            evcount > 0 && lvl > level) {
-            level = lvl;
+    if (debug_process_killing) {
+        ALOGI("%s memory pressure event is triggered", level_name[level]);
+    }
+
+    if (!use_psi_monitors) {
+        /*
+         * Check all event counters from low to critical
+         * and upgrade to the highest priority one. By reading
+         * eventfd we also reset the event counters.
+         */
+        for (lvl = VMPRESS_LEVEL_LOW; lvl < VMPRESS_LEVEL_COUNT; lvl++) {
+            if (mpevfd[lvl] != -1 &&
+                TEMP_FAILURE_RETRY(read(mpevfd[lvl],
+                                   &evcount, sizeof(evcount))) > 0 &&
+                evcount > 0 && lvl > level) {
+                level = lvl;
+            }
         }
     }
 
@@ -1722,6 +1753,54 @@
     }
 }
 
+static bool init_mp_psi(enum vmpressure_level level) {
+    int fd = init_psi_monitor(psi_thresholds[level].stall_type,
+        psi_thresholds[level].threshold_ms * US_PER_MS,
+        PSI_WINDOW_SIZE_MS * US_PER_MS);
+
+    if (fd < 0) {
+        return false;
+    }
+
+    vmpressure_hinfo[level].handler = mp_event_common;
+    vmpressure_hinfo[level].data = level;
+    if (register_psi_monitor(epollfd, fd, &vmpressure_hinfo[level]) < 0) {
+        destroy_psi_monitor(fd);
+        return false;
+    }
+    maxevents++;
+    mpevfd[level] = fd;
+
+    return true;
+}
+
+static void destroy_mp_psi(enum vmpressure_level level) {
+    int fd = mpevfd[level];
+
+    if (unregister_psi_monitor(epollfd, fd) < 0) {
+        ALOGE("Failed to unregister psi monitor for %s memory pressure; errno=%d",
+            level_name[level], errno);
+    }
+    destroy_psi_monitor(fd);
+    mpevfd[level] = -1;
+}
+
+static bool init_psi_monitors() {
+    if (!init_mp_psi(VMPRESS_LEVEL_LOW)) {
+        return false;
+    }
+    if (!init_mp_psi(VMPRESS_LEVEL_MEDIUM)) {
+        destroy_mp_psi(VMPRESS_LEVEL_LOW);
+        return false;
+    }
+    if (!init_mp_psi(VMPRESS_LEVEL_CRITICAL)) {
+        destroy_mp_psi(VMPRESS_LEVEL_MEDIUM);
+        destroy_mp_psi(VMPRESS_LEVEL_LOW);
+        return false;
+    }
+    return true;
+}
+
 static bool init_mp_common(enum vmpressure_level level) {
     int mpfd;
     int evfd;
@@ -1837,12 +1916,22 @@
     if (use_inkernel_interface) {
         ALOGI("Using in-kernel low memory killer interface");
     } else {
-        if (!init_mp_common(VMPRESS_LEVEL_LOW) ||
+        /* Try to use psi monitor first if kernel has it */
+        use_psi_monitors = property_get_bool("ro.lmk.use_psi", true) &&
+            init_psi_monitors();
+        /* Fall back to vmpressure */
+        if (!use_psi_monitors &&
+            (!init_mp_common(VMPRESS_LEVEL_LOW) ||
             !init_mp_common(VMPRESS_LEVEL_MEDIUM) ||
-            !init_mp_common(VMPRESS_LEVEL_CRITICAL)) {
+            !init_mp_common(VMPRESS_LEVEL_CRITICAL))) {
             ALOGE("Kernel does not support memory pressure events or in-kernel low memory killer");
             return -1;
         }
+        if (use_psi_monitors) {
+            ALOGI("Using psi monitors for memory pressure detection");
+        } else {
+            ALOGI("Using vmpressure for memory pressure detection");
+        }
     }
 
     for (i = 0; i <= ADJTOSLOT(OOM_SCORE_ADJ_MAX); i++) {
@@ -1857,14 +1946,39 @@
 
 static void mainloop(void) {
     struct event_handler_info* handler_info;
+    struct event_handler_info* poll_handler = NULL;
+    struct timespec last_report_tm, curr_tm;
     struct epoll_event *evt;
+    long delay = -1;
+    int polling = 0;
 
     while (1) {
         struct epoll_event events[maxevents];
         int nevents;
         int i;
 
-        nevents = epoll_wait(epollfd, events, maxevents, -1);
+        if (polling) {
+            /* Calculate next timeout */
+            clock_gettime(CLOCK_MONOTONIC_COARSE, &curr_tm);
+            delay = get_time_diff_ms(&last_report_tm, &curr_tm);
+            delay = (delay < PSI_POLL_PERIOD_MS) ?
+                PSI_POLL_PERIOD_MS - delay : PSI_POLL_PERIOD_MS;
+
+            /* Wait for events until the next polling timeout */
+            nevents = epoll_wait(epollfd, events, maxevents, delay);
+
+            clock_gettime(CLOCK_MONOTONIC_COARSE, &curr_tm);
+            if (get_time_diff_ms(&last_report_tm, &curr_tm) >= PSI_POLL_PERIOD_MS) {
+                if (polling) {
+                    polling--;
+                    poll_handler->handler(poll_handler->data, 0);
+                    last_report_tm = curr_tm;
+                }
+            }
+        } else {
+            /* Wait for events with no timeout */
+            nevents = epoll_wait(epollfd, events, maxevents, -1);
+        }
 
         if (nevents == -1) {
             if (errno == EINTR)
@@ -1899,6 +2013,17 @@
             if (evt->data.ptr) {
                 handler_info = (struct event_handler_info*)evt->data.ptr;
                 handler_info->handler(handler_info->data, evt->events);
+
+                if (use_psi_monitors && handler_info->handler == mp_event_common) {
+                    /*
+                     * Poll for the duration of PSI_WINDOW_SIZE_MS after the
+                     * initial PSI event because psi events are rate-limited
+                     * at one per sec.
+                     */
+                    polling = PSI_POLL_COUNT;
+                    poll_handler = handler_info;
+                    clock_gettime(CLOCK_MONOTONIC_COARSE, &last_report_tm);
+                }
             }
         }
     }