use std::sync::Arc; use chrono::DateTime; use chrono::SecondsFormat; use chrono::Utc; use codex_protocol::ThreadId; use codex_protocol::models::SandboxPermissions; use codex_utils_absolute_path::AbsolutePathBuf; use futures::future::BoxFuture; use serde::Serialize; use serde::Serializer; pub type HookFn = Arc Fn(&'a HookPayload) -> BoxFuture<'a, HookResult> + Send + Sync>; #[derive(Debug)] pub enum HookResult { /// Success: hook completed successfully. Success, /// FailedContinue: hook failed, but other subsequent hooks should still execute and the /// operation should continue. FailedContinue(Box), /// FailedAbort: hook failed, other subsequent hooks should not execute, and the operation /// should be aborted. FailedAbort(Box), } impl HookResult { pub fn should_abort_operation(&self) -> bool { matches!(self, Self::FailedAbort(_)) } } #[derive(Debug)] pub struct HookResponse { pub hook_name: String, pub result: HookResult, } #[derive(Clone)] pub struct Hook { pub name: String, pub func: HookFn, } impl Default for Hook { fn default() -> Self { Self { name: "default".to_string(), func: Arc::new(|_| Box::pin(async { HookResult::Success })), } } } impl Hook { pub async fn execute(&self, payload: &HookPayload) -> HookResponse { HookResponse { hook_name: self.name.clone(), result: (self.func)(payload).await, } } } #[derive(Debug, Serialize, Clone)] #[serde(rename_all = "snake_case")] pub struct HookPayload { pub session_id: ThreadId, pub cwd: AbsolutePathBuf, #[serde(skip_serializing_if = "Option::is_none")] pub client: Option, #[serde(serialize_with = "serialize_triggered_at")] pub triggered_at: DateTime, pub hook_event: HookEvent, } #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "snake_case")] pub struct HookEventAfterAgent { pub thread_id: ThreadId, pub turn_id: String, pub input_messages: Vec, pub last_assistant_message: Option, } #[derive(Debug, Clone, Serialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum HookToolKind { Function, Custom, LocalShell, Mcp, } #[derive(Debug, Clone, Serialize, PartialEq)] #[serde(rename_all = "snake_case")] pub struct HookToolInputLocalShell { pub command: Vec, pub workdir: Option, pub timeout_ms: Option, pub sandbox_permissions: Option, pub prefix_rule: Option>, pub justification: Option, } #[derive(Debug, Clone, Serialize, PartialEq)] #[serde(tag = "input_type", rename_all = "snake_case")] pub enum HookToolInput { Function { arguments: String, }, Custom { input: String, }, LocalShell { params: HookToolInputLocalShell, }, Mcp { server: String, tool: String, arguments: String, }, } #[derive(Debug, Clone, Serialize, PartialEq)] #[serde(rename_all = "snake_case")] pub struct HookEventAfterToolUse { pub turn_id: String, pub call_id: String, pub tool_name: String, pub tool_kind: HookToolKind, pub tool_input: HookToolInput, pub executed: bool, pub success: bool, pub duration_ms: u64, pub mutating: bool, pub sandbox: String, pub sandbox_policy: String, pub output_preview: String, } fn serialize_triggered_at(value: &DateTime, serializer: S) -> Result where S: Serializer, { serializer.serialize_str(&value.to_rfc3339_opts(SecondsFormat::Secs, true)) } #[derive(Debug, Clone, Serialize)] #[serde(tag = "event_type", rename_all = "snake_case")] pub enum HookEvent { AfterAgent { #[serde(flatten)] event: HookEventAfterAgent, }, AfterToolUse { #[serde(flatten)] event: HookEventAfterToolUse, }, } #[cfg(test)] mod tests { use chrono::TimeZone; use chrono::Utc; use codex_protocol::ThreadId; use codex_protocol::models::SandboxPermissions; use codex_utils_absolute_path::test_support::PathBufExt; use codex_utils_absolute_path::test_support::test_path_buf; use pretty_assertions::assert_eq; use serde_json::json; use super::HookEvent; use super::HookEventAfterAgent; use super::HookEventAfterToolUse; use super::HookPayload; use super::HookToolInput; use super::HookToolInputLocalShell; use super::HookToolKind; #[test] fn hook_payload_serializes_stable_wire_shape() { let session_id = ThreadId::new(); let thread_id = ThreadId::new(); let cwd = test_path_buf("/tmp").abs(); let payload = HookPayload { session_id, cwd: cwd.clone(), client: None, triggered_at: Utc .with_ymd_and_hms(2025, 1, 1, 0, 0, 0) .single() .expect("valid timestamp"), hook_event: HookEvent::AfterAgent { event: HookEventAfterAgent { thread_id, turn_id: "turn-1".to_string(), input_messages: vec!["hello".to_string()], last_assistant_message: Some("hi".to_string()), }, }, }; let actual = serde_json::to_value(payload).expect("serialize hook payload"); let expected = json!({ "session_id": session_id.to_string(), "cwd": cwd.display().to_string(), "triggered_at": "2025-01-01T00:00:00Z", "hook_event": { "event_type": "after_agent", "thread_id": thread_id.to_string(), "turn_id": "turn-1", "input_messages": ["hello"], "last_assistant_message": "hi", }, }); assert_eq!(actual, expected); } #[test] fn after_tool_use_payload_serializes_stable_wire_shape() { let session_id = ThreadId::new(); let cwd = test_path_buf("/tmp").abs(); let payload = HookPayload { session_id, cwd: cwd.clone(), client: None, triggered_at: Utc .with_ymd_and_hms(2025, 1, 1, 0, 0, 0) .single() .expect("valid timestamp"), hook_event: HookEvent::AfterToolUse { event: HookEventAfterToolUse { turn_id: "turn-2".to_string(), call_id: "call-1".to_string(), tool_name: "local_shell".to_string(), tool_kind: HookToolKind::LocalShell, tool_input: HookToolInput::LocalShell { params: HookToolInputLocalShell { command: vec!["cargo".to_string(), "fmt".to_string()], workdir: Some("codex-rs".to_string()), timeout_ms: Some(60_000), sandbox_permissions: Some(SandboxPermissions::UseDefault), justification: None, prefix_rule: None, }, }, executed: true, success: true, duration_ms: 42, mutating: true, sandbox: "none".to_string(), sandbox_policy: "danger-full-access".to_string(), output_preview: "ok".to_string(), }, }, }; let actual = serde_json::to_value(payload).expect("serialize hook payload"); let expected = json!({ "session_id": session_id.to_string(), "cwd": cwd.display().to_string(), "triggered_at": "2025-01-01T00:00:00Z", "hook_event": { "event_type": "after_tool_use", "turn_id": "turn-2", "call_id": "call-1", "tool_name": "local_shell", "tool_kind": "local_shell", "tool_input": { "input_type": "local_shell", "params": { "command": ["cargo", "fmt"], "workdir": "codex-rs", "timeout_ms": 60000, "sandbox_permissions": "use_default", "justification": null, "prefix_rule": null, }, }, "executed": true, "success": true, "duration_ms": 42, "mutating": true, "sandbox": "none", "sandbox_policy": "danger-full-access", "output_preview": "ok", }, }); assert_eq!(actual, expected); } }