#include <stdio.h>
#include <unistd.h>
#include <sched.h>
#include <signal.h>
#include <assert.h>
#include <libgen.h>
#include <pthread.h>
#include <bpf/bpf.h>
#include <sys/mman.h>
#include <sys/queue.h>
#include <sys/syscall.h>
#include <scx/common.h>
#include "scx_userland.h"
#include "scx_userland.bpf.skel.h"
const char help_fmt[] =
"A minimal userland sched_ext scheduler.\n"
"\n"
"See the top-level comment in .bpf.c for more details.\n"
"\n"
"Try to reduce `sysctl kernel.pid_max` if this program triggers OOMs.\n"
"\n"
"Usage: %s [-b BATCH]\n"
"\n"
" -b BATCH The number of tasks to batch when dispatching (default: 8)\n"
" -v Print libbpf debug messages\n"
" -h Display this help and exit\n";
#define SCHED_EXT 7
static __u32 batch_size = 8;
static bool verbose;
static volatile int exit_req;
static int enqueued_fd, dispatched_fd;
static pthread_t stats_printer;
static struct scx_userland *skel;
static struct bpf_link *ops_link;
static __u64 nr_vruntime_enqueues, nr_vruntime_dispatches, nr_vruntime_failed;
static __u64 nr_curr_enqueued;
struct enqueued_task {
LIST_ENTRY(enqueued_task) entries;
__u64 sum_exec_runtime;
double vruntime;
};
LIST_HEAD(listhead, enqueued_task);
static struct listhead vruntime_head = LIST_HEAD_INITIALIZER(vruntime_head);
struct enqueued_task *tasks;
static int pid_max;
static double min_vruntime;
static int libbpf_print_fn(enum libbpf_print_level level, const char *format, va_list args)
{
if (level == LIBBPF_DEBUG && !verbose)
return 0;
return vfprintf(stderr, format, args);
}
static void sigint_handler(int userland)
{
exit_req = 1;
}
static int get_pid_max(void)
{
FILE *fp;
int pid_max;
fp = fopen("/proc/sys/kernel/pid_max", "r");
if (fp == NULL) {
fprintf(stderr, "Error opening /proc/sys/kernel/pid_max\n");
return -1;
}
if (fscanf(fp, "%d", &pid_max) != 1) {
fprintf(stderr, "Error reading from /proc/sys/kernel/pid_max\n");
fclose(fp);
return -1;
}
fclose(fp);
return pid_max;
}
static int init_tasks(void)
{
pid_max = get_pid_max();
if (pid_max < 0)
return pid_max;
tasks = calloc(pid_max, sizeof(*tasks));
if (!tasks) {
fprintf(stderr, "Error allocating tasks array\n");
return -ENOMEM;
}
return 0;
}
static __u32 task_pid(const struct enqueued_task *task)
{
return ((uintptr_t)task - (uintptr_t)tasks) / sizeof(*task);
}
static int dispatch_task(__s32 pid)
{
int err;
err = bpf_map_update_elem(dispatched_fd, NULL, &pid, 0);
if (err) {
__atomic_add_fetch(&nr_vruntime_failed, 1, __ATOMIC_RELAXED);
} else {
__atomic_add_fetch(&nr_vruntime_dispatches, 1, __ATOMIC_RELAXED);
}
return err;
}
static struct enqueued_task *get_enqueued_task(__s32 pid)
{
if (pid >= pid_max)
return NULL;
return &tasks[pid];
}
static double calc_vruntime_delta(__u64 weight, __u64 delta)
{
double weight_f = (double)weight / 100.0;
double delta_f = (double)delta;
return delta_f / weight_f;
}
static void update_enqueued(struct enqueued_task *enqueued, const struct scx_userland_enqueued_task *bpf_task)
{
__u64 delta;
delta = bpf_task->sum_exec_runtime - enqueued->sum_exec_runtime;
enqueued->vruntime += calc_vruntime_delta(bpf_task->weight, delta);
if (min_vruntime > enqueued->vruntime)
enqueued->vruntime = min_vruntime;
enqueued->sum_exec_runtime = bpf_task->sum_exec_runtime;
}
static int vruntime_enqueue(const struct scx_userland_enqueued_task *bpf_task)
{
struct enqueued_task *curr, *enqueued, *prev;
curr = get_enqueued_task(bpf_task->pid);
if (!curr)
return ENOENT;
update_enqueued(curr, bpf_task);
__atomic_add_fetch(&nr_vruntime_enqueues, 1, __ATOMIC_RELAXED);
__atomic_add_fetch(&nr_curr_enqueued, 1, __ATOMIC_RELAXED);
if (LIST_EMPTY(&vruntime_head)) {
LIST_INSERT_HEAD(&vruntime_head, curr, entries);
return 0;
}
LIST_FOREACH(enqueued, &vruntime_head, entries) {
if (curr->vruntime <= enqueued->vruntime) {
LIST_INSERT_BEFORE(enqueued, curr, entries);
return 0;
}
prev = enqueued;
}
LIST_INSERT_AFTER(prev, curr, entries);
return 0;
}
static void drain_enqueued_map(void)
{
while (1) {
struct scx_userland_enqueued_task task;
int err;
if (bpf_map_lookup_and_delete_elem(enqueued_fd, NULL, &task)) {
skel->bss->nr_queued = 0;
skel->bss->nr_scheduled = nr_curr_enqueued;
return;
}
err = vruntime_enqueue(&task);
if (err) {
fprintf(stderr, "Failed to enqueue task %d: %s\n",
task.pid, strerror(err));
exit_req = 1;
return;
}
}
}
static void dispatch_batch(void)
{
__u32 i;
for (i = 0; i < batch_size; i++) {
struct enqueued_task *task;
int err;
__s32 pid;
task = LIST_FIRST(&vruntime_head);
if (!task)
break;
min_vruntime = task->vruntime;
pid = task_pid(task);
LIST_REMOVE(task, entries);
err = dispatch_task(pid);
if (err) {
LIST_INSERT_HEAD(&vruntime_head, task, entries);
break;
}
__atomic_sub_fetch(&nr_curr_enqueued, 1, __ATOMIC_RELAXED);
}
skel->bss->nr_scheduled = __atomic_load_n(&nr_curr_enqueued, __ATOMIC_RELAXED);
}
static void *run_stats_printer(void *arg)
{
while (!exit_req) {
__u64 nr_failed_enqueues, nr_kernel_enqueues, nr_user_enqueues, total;
nr_failed_enqueues = skel->bss->nr_failed_enqueues;
nr_kernel_enqueues = skel->bss->nr_kernel_enqueues;
nr_user_enqueues = skel->bss->nr_user_enqueues;
total = nr_failed_enqueues + nr_kernel_enqueues + nr_user_enqueues;
printf("o-----------------------o\n");
printf("| BPF ENQUEUES |\n");
printf("|-----------------------|\n");
printf("| kern: %10llu |\n", nr_kernel_enqueues);
printf("| user: %10llu |\n", nr_user_enqueues);
printf("| failed: %10llu |\n", nr_failed_enqueues);
printf("| -------------------- |\n");
printf("| total: %10llu |\n", total);
printf("| |\n");
printf("|-----------------------|\n");
printf("| VRUNTIME / USER |\n");
printf("|-----------------------|\n");
printf("| enq: %10llu |\n", __atomic_load_n(&nr_vruntime_enqueues, __ATOMIC_RELAXED));
printf("| disp: %10llu |\n", __atomic_load_n(&nr_vruntime_dispatches, __ATOMIC_RELAXED));
printf("| failed: %10llu |\n", __atomic_load_n(&nr_vruntime_failed, __ATOMIC_RELAXED));
printf("o-----------------------o\n");
printf("\n\n");
fflush(stdout);
sleep(1);
}
return NULL;
}
static int spawn_stats_thread(void)
{
return pthread_create(&stats_printer, NULL, run_stats_printer, NULL);
}
static void pre_bootstrap(int argc, char **argv)
{
int err;
__u32 opt;
struct sched_param sched_param = {
.sched_priority = sched_get_priority_max(SCHED_EXT),
};
err = init_tasks();
if (err)
exit(err);
libbpf_set_print(libbpf_print_fn);
signal(SIGINT, sigint_handler);
signal(SIGTERM, sigint_handler);
err = syscall(__NR_sched_setscheduler, getpid(), SCHED_EXT, &sched_param);
SCX_BUG_ON(err, "Failed to set scheduler to SCHED_EXT");
while ((opt = getopt(argc, argv, "b:vh")) != -1) {
switch (opt) {
case 'b':
batch_size = strtoul(optarg, NULL, 0);
break;
case 'v':
verbose = true;
break;
default:
fprintf(stderr, help_fmt, basename(argv[0]));
exit(opt != 'h');
}
}
err = mlockall(MCL_CURRENT | MCL_FUTURE);
SCX_BUG_ON(err, "Failed to prefault and lock address space");
}
static void bootstrap(char *comm)
{
exit_req = 0;
min_vruntime = 0.0;
__atomic_store_n(&nr_vruntime_enqueues, 0, __ATOMIC_RELAXED);
__atomic_store_n(&nr_vruntime_dispatches, 0, __ATOMIC_RELAXED);
__atomic_store_n(&nr_vruntime_failed, 0, __ATOMIC_RELAXED);
__atomic_store_n(&nr_curr_enqueued, 0, __ATOMIC_RELAXED);
memset(tasks, 0, pid_max * sizeof(*tasks));
LIST_INIT(&vruntime_head);
skel = SCX_OPS_OPEN(userland_ops, scx_userland);
skel->rodata->num_possible_cpus = libbpf_num_possible_cpus();
assert(skel->rodata->num_possible_cpus > 0);
skel->rodata->usersched_pid = getpid();
assert(skel->rodata->usersched_pid > 0);
SCX_OPS_LOAD(skel, userland_ops, scx_userland, uei);
enqueued_fd = bpf_map__fd(skel->maps.enqueued);
dispatched_fd = bpf_map__fd(skel->maps.dispatched);
assert(enqueued_fd > 0);
assert(dispatched_fd > 0);
SCX_BUG_ON(spawn_stats_thread(), "Failed to spawn stats thread");
ops_link = SCX_OPS_ATTACH(skel, userland_ops, scx_userland);
}
static void sched_main_loop(void)
{
while (!exit_req) {
drain_enqueued_map();
dispatch_batch();
sched_yield();
}
}
int main(int argc, char **argv)
{
__u64 ecode;
pre_bootstrap(argc, argv);
restart:
bootstrap(argv[0]);
sched_main_loop();
exit_req = 1;
bpf_link__destroy(ops_link);
pthread_join(stats_printer, NULL);
ecode = UEI_REPORT(skel, uei);
scx_userland__destroy(skel);
if (UEI_ECODE_RESTART(ecode))
goto restart;
return 0;
}