import asyncio import json import logging import grpc from inference.core.registry import PluginRegistry from inference.core.types import ( AudioChunk, ImageFrame, VoiceLLMDialogContextItem, VoiceLLMInputEvent, VoiceLLMSessionConfig, ) from inference.generated import common_pb2, voice_llm_pb2, voice_llm_pb2_grpc from inference.plugins.voice_llm.base import VoiceCheckError, VoiceLLMPlugin logger = logging.getLogger(__name__) def _audio_chunk_to_pb(ac: AudioChunk) -> common_pb2.AudioChunk: return common_pb2.AudioChunk( data=ac.data, sample_rate=ac.sample_rate, channels=ac.channels, format=ac.format or "", is_final=ac.is_final, timestamp_ms=ac.timestamp_ms, ) def _session_config_from_pb(cfg: voice_llm_pb2.VoiceLLMConfig) -> VoiceLLMSessionConfig: # Generated *_pb2.py is gitignored; tolerate stale stubs missing dialog_context (field 8). raw_ctx = getattr(cfg, "dialog_context", None) or [] return VoiceLLMSessionConfig( session_id=cfg.session_id, provider=getattr(cfg, "provider", ""), character_id=getattr(cfg, "character_id", ""), character_dir=getattr(cfg, "character_dir", ""), system_prompt=cfg.system_prompt, voice=cfg.voice, bot_name=cfg.bot_name, speaking_style=cfg.speaking_style, welcome_message=cfg.welcome_message, dialog_context=[ VoiceLLMDialogContextItem( role=item.role, text=item.text, timestamp=item.timestamp, ) for item in raw_ctx ], ) class VoiceLLMGRPCService(voice_llm_pb2_grpc.VoiceLLMServiceServicer): def __init__(self, registry: PluginRegistry) -> None: self.registry = registry def _get_plugin(self, provider: str = "") -> VoiceLLMPlugin: provider = provider.strip() if provider: try: plugin = self.registry.get(f"persona.{provider}") except KeyError: try: plugin = self.registry.get(f"omni.{provider}") except KeyError: try: plugin = self.registry.get(f"voice_llm.{provider}") except KeyError: plugin = None else: plugin = self.registry.get_by_category("persona") if plugin is None: plugin = self.registry.get_by_category("omni") if plugin is None: plugin = self.registry.get_by_category("voice_llm") if plugin is None: suffix = f" for provider {provider!r}" if provider else "" raise RuntimeError(f"No omni model plugin initialized{suffix}") return plugin @staticmethod def _input_event_from_pb(msg: voice_llm_pb2.VoiceLLMInput) -> VoiceLLMInputEvent | None: which = msg.WhichOneof("input") if which == "audio": return VoiceLLMInputEvent(audio=msg.audio.data) if which == "text": return VoiceLLMInputEvent(text=msg.text) if which == "image": return VoiceLLMInputEvent( image=ImageFrame( data=msg.image.data, mime_type=msg.image.mime_type, width=msg.image.width, height=msg.image.height, source=msg.image.source, timestamp_ms=msg.image.timestamp_ms, frame_seq=msg.image.frame_seq, ) ) return None async def Converse(self, request_iterator, context): """Stream user audio/text to an omni model (e.g. Doubao); yield audio + transcripts only. Avatar video is produced by AvatarService.GenerateStream; the Go orchestrator composes omni model output with that stream. """ # Phase 1: read the config message and first input event. session_config: VoiceLLMSessionConfig | None = None first_input: VoiceLLMInputEvent | None = None async for msg in request_iterator: which = msg.WhichOneof("input") if which == "config": session_config = _session_config_from_pb(msg.config) logger.debug( "Omni session config: voice=%r bot_name=%r system_prompt=%r welcome=%r", session_config.voice, session_config.bot_name, session_config.system_prompt[:50] if session_config.system_prompt else "", session_config.welcome_message[:50] if session_config.welcome_message else "", ) continue first_input = self._input_event_from_pb(msg) break if session_config is None: session_config = VoiceLLMSessionConfig() if first_input is not None: session_config.input_mode = "text" if first_input.text else "keep_alive" plugin = self._get_plugin(session_config.provider) # Phase 2: stream remaining messages as unified input events. async def input_stream(): if first_input is not None: yield first_input async for msg in request_iterator: event = self._input_event_from_pb(msg) if event is not None: yield event async for event in plugin.converse_stream(input_stream(), session_config=session_config): output = voice_llm_pb2.VoiceLLMOutput(is_final=event.is_final) if event.audio: output.audio.CopyFrom(_audio_chunk_to_pb(event.audio)) if event.transcript: output.transcript = event.transcript if event.user_transcript: output.user_transcript = event.user_transcript if event.question_id: output.question_id = event.question_id if event.reply_id: output.reply_id = event.reply_id if event.barge_in: output.barge_in = True if event.task_event: output.task_event_json = json.dumps(event.task_event, ensure_ascii=False) yield output async def CheckVoice(self, request, context): try: session_config = _session_config_from_pb(request.config) plugin = self._get_plugin(session_config.provider) except Exception as exc: context.set_code(grpc.StatusCode.INTERNAL) context.set_details(str(exc)) return voice_llm_pb2.CheckVoiceResponse(ok=False) try: await asyncio.wait_for( plugin.check_voice(session_config=session_config), timeout=4.5, ) return voice_llm_pb2.CheckVoiceResponse(ok=True) except VoiceCheckError as exc: return voice_llm_pb2.CheckVoiceResponse( ok=False, provider_error=str(exc), ) except asyncio.TimeoutError: context.set_code(grpc.StatusCode.DEADLINE_EXCEEDED) context.set_details("voice check timed out") return voice_llm_pb2.CheckVoiceResponse(ok=False) except Exception as exc: context.set_code(grpc.StatusCode.INTERNAL) context.set_details(str(exc)) return voice_llm_pb2.CheckVoiceResponse(ok=False) async def Interrupt(self, request, context): plugins = [ *self.registry.get_all_by_category("omni"), *self.registry.get_all_by_category("voice_llm"), ] if not plugins: plugin = self._get_plugin() plugins = [plugin] await asyncio.gather(*(plugin.interrupt() for plugin in plugins)) return voice_llm_pb2.InterruptResponse(success=True)