lmkd: Add support for process death notifications

With pidfd polling support lmkd can detect process death without periodic
polling. Implement mechanism to detect kernel pidfd support using
pidfd_open syscall existence as an indicator. Implement the logic to use
pidfd to wait for process death.

Bug: 135608568
Test: lmkd_unit_test with and without pidfd kernel support
Change-Id: Ic6db7e50893534467f5130a7f998b66fb4451272
Signed-off-by: Suren Baghdasaryan <surenb@google.com>
diff --git a/lmkd/lmkd.c b/lmkd/lmkd.c
index 18cd9f5..a5411d8 100644
--- a/lmkd/lmkd.c
+++ b/lmkd/lmkd.c
@@ -31,6 +31,7 @@
 #include <sys/mman.h>
 #include <sys/resource.h>
 #include <sys/socket.h>
+#include <sys/syscall.h>
 #include <sys/sysinfo.h>
 #include <sys/time.h>
 #include <sys/types.h>
@@ -139,6 +140,10 @@
 /* ro.lmk.psi_complete_stall_ms property defaults */
 #define DEF_COMPLETE_STALL 700
 
+static inline int sys_pidfd_open(pid_t pid, unsigned int flags) {
+    return syscall(__NR_pidfd_open, pid, flags);
+}
+
 /* default to old in-kernel interface if no memory pressure events */
 static bool use_inkernel_interface = true;
 static bool has_inkernel_module;
@@ -169,6 +174,11 @@
 
 static int level_oomadj[VMPRESS_LEVEL_COUNT];
 static int mpevfd[VMPRESS_LEVEL_COUNT] = { -1, -1, -1 };
+static bool pidfd_supported;
+static int last_kill_pid_or_fd = -1;
+static struct timespec last_kill_tm;
+
+/* lmkd configurable parameters */
 static bool debug_process_killing;
 static bool enable_pressure_upgrade;
 static int64_t upgrade_pressure;
@@ -197,6 +207,8 @@
     POLLING_DO_NOT_CHANGE,
     POLLING_START,
     POLLING_STOP,
+    POLLING_PAUSE,
+    POLLING_RESUME,
 };
 
 /*
@@ -207,6 +219,7 @@
  */
 struct polling_params {
     struct event_handler_info* poll_handler;
+    struct event_handler_info* paused_handler;
     struct timespec poll_start_tm;
     struct timespec last_poll_tm;
     int polling_interval_ms;
@@ -235,8 +248,11 @@
 /* vmpressure event handler data */
 static struct event_handler_info vmpressure_hinfo[VMPRESS_LEVEL_COUNT];
 
-/* 3 memory pressure levels, 1 ctrl listen socket, 2 ctrl data socket, 1 lmk events */
-#define MAX_EPOLL_EVENTS (2 + MAX_DATA_CONN + VMPRESS_LEVEL_COUNT)
+/*
+ * 1 ctrl listen socket, 2 ctrl data socket, 3 memory pressure levels,
+ * 1 lmk events + 1 fd to wait for process death
+ */
+#define MAX_EPOLL_EVENTS (1 + MAX_DATA_CONN + VMPRESS_LEVEL_COUNT + 1 + 1)
 static int epollfd;
 static int maxevents;
 
@@ -1647,11 +1663,112 @@
     closedir(d);
 }
 
-static int last_killed_pid = -1;
+static bool is_kill_pending(void) {
+    char buf[24];
+
+    if (last_kill_pid_or_fd < 0) {
+        return false;
+    }
+
+    if (pidfd_supported) {
+        return true;
+    }
+
+    /* when pidfd is not supported base the decision on /proc/<pid> existence */
+    snprintf(buf, sizeof(buf), "/proc/%d/", last_kill_pid_or_fd);
+    if (access(buf, F_OK) == 0) {
+        return true;
+    }
+
+    return false;
+}
+
+static bool is_waiting_for_kill(void) {
+    return pidfd_supported && last_kill_pid_or_fd >= 0;
+}
+
+static void stop_wait_for_proc_kill(bool finished) {
+    struct epoll_event epev;
+
+    if (last_kill_pid_or_fd < 0) {
+        return;
+    }
+
+    if (debug_process_killing) {
+        struct timespec curr_tm;
+
+        if (clock_gettime(CLOCK_MONOTONIC_COARSE, &curr_tm) != 0) {
+            /*
+             * curr_tm is used here merely to report kill duration, so this failure is not fatal.
+             * Log an error and continue.
+             */
+            ALOGE("Failed to get current time");
+        }
+
+        if (finished) {
+            ALOGI("Process got killed in %ldms",
+                get_time_diff_ms(&last_kill_tm, &curr_tm));
+        } else {
+            ALOGI("Stop waiting for process kill after %ldms",
+                get_time_diff_ms(&last_kill_tm, &curr_tm));
+        }
+    }
+
+    if (pidfd_supported) {
+        /* unregister fd */
+        if (epoll_ctl(epollfd, EPOLL_CTL_DEL, last_kill_pid_or_fd, &epev) != 0) {
+            ALOGE("epoll_ctl for last killed process failed; errno=%d", errno);
+            return;
+        }
+        maxevents--;
+        close(last_kill_pid_or_fd);
+    }
+
+    last_kill_pid_or_fd = -1;
+}
+
+static void kill_done_handler(int data __unused, uint32_t events __unused,
+                              struct polling_params *poll_params) {
+    stop_wait_for_proc_kill(true);
+    poll_params->update = POLLING_RESUME;
+}
+
+static void start_wait_for_proc_kill(int pid) {
+    static struct event_handler_info kill_done_hinfo = { 0, kill_done_handler };
+    struct epoll_event epev;
+
+    if (last_kill_pid_or_fd >= 0) {
+        /* Should not happen but if it does we should stop previous wait */
+        ALOGE("Attempt to wait for a kill while another wait is in progress");
+        stop_wait_for_proc_kill(false);
+    }
+
+    if (!pidfd_supported) {
+        /* If pidfd is not supported store PID of the process being killed */
+        last_kill_pid_or_fd = pid;
+        return;
+    }
+
+    last_kill_pid_or_fd = TEMP_FAILURE_RETRY(sys_pidfd_open(pid, 0));
+    if (last_kill_pid_or_fd < 0) {
+        ALOGE("pidfd_open for process pid %d failed; errno=%d", pid, errno);
+        return;
+    }
+
+    epev.events = EPOLLIN;
+    epev.data.ptr = (void *)&kill_done_hinfo;
+    if (epoll_ctl(epollfd, EPOLL_CTL_ADD, last_kill_pid_or_fd, &epev) != 0) {
+        ALOGE("epoll_ctl for last kill failed; errno=%d", errno);
+        close(last_kill_pid_or_fd);
+        last_kill_pid_or_fd = -1;
+        return;
+    }
+    maxevents++;
+}
 
 /* Kill one process specified by procp.  Returns the size of the process killed */
 static int kill_one_process(struct proc* procp, int min_oom_score, int kill_reason,
-                            const char *kill_desc, union meminfo *mi) {
+                            const char *kill_desc, union meminfo *mi, struct timespec *tm) {
     int pid = procp->pid;
     uid_t uid = procp->uid;
     int tgid;
@@ -1682,12 +1799,16 @@
 
     TRACE_KILL_START(pid);
 
+    /* Have to start waiting before sending SIGKILL to make sure pid is valid */
+    start_wait_for_proc_kill(pid);
+
     /* CAP_KILL required */
     r = kill(pid, SIGKILL);
 
     TRACE_KILL_END();
 
     if (r) {
+        stop_wait_for_proc_kill(false);
         ALOGE("kill(%d): errno=%d", pid, errno);
         /* Delete process record even when we fail to kill so that we don't get stuck on it */
         goto out;
@@ -1695,6 +1816,8 @@
 
     set_process_group_and_prio(pid, SP_FOREGROUND, ANDROID_PRIORITY_HIGHEST);
 
+    last_kill_tm = *tm;
+
     inc_killcnt(procp->oomadj);
 
     killinfo_log(procp, min_oom_score, tasksize, kill_reason, mi);
@@ -1707,8 +1830,6 @@
               uid, procp->oomadj, tasksize * page_k);
     }
 
-    last_killed_pid = pid;
-
     stats_write_lmk_kill_occurred(LMK_KILL_OCCURRED, uid, taskname,
             procp->oomadj, min_oom_score, tasksize, mem_st);
 
@@ -1728,7 +1849,7 @@
  * Returns size of the killed process.
  */
 static int find_and_kill_process(int min_score_adj, int kill_reason, const char *kill_desc,
-                                 union meminfo *mi) {
+                                 union meminfo *mi, struct timespec *tm) {
     int i;
     int killed_size = 0;
     bool lmk_state_change_start = false;
@@ -1743,7 +1864,7 @@
             if (!procp)
                 break;
 
-            killed_size = kill_one_process(procp, min_score_adj, kill_reason, kill_desc, mi);
+            killed_size = kill_one_process(procp, min_score_adj, kill_reason, kill_desc, mi, tm);
             if (killed_size >= 0) {
                 if (!lmk_state_change_start) {
                     lmk_state_change_start = true;
@@ -1822,23 +1943,6 @@
         level - 1 : level);
 }
 
-static bool is_kill_pending(void) {
-    char buf[24];
-
-    if (last_killed_pid < 0) {
-        return false;
-    }
-
-    snprintf(buf, sizeof(buf), "/proc/%d/", last_killed_pid);
-    if (access(buf, F_OK) == 0) {
-        return true;
-    }
-
-    // reset last killed PID because there's nothing pending
-    last_killed_pid = -1;
-    return false;
-}
-
 enum zone_watermark {
     WMARK_MIN = 0,
     WMARK_LOW,
@@ -1934,9 +2038,13 @@
 
     /* Skip while still killing a process */
     if (is_kill_pending()) {
-        /* TODO: replace this quick polling with pidfd polling if kernel supports */
         goto no_kill;
     }
+    /*
+     * Process is dead, stop waiting. This has no effect if pidfds are supported and
+     * death notification already caused waiting to stop.
+     */
+    stop_wait_for_proc_kill(true);
 
     if (clock_gettime(CLOCK_MONOTONIC_COARSE, &curr_tm) != 0) {
         ALOGE("Failed to get current time");
@@ -2067,7 +2175,8 @@
 
     /* Kill a process if necessary */
     if (kill_reason != NONE) {
-        int pages_freed = find_and_kill_process(min_score_adj, kill_reason, kill_desc, &mi);
+        int pages_freed = find_and_kill_process(min_score_adj, kill_reason, kill_desc, &mi,
+                                                &curr_tm);
         if (pages_freed > 0) {
             killing = true;
             if (cut_thrashing_limit) {
@@ -2081,6 +2190,13 @@
     }
 
 no_kill:
+    /* Do not poll if kernel supports pidfd waiting */
+    if (is_waiting_for_kill()) {
+        /* Pause polling if we are waiting for process death notification */
+        poll_params->update = POLLING_PAUSE;
+        return;
+    }
+
     /*
      * Start polling after initial PSI event;
      * extend polling while device is in direct reclaim or process is being killed;
@@ -2110,7 +2226,6 @@
     union meminfo mi;
     struct zoneinfo zi;
     struct timespec curr_tm;
-    static struct timespec last_kill_tm;
     static unsigned long kill_skip_count = 0;
     enum vmpressure_level level = (enum vmpressure_level)data;
     long other_free = 0, other_file = 0;
@@ -2159,15 +2274,26 @@
         return;
     }
 
-    if (kill_timeout_ms) {
-        // If we're within the timeout, see if there's pending reclaim work
-        // from the last killed process. If there is (as evidenced by
-        // /proc/<pid> continuing to exist), skip killing for now.
-        if ((get_time_diff_ms(&last_kill_tm, &curr_tm) < kill_timeout_ms) &&
-            (low_ram_device || is_kill_pending())) {
+    if (kill_timeout_ms && get_time_diff_ms(&last_kill_tm, &curr_tm) < kill_timeout_ms) {
+        /*
+         * If we're within the no-kill timeout, see if there's pending reclaim work
+         * from the last killed process. If so, skip killing for now.
+         */
+        if (is_kill_pending()) {
             kill_skip_count++;
             return;
         }
+        /*
+         * Process is dead, stop waiting. This has no effect if pidfds are supported and
+         * death notification already caused waiting to stop.
+         */
+        stop_wait_for_proc_kill(true);
+    } else {
+        /*
+         * Killing took longer than no-kill timeout. Stop waiting for the last process
+         * to die because we are ready to kill again.
+         */
+        stop_wait_for_proc_kill(false);
     }
 
     if (kill_skip_count > 0) {
@@ -2266,7 +2392,7 @@
 do_kill:
     if (low_ram_device) {
         /* For Go devices kill only one task */
-        if (find_and_kill_process(level_oomadj[level], -1, NULL, &mi) == 0) {
+        if (find_and_kill_process(level_oomadj[level], -1, NULL, &mi, &curr_tm) == 0) {
             if (debug_process_killing) {
                 ALOGI("Nothing to kill");
             }
@@ -2289,7 +2415,7 @@
             min_score_adj = level_oomadj[level];
         }
 
-        pages_freed = find_and_kill_process(min_score_adj, -1, NULL, &mi);
+        pages_freed = find_and_kill_process(min_score_adj, -1, NULL, &mi, &curr_tm);
 
         if (pages_freed == 0) {
             /* Rate limit kill reports when nothing was reclaimed */
@@ -2297,9 +2423,6 @@
                 report_skip_count++;
                 return;
             }
-        } else {
-            /* If we killed anything, update the last killed timestamp. */
-            last_kill_tm = curr_tm;
         }
 
         /* Log whenever we kill or when report rate limit allows */
@@ -2322,6 +2445,10 @@
 
         last_report_tm = curr_tm;
     }
+    if (is_waiting_for_kill()) {
+        /* pause polling if we are waiting for process death notification */
+        poll_params->update = POLLING_PAUSE;
+    }
 }
 
 static bool init_mp_psi(enum vmpressure_level level, bool use_new_strategy) {
@@ -2473,6 +2600,7 @@
         .fd = -1,
     };
     struct epoll_event epev;
+    int pidfd;
     int i;
     int ret;
 
@@ -2563,9 +2691,61 @@
         ALOGE("Failed to read %s: %s", file_data.filename, strerror(errno));
     }
 
+    /* check if kernel supports pidfd_open syscall */
+    pidfd = TEMP_FAILURE_RETRY(sys_pidfd_open(getpid(), 0));
+    if (pidfd < 0) {
+        pidfd_supported = (errno != ENOSYS);
+    } else {
+        pidfd_supported = true;
+        close(pidfd);
+    }
+    ALOGI("Process polling is %s", pidfd_supported ? "supported" : "not supported" );
+
     return 0;
 }
 
+static void call_handler(struct event_handler_info* handler_info,
+                         struct polling_params *poll_params, uint32_t events) {
+    struct timespec curr_tm;
+
+    handler_info->handler(handler_info->data, events, poll_params);
+    clock_gettime(CLOCK_MONOTONIC_COARSE, &curr_tm);
+    poll_params->last_poll_tm = curr_tm;
+
+    switch (poll_params->update) {
+    case POLLING_START:
+        /*
+         * Poll for the duration of PSI_WINDOW_SIZE_MS after the
+         * initial PSI event because psi events are rate-limited
+         * at one per sec.
+         */
+        poll_params->poll_start_tm = curr_tm;
+        if (poll_params->poll_handler != handler_info) {
+            poll_params->poll_handler = handler_info;
+        }
+        break;
+    case POLLING_STOP:
+        poll_params->poll_handler = NULL;
+        break;
+    case POLLING_PAUSE:
+        poll_params->paused_handler = handler_info;
+        poll_params->poll_handler = NULL;
+        break;
+    case POLLING_RESUME:
+        poll_params->poll_start_tm = curr_tm;
+        poll_params->poll_handler = poll_params->paused_handler;
+        break;
+    case POLLING_DO_NOT_CHANGE:
+        if (get_time_diff_ms(&poll_params->poll_start_tm, &curr_tm) > PSI_WINDOW_SIZE_MS) {
+            /* Polled for the duration of PSI window, time to stop */
+            poll_params->poll_handler = NULL;
+        }
+        /* WARNING: skipping the rest of the function */
+        return;
+    }
+    poll_params->update = POLLING_DO_NOT_CHANGE;
+}
+
 static void mainloop(void) {
     struct event_handler_info* handler_info;
     struct polling_params poll_params;
@@ -2582,41 +2762,33 @@
         int i;
 
         if (poll_params.poll_handler) {
-            /* Calculate next timeout */
-            clock_gettime(CLOCK_MONOTONIC_COARSE, &curr_tm);
-            delay = get_time_diff_ms(&poll_params.last_poll_tm, &curr_tm);
-            delay = (delay < poll_params.polling_interval_ms) ?
-                poll_params.polling_interval_ms - delay : poll_params.polling_interval_ms;
-
-            /* Wait for events until the next polling timeout */
-            nevents = epoll_wait(epollfd, events, maxevents, delay);
+            bool poll_now;
 
             clock_gettime(CLOCK_MONOTONIC_COARSE, &curr_tm);
-            if (get_time_diff_ms(&poll_params.last_poll_tm, &curr_tm) >=
-                poll_params.polling_interval_ms) {
-                /* Set input params for the call */
-                poll_params.poll_handler->handler(poll_params.poll_handler->data, 0, &poll_params);
-                poll_params.last_poll_tm = curr_tm;
+            if (poll_params.poll_handler == poll_params.paused_handler) {
+                /*
+                 * Just transitioned into POLLING_RESUME. Reset paused_handler
+                 * and poll immediately
+                 */
+                poll_params.paused_handler = NULL;
+                poll_now = true;
+                nevents = 0;
+            } else {
+                /* Calculate next timeout */
+                delay = get_time_diff_ms(&poll_params.last_poll_tm, &curr_tm);
+                delay = (delay < poll_params.polling_interval_ms) ?
+                    poll_params.polling_interval_ms - delay : poll_params.polling_interval_ms;
 
-                if (poll_params.update != POLLING_DO_NOT_CHANGE) {
-                    switch (poll_params.update) {
-                    case POLLING_START:
-                        poll_params.poll_start_tm = curr_tm;
-                        break;
-                    case POLLING_STOP:
-                        poll_params.poll_handler = NULL;
-                        break;
-                    default:
-                        break;
-                    }
-                    poll_params.update = POLLING_DO_NOT_CHANGE;
-                } else {
-                    if (get_time_diff_ms(&poll_params.poll_start_tm, &curr_tm) >
-                        PSI_WINDOW_SIZE_MS) {
-                        /* Polled for the duration of PSI window, time to stop */
-                        poll_params.poll_handler = NULL;
-                    }
-                }
+                /* Wait for events until the next polling timeout */
+                nevents = epoll_wait(epollfd, events, maxevents, delay);
+
+                /* Update current time after wait */
+                clock_gettime(CLOCK_MONOTONIC_COARSE, &curr_tm);
+                poll_now = (get_time_diff_ms(&poll_params.last_poll_tm, &curr_tm) >=
+                    poll_params.polling_interval_ms);
+            }
+            if (poll_now) {
+                call_handler(poll_params.poll_handler, &poll_params, 0);
             }
         } else {
             /* Wait for events with no timeout */
@@ -2656,29 +2828,7 @@
             }
             if (evt->data.ptr) {
                 handler_info = (struct event_handler_info*)evt->data.ptr;
-                /* Set input params for the call */
-                handler_info->handler(handler_info->data, evt->events, &poll_params);
-
-                if (poll_params.update != POLLING_DO_NOT_CHANGE) {
-                    switch (poll_params.update) {
-                    case POLLING_START:
-                        /*
-                         * Poll for the duration of PSI_WINDOW_SIZE_MS after the
-                         * initial PSI event because psi events are rate-limited
-                         * at one per sec.
-                         */
-                        clock_gettime(CLOCK_MONOTONIC_COARSE, &curr_tm);
-                        poll_params.poll_start_tm = poll_params.last_poll_tm = curr_tm;
-                        poll_params.poll_handler = handler_info;
-                        break;
-                    case POLLING_STOP:
-                        poll_params.poll_handler = NULL;
-                        break;
-                    default:
-                        break;
-                    }
-                    poll_params.update = POLLING_DO_NOT_CHANGE;
-                }
+                call_handler(handler_info, &poll_params, evt->events);
             }
         }
     }