diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index da8813fd2789..86c30dcaaae2 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2452,6 +2452,8 @@ extern "C" { uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling) bool strict_cpu; // strict cpu placement bool paused; // start in paused state + void (*thread_create_callback)(void); // callback invoked when thread is created + void (*thread_destroy_callback)(void); // callback invoked when thread is destroyed }; struct ggml_threadpool; // forward declaration, see ggml.c diff --git a/ggml/src/ggml-cpu/ggml-cpu-c.c b/ggml/src/ggml-cpu/ggml-cpu-c.c index f6bea3df34a0..fbefc0f45ec1 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-c.c +++ b/ggml/src/ggml-cpu/ggml-cpu-c.c @@ -463,6 +463,9 @@ struct ggml_threadpool { int32_t prio; // Scheduling priority uint32_t poll; // Polling level (0 - no polling) + void (*thread_create_callback)(void); + void (*thread_destroy_callback)(void); + enum ggml_status ec; }; @@ -2959,6 +2962,10 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) { struct ggml_compute_state * state = (struct ggml_compute_state *) data; struct ggml_threadpool * threadpool = state->threadpool; + if (threadpool->thread_create_callback) { + threadpool->thread_create_callback(); + } + ggml_thread_apply_priority(threadpool->prio); if (ggml_thread_cpumask_is_valid(state->cpumask)) { ggml_thread_apply_affinity(state->cpumask); @@ -2990,6 +2997,10 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) { } } + if (threadpool->thread_destroy_callback) { + threadpool->thread_destroy_callback(); + } + return (thread_ret_t) 0; } @@ -3049,6 +3060,8 @@ static struct ggml_threadpool * ggml_threadpool_new_impl( threadpool->n_threads_cur = tpp->n_threads; threadpool->poll = tpp->poll; threadpool->prio = tpp->prio; + threadpool->thread_create_callback = tpp->thread_create_callback; + threadpool->thread_destroy_callback = tpp->thread_destroy_callback; threadpool->ec = GGML_STATUS_SUCCESS; }