import asyncio import base64 import json import logging import time from collections import deque from dataclasses import dataclass from typing import Any, AsyncIterator from inference.core.types import ( AudioChunk, ImageFrame, PluginConfig, ToolCall, ToolResult, VoiceLLMInputEvent, VoiceLLMOutputEvent, VoiceLLMSessionConfig, ) from inference.plugins.qwen_endpoint import dashscope_realtime_ws_url from inference.plugins.voice_llm.base import VoiceCheckError, VoiceLLMPlugin logger = logging.getLogger(__name__) _MAX_IMAGE_BYTES = 500 * 1024 class QwenOmniRealtimePlugin(VoiceLLMPlugin): """DashScope Qwen Omni realtime omni model plugin.""" name = "omni.qwen_omni" def __init__(self) -> None: self.api_key = "" self.model = "qwen3.5-omni-flash-realtime" self.ws_url = "" self.voice = "Tina" self.system_prompt = "" self.input_sample_rate = 16000 self.output_sample_rate = 24000 self.vad_type = "semantic_vad" self.vad_threshold = 0.5 self.vad_silence_duration_ms = 800 self.enable_search: bool | None = None self.search_options: dict[str, Any] | None = None self.temperature: float | None = None self.top_p: float | None = None self.top_k: int | None = None self.max_tokens: int | None = None self._active_ws: Any | None = None async def initialize(self, config: PluginConfig) -> None: params = config.params self.api_key = params.get("api_key", self.api_key) self.model = params.get("model", self.model) self.ws_url = dashscope_realtime_ws_url(self.model, "DASHSCOPE_OMNI_WS_URL") self.voice = params.get("voice", self.voice) self.system_prompt = params.get("system_prompt", self.system_prompt) self.input_sample_rate = int( params.get("input_sample_rate", self.input_sample_rate) ) self.output_sample_rate = int( params.get("output_sample_rate", self.output_sample_rate) ) self.vad_type = params.get("vad_type", self.vad_type) self.vad_threshold = float(params.get("vad_threshold", self.vad_threshold)) self.vad_silence_duration_ms = int( params.get("vad_silence_duration_ms", self.vad_silence_duration_ms) ) self.enable_search = self._optional_bool(params.get("enable_search")) search_options = params.get("search_options") if isinstance(search_options, dict): self.search_options = search_options self.temperature = self._optional_float(params.get("temperature")) self.top_p = self._optional_float(params.get("top_p")) self.top_k = self._optional_int(params.get("top_k")) self.max_tokens = self._optional_int(params.get("max_tokens")) async def check_voice( self, session_config: VoiceLLMSessionConfig | None = None, ) -> None: import websockets ws = await self._connect(websockets) try: await self._configure_session(ws, session_config or VoiceLLMSessionConfig()) except RuntimeError as exc: raise VoiceCheckError(str(exc)) from exc finally: await ws.close() async def converse_stream( self, input_stream: AsyncIterator[VoiceLLMInputEvent], session_config: VoiceLLMSessionConfig | None = None, ) -> AsyncIterator[VoiceLLMOutputEvent]: import websockets config = session_config or VoiceLLMSessionConfig() ws = await self._connect(websockets) self._active_ws = ws response_coordinator = _QwenResponseCoordinator() output_queue: asyncio.Queue[VoiceLLMOutputEvent | Exception | None] = ( asyncio.Queue() ) sender_task: asyncio.Task | None = None receiver_task: asyncio.Task | None = None try: await self._configure_session(ws, config) sender_task = asyncio.create_task( self._send_inputs( ws, input_stream, config.session_id, output_queue, response_coordinator, drain_responses_on_close=True, ) ) receiver_task = asyncio.create_task( self._receive_events( ws, config.session_id, output_queue, response_coordinator, defer_response=config.defer_response, ) ) while True: item = await output_queue.get() if item is None: break if isinstance(item, Exception): raise item yield item finally: for task in (sender_task, receiver_task): if task and not task.done(): task.cancel() try: await task except asyncio.CancelledError: pass if self._active_ws is ws: self._active_ws = None await ws.close() async def interrupt(self) -> None: ws = self._active_ws if ws is None: return for event_type in ("response.cancel", "input_audio_buffer.clear"): try: await self._send_json( ws, { "type": event_type, "event_id": self._event_id("qwen_omni", "interrupt"), }, ) except Exception: logger.debug("Failed to send Qwen Omni interrupt event", exc_info=True) async def _connect(self, websockets: Any): headers = {"Authorization": f"Bearer {self.api_key}"} try: return await websockets.connect( self.ws_url, additional_headers=headers, ) except TypeError: return await websockets.connect( self.ws_url, extra_headers=headers, ) async def _configure_session( self, ws: Any, session_config: VoiceLLMSessionConfig, ) -> None: await self._send_json( ws, { "type": "session.update", "event_id": self._event_id(session_config.session_id, "session"), "session": self._session_payload(session_config), }, ) while True: event = self._decode_message(await ws.recv()) event_type = event.get("type", "") if event_type in {"session.created", "session.updated"}: return if event_type == "error": raise RuntimeError(self._error_message(event)) async def _send_inputs( self, ws: Any, input_stream: AsyncIterator[VoiceLLMInputEvent], session_id: str, output_queue: asyncio.Queue[VoiceLLMOutputEvent | Exception | None], response_coordinator: "_QwenResponseCoordinator | None" = None, drain_responses_on_close: bool = False, ) -> None: response_coordinator = response_coordinator or _QwenResponseCoordinator() response_sender_task = asyncio.create_task( self._send_deferred_responses(ws, response_coordinator, output_queue) ) try: pending_image: ImageFrame | None = None has_sent_audio = False expects_deferred_response = False async for event in input_stream: if event.tool_result: await self._send_tool_result( ws, session_id, event.tool_result, response_coordinator, ) if not event.tool_result.suppress_response: expects_deferred_response = True continue if event.response_instructions is not None: expects_deferred_response = True await self._send_response_instructions( ws, session_id, event.response_instructions, response_coordinator, ) continue if event.text: expects_deferred_response = True await self._send_text(ws, session_id, event.text, response_coordinator) continue if event.audio: has_sent_audio = True await self._send_json( ws, { "type": "input_audio_buffer.append", "event_id": self._event_id(session_id, "audio"), "audio": base64.b64encode(event.audio).decode("ascii"), }, ) # Keep image strictly after an audio append. This avoids provider-side # ordering violations when a new turn starts and an image arrives first. if pending_image is not None: await self._send_image(ws, session_id, pending_image) pending_image = None if event.image is not None: if not self._valid_image(event.image): continue # Always buffer the latest valid frame and flush it only after # the next audio chunk is appended. pending_image = event.image if pending_image is not None and has_sent_audio: # If stream ends after image input, flush once to avoid dropping # the latest frame while still guaranteeing audio-first ordering. await self._send_image(ws, session_id, pending_image) if drain_responses_on_close and expects_deferred_response: await response_coordinator.wait_all_responses_done(timeout=60.0) except Exception as exc: await output_queue.put(exc) finally: await response_coordinator.close() if not response_sender_task.done(): try: await response_sender_task except asyncio.CancelledError: pass try: await ws.close() except Exception: pass async def _send_image(self, ws: Any, session_id: str, image: ImageFrame) -> None: await self._send_json( ws, { "type": "input_image_buffer.append", "event_id": self._event_id(session_id, "image"), "image": base64.b64encode(image.data).decode("ascii"), }, ) async def _send_text( self, ws: Any, session_id: str, text: str, response_coordinator: "_QwenResponseCoordinator", ) -> None: await response_coordinator.enqueue( _QwenDeferredResponse( item_payload={ "type": "conversation.item.create", "event_id": self._event_id(session_id, "text"), "item": { "type": "message", "role": "user", "content": [ { "type": "input_text", "text": text, } ], }, }, response_payload={ "type": "response.create", "event_id": self._event_id(session_id, "text_response"), "response": {"modalities": ["text", "audio"]}, }, ) ) async def _send_response_instructions( self, ws: Any, session_id: str, instructions: str, response_coordinator: "_QwenResponseCoordinator", ) -> None: response: dict[str, Any] = {"modalities": ["text", "audio"]} instructions = str(instructions or "").strip() if instructions: response["instructions"] = instructions await response_coordinator.enqueue( _QwenDeferredResponse( response_payload={ "type": "response.create", "event_id": self._event_id(session_id, "response"), "response": response, } ) ) async def _send_tool_result( self, ws: Any, session_id: str, result: ToolResult, response_coordinator: "_QwenResponseCoordinator", ) -> None: await self._send_json( ws, { "type": "conversation.item.create", "event_id": self._event_id(session_id, "tool_result"), "item": { "type": "function_call_output", "call_id": result.id, "output": json.dumps(result.result, ensure_ascii=False), }, }, ) if result.suppress_response: return await response_coordinator.enqueue( _QwenDeferredResponse( response_payload={ "type": "response.create", "event_id": self._event_id(session_id, "tool_response"), } ) ) async def _send_deferred_responses( self, ws: Any, response_coordinator: "_QwenResponseCoordinator", output_queue: asyncio.Queue[VoiceLLMOutputEvent | Exception | None], ) -> None: try: while True: request = await response_coordinator.next_request() if request is None: return await self._send_deferred_response(ws, response_coordinator, request) except asyncio.CancelledError: raise except Exception as exc: await output_queue.put(exc) async def _send_deferred_response( self, ws: Any, response_coordinator: "_QwenResponseCoordinator", request: "_QwenDeferredResponse", ) -> None: if not await response_coordinator.wait_idle(): return if request.item_payload is not None and not request.item_sent: await self._send_json(ws, request.item_payload) request.item_sent = True await response_coordinator.begin_client_response(request) try: await self._send_json(ws, request.response_payload) except Exception: await response_coordinator.release_client_response(request) raise @staticmethod def _valid_image(image: ImageFrame) -> bool: mime_type = (image.mime_type or "").lower() if mime_type and mime_type not in {"image/jpeg", "image/jpg"}: return False data = image.data or b"" if len(data) == 0 or len(data) > _MAX_IMAGE_BYTES: return False return len(data) >= 3 and data[0] == 0xFF and data[1] == 0xD8 and data[2] == 0xFF async def _receive_events( self, ws: Any, session_id: str, output_queue: asyncio.Queue[VoiceLLMOutputEvent | Exception | None], response_coordinator: "_QwenResponseCoordinator | None" = None, defer_response: bool = False, ) -> None: response_coordinator = response_coordinator or _QwenResponseCoordinator() turn_state = _QwenTurnState(session_id=session_id or "qwen_omni") tool_arg_parts: dict[str, str] = {} emitted_tool_calls: set[str] = set() try: async for message in ws: event = self._decode_message(message) self._log_server_event(session_id, event) event_type = event.get("type", "") if event_type == "error": message = self._error_message(event) if self._is_active_response_error(message): logger.info( "qwen_omni deferred response delayed by active response session=%s", session_id or "qwen_omni", ) await response_coordinator.mark_active_response_error() continue raise RuntimeError(message) if event_type in {"session.created", "session.updated"}: continue if event_type == "response.function_call_arguments.delta": call_id = str(event.get("call_id") or event.get("item_id") or "") if call_id: tool_arg_parts[call_id] = tool_arg_parts.get(call_id, "") + str(event.get("delta", "") or "") continue if event_type == "response.function_call_arguments.done": call = self._tool_call_from_event(event, tool_arg_parts) if call and call.id not in emitted_tool_calls: emitted_tool_calls.add(call.id) await output_queue.put( VoiceLLMOutputEvent( tool_calls=[call], question_id=turn_state.question_id, reply_id=turn_state.reply_id, ) ) continue if event_type == "response.output_item.done": item = event.get("item") if isinstance(item, dict) and item.get("type") == "function_call": call = self._tool_call_from_event(item, tool_arg_parts) if call and call.id not in emitted_tool_calls: emitted_tool_calls.add(call.id) await output_queue.put( VoiceLLMOutputEvent( tool_calls=[call], question_id=turn_state.question_id, reply_id=turn_state.reply_id, ) ) continue if event_type == "input_audio_buffer.speech_started": if not defer_response: await response_coordinator.mark_response_started() turn_state.start_next_turn() await output_queue.put( VoiceLLMOutputEvent( barge_in=True, question_id=turn_state.question_id, reply_id=turn_state.reply_id, ) ) continue if event_type == "response.created": await response_coordinator.mark_response_started() response = event.get("response") if isinstance(response, dict): response_id = str(response.get("id", "") or "") if response_id and not turn_state.question_id: turn_state.start_next_turn() turn_state.reply_id = response_id continue if event_type == "conversation.item.input_audio_transcription.completed": turn_state.ensure_turn() transcript = str(event.get("transcript", "") or "").strip() if transcript: await output_queue.put( VoiceLLMOutputEvent( user_transcript=transcript, question_id=turn_state.question_id, reply_id=turn_state.reply_id, ) ) continue if event_type == "response.audio_transcript.delta": turn_state.ensure_turn() delta = str(event.get("delta", "") or "") if delta: turn_state.assistant_text += delta await output_queue.put( VoiceLLMOutputEvent( transcript=delta, question_id=turn_state.question_id, reply_id=turn_state.reply_id, ) ) continue if event_type == "response.audio_transcript.done": transcript = str(event.get("transcript", "") or "") if transcript: turn_state.assistant_text = transcript continue if event_type == "response.audio.delta": turn_state.ensure_turn() delta = str(event.get("delta", "") or "") if not delta: continue audio_payload = base64.b64decode(delta) if audio_payload: turn_state.has_audio = True await output_queue.put( VoiceLLMOutputEvent( audio=AudioChunk( data=audio_payload, sample_rate=self.output_sample_rate, channels=1, format="pcm_s16le", ), question_id=turn_state.question_id, reply_id=turn_state.reply_id, ) ) continue if event_type == "response.done": if turn_state.has_content: await output_queue.put( VoiceLLMOutputEvent( audio=AudioChunk( data=b"", sample_rate=self.output_sample_rate, channels=1, format="pcm_s16le", is_final=True, ) if turn_state.has_audio else None, transcript=turn_state.assistant_text, is_final=True, question_id=turn_state.question_id, reply_id=turn_state.reply_id, ) ) turn_state.reset() await response_coordinator.mark_response_done() continue except Exception as exc: if not getattr(ws, "closed", False): await output_queue.put(exc) finally: await output_queue.put(None) def _session_payload(self, session_config: VoiceLLMSessionConfig) -> dict[str, Any]: payload: dict[str, Any] = { "modalities": ["text", "audio"], "voice": session_config.voice or self.voice, "input_audio_format": "pcm", "output_audio_format": "pcm", "instructions": self._instructions(session_config), "turn_detection": { "type": self.vad_type, "threshold": self.vad_threshold, "silence_duration_ms": self.vad_silence_duration_ms, }, } if session_config.defer_response: payload["turn_detection"]["create_response"] = False has_tools = bool(session_config.tools) optional_values: dict[str, Any] = { "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, "max_tokens": self.max_tokens, } if not has_tools: optional_values["enable_search"] = self.enable_search optional_values["search_options"] = self.search_options for key, value in optional_values.items(): if value is not None: payload[key] = value if has_tools: payload["tools"] = [ { "type": "function", "function": { "name": tool.name, "description": tool.description, "parameters": tool.parameters or {"type": "object", "properties": {}}, }, } for tool in session_config.tools ] return payload def _instructions(self, session_config: VoiceLLMSessionConfig) -> str: parts: list[str] = [] if session_config.bot_name: parts.append(f"名字:{session_config.bot_name}") parts.append(session_config.system_prompt or self.system_prompt) if session_config.speaking_style: parts.append(f"说话风格:{session_config.speaking_style}") if session_config.dialog_context: parts.append("以下是最近的对话上下文,请在回答时保持连续性:") for item in session_config.dialog_context: role = "用户" if item.role == "user" else "助手" parts.append(f"{role}:{item.text}") return "\n".join(part for part in parts if part.strip()) @staticmethod def _tool_call_from_event(event: dict[str, Any], arg_parts: dict[str, str]) -> ToolCall | None: call_id = str(event.get("call_id") or event.get("id") or event.get("item_id") or "") name = str(event.get("name") or "") raw_args = event.get("arguments") if raw_args is None and call_id: raw_args = arg_parts.get(call_id, "") if not call_id or not name: return None return ToolCall( id=call_id, name=name, arguments=QwenOmniRealtimePlugin._parse_tool_arguments(raw_args), ) @staticmethod def _parse_tool_arguments(raw: Any) -> dict[str, Any]: if isinstance(raw, dict): return raw if raw is None: return {} try: parsed = json.loads(str(raw)) except json.JSONDecodeError: return {} return parsed if isinstance(parsed, dict) else {} @staticmethod def _clip_text(value: Any, limit: int = 180) -> str: text = str(value or "") if len(text) <= limit: return text return text[:limit] + "..." @classmethod def _server_event_log_fields(cls, event: dict[str, Any]) -> dict[str, Any]: event_type = str(event.get("type") or "") fields: dict[str, Any] = {} for key in ("response_id", "item_id", "call_id", "name", "output_index"): if key in event and event.get(key) not in (None, ""): fields[key] = event.get(key) if event_type == "response.audio.delta": fields["audio_delta_b64_len"] = len(str(event.get("delta") or "")) elif event_type in {"response.audio_transcript.delta", "response.function_call_arguments.delta"}: fields["delta"] = cls._clip_text(event.get("delta")) if "transcript" in event: fields["transcript"] = cls._clip_text(event.get("transcript")) if "arguments" in event: fields["arguments"] = cls._clip_text(event.get("arguments")) item = event.get("item") if isinstance(item, dict): item_fields = { key: item.get(key) for key in ("type", "id", "call_id", "name") if item.get(key) not in (None, "") } if "arguments" in item: item_fields["arguments"] = cls._clip_text(item.get("arguments")) fields["item"] = item_fields response = event.get("response") if isinstance(response, dict): fields["response"] = { key: response.get(key) for key in ("id", "status") if response.get(key) not in (None, "") } error = event.get("error") if error: fields["error"] = cls._clip_text(error) return fields @classmethod def _server_event_log_level(cls, event: dict[str, Any]) -> int: event_type = str(event.get("type") or "") if event_type == "error": return logging.ERROR if event_type in { "session.created", "session.updated", "input_audio_buffer.speech_started", "conversation.item.input_audio_transcription.completed", "response.created", "response.audio_transcript.done", "response.function_call_arguments.done", "response.done", }: return logging.INFO return logging.DEBUG @classmethod def _log_server_event(cls, session_id: str, event: dict[str, Any]) -> None: event_type = str(event.get("type") or "unknown") level = cls._server_event_log_level(event) if not logger.isEnabledFor(level): return fields = cls._server_event_log_fields(event) logger.log( level, "qwen_omni model event session=%s type=%s fields=%s", session_id or "qwen_omni", event_type, json.dumps(fields, ensure_ascii=False, sort_keys=True), ) @staticmethod async def _send_json(ws: Any, payload: dict[str, Any]) -> None: await ws.send(json.dumps(payload, ensure_ascii=False)) @staticmethod def _decode_message(message: str | bytes) -> dict[str, Any]: if isinstance(message, bytes): message = message.decode("utf-8") return json.loads(message) @staticmethod def _event_id(session_id: str, suffix: str) -> str: base = session_id or "qwen_omni" return f"{base}_{suffix}_{int(time.time() * 1000)}" @staticmethod def _error_message(event: dict[str, Any]) -> str: error = event.get("error") if isinstance(error, dict): message = error.get("message") or error.get("msg") or error.get("code") if message: return str(message) if isinstance(error, str): return error return f"Qwen Omni error: {event}" @staticmethod def _is_active_response_error(message: str) -> bool: return "Conversation already has an active response" in message @staticmethod def _optional_bool(value: Any) -> bool | None: if value is None: return None if isinstance(value, bool): return value if isinstance(value, str): normalized = value.strip().lower() if normalized in {"true", "1", "yes"}: return True if normalized in {"false", "0", "no"}: return False return None @staticmethod def _optional_float(value: Any) -> float | None: if value is None: return None try: return float(value) except (TypeError, ValueError): return None @staticmethod def _optional_int(value: Any) -> int | None: if value is None: return None try: return int(value) except (TypeError, ValueError): return None async def shutdown(self) -> None: if self._active_ws is not None: await self._active_ws.close() self._active_ws = None @dataclass class _QwenDeferredResponse: response_payload: dict[str, Any] item_payload: dict[str, Any] | None = None item_sent: bool = False class _QwenResponseCoordinator: def __init__(self) -> None: self._pending: deque[_QwenDeferredResponse] = deque() self._pending_condition = asyncio.Condition() self._state_condition = asyncio.Condition() self._idle = True self._closed = False self._current_response: _QwenDeferredResponse | None = None async def enqueue(self, request: _QwenDeferredResponse) -> None: async with self._pending_condition: if self._closed: return self._pending.append(request) self._pending_condition.notify() async def _prepend(self, request: _QwenDeferredResponse) -> None: async with self._pending_condition: if self._closed: return self._pending.appendleft(request) self._pending_condition.notify() async def next_request(self) -> _QwenDeferredResponse | None: async with self._pending_condition: while not self._pending and not self._closed: await self._pending_condition.wait() if self._pending: return self._pending.popleft() return None async def wait_idle(self) -> bool: async with self._state_condition: while not self._idle and not self._closed: await self._state_condition.wait() return self._idle async def begin_client_response(self, request: _QwenDeferredResponse) -> None: async with self._state_condition: self._idle = False self._current_response = request self._state_condition.notify_all() async def release_client_response(self, request: _QwenDeferredResponse) -> None: async with self._state_condition: if self._current_response is request: self._current_response = None self._idle = True self._state_condition.notify_all() async def mark_response_started(self) -> None: async with self._state_condition: self._idle = False self._state_condition.notify_all() async def mark_response_done(self) -> None: async with self._state_condition: self._idle = True self._current_response = None self._state_condition.notify_all() async with self._pending_condition: self._pending_condition.notify_all() async def mark_active_response_error(self) -> None: retry: _QwenDeferredResponse | None = None async with self._state_condition: if self._current_response is not None: retry = self._current_response retry.item_sent = True self._current_response = None self._idle = False self._state_condition.notify_all() if retry is not None: await self._prepend(retry) async def close(self) -> None: async with self._pending_condition: self._closed = True self._pending_condition.notify_all() async with self._state_condition: self._idle = True self._state_condition.notify_all() async def wait_all_responses_done(self, timeout: float) -> None: async def _wait() -> None: while True: async with self._pending_condition: has_pending = bool(self._pending) closed = self._closed async with self._state_condition: idle = self._idle and self._current_response is None if closed or (not has_pending and idle): return await asyncio.sleep(0.01) try: await asyncio.wait_for(_wait(), timeout=timeout) except asyncio.TimeoutError: logger.warning("qwen_omni timed out waiting for deferred response completion") class _QwenTurnState: def __init__(self, session_id: str) -> None: self.session_id = session_id self.turn_index = 0 self.question_id = "" self.reply_id = "" self.assistant_text = "" self.has_audio = False @property def has_content(self) -> bool: return self.has_audio or bool(self.assistant_text) def ensure_turn(self) -> None: if not self.question_id: self.start_next_turn() def start_next_turn(self) -> None: self.turn_index += 1 self.question_id = f"{self.session_id}_q{self.turn_index}" self.reply_id = f"{self.session_id}_r{self.turn_index}" self.assistant_text = "" self.has_audio = False def reset(self) -> None: self.question_id = "" self.reply_id = "" self.assistant_text = "" self.has_audio = False