import argparse import functools import importlib.util from pathlib import Path import re import time import orjson import os import gradio as gr import numpy as np import torch from transformers import AutoModel, AutoProcessor # Disable the broken cuDNN SDPA backend torch.backends.cuda.enable_cudnn_sdp(False) # Keep these enabled as fallbacks torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(True) MODEL_PATH = "OpenMOSS-Team/MOSS-TTS-v1.5" DEFAULT_ATTN_IMPLEMENTATION = "auto" DEFAULT_MAX_NEW_TOKENS = 4096 CONTINUATION_NOTICE = ( "Continuation mode is active. Make sure the reference audio transcript is prepended to the input text." ) MODE_CLONE = "Clone" MODE_CONTINUE = "Continuation" MODE_CONTINUE_CLONE = "Continuation + Clone" ZH_TOKENS_PER_CHAR = 3.098411951313033 EN_TOKENS_PER_CHAR = 0.8673376262755219 REFERENCE_AUDIO_DIR = Path(__file__).resolve().parent.parent / "assets" / "audio" EXAMPLE_TEXTS_JSONL_PATH = Path(__file__).resolve().parent.parent / "assets" / "text" / "moss_tts_example_texts.jsonl" LANGUAGE_TAG_AUTO = "Auto (omit)" LANGUAGE_TAG_CHOICES = [ LANGUAGE_TAG_AUTO, "Chinese", "Cantonese", "English", "Arabic", "Czech", "Danish", "Dutch", "Finnish", "French", "German", "Greek", "Hebrew", "Hindi", "Hungarian", "Italian", "Japanese", "Korean", "Macedonian", "Malay", "Persian (Farsi)", "Polish", "Portuguese", "Romanian", "Russian", "Spanish", "Swahili", "Swedish", "Tagalog", "Thai", "Turkish", "Vietnamese", ] def _parse_example_id(example_id: str) -> tuple[str, int] | None: matched = re.fullmatch(r"(zh|en)/(\d+)", (example_id or "").strip()) if matched is None: return None return matched.group(1), int(matched.group(2)) def _resolve_reference_audio_path(language: str, index: int) -> Path | None: stem_candidates = [f"reference_{language}_{index}"] for stem in stem_candidates: for ext in (".wav", ".mp3"): audio_path = REFERENCE_AUDIO_DIR / f"{stem}{ext}" if audio_path.exists(): return audio_path return None def build_example_rows() -> list[tuple[str, str, str]]: rows: list[tuple[str, str, str]] = [] with open(EXAMPLE_TEXTS_JSONL_PATH, "rb") as f: for line in f: if not line.strip(): continue sample = orjson.loads(line) parsed = _parse_example_id(sample.get("id", "")) if parsed is None: continue language, index = parsed text = str(sample.get("text", "")).strip() audio_path = _resolve_reference_audio_path(language, index) if audio_path is None: continue rows.append((sample['role'], str(audio_path), text)) return rows EXAMPLE_ROWS = build_example_rows() @functools.lru_cache(maxsize=1) def load_backend(model_path: str, device_str: str, attn_implementation: str): device = torch.device(device_str if torch.cuda.is_available() else "cpu") dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 resolved_attn_implementation = resolve_attn_implementation( requested=attn_implementation, device=device, dtype=dtype, ) processor = AutoProcessor.from_pretrained( model_path, trust_remote_code=True, ) if hasattr(processor, "audio_tokenizer"): processor.audio_tokenizer = processor.audio_tokenizer.to(device) model_kwargs = { "trust_remote_code": True, "torch_dtype": dtype, } if resolved_attn_implementation: model_kwargs["attn_implementation"] = resolved_attn_implementation model = AutoModel.from_pretrained(model_path, **model_kwargs).to(device) model.eval() sample_rate = int(getattr(processor.model_config, "sampling_rate", 24000)) return model, processor, device, sample_rate def resolve_attn_implementation(requested: str, device: torch.device, dtype: torch.dtype) -> str | None: requested_norm = (requested or "").strip().lower() if requested_norm in {"none"}: return None if requested_norm not in {"", "auto"}: return requested # Prefer FlashAttention 2 when package + device conditions are met. if ( device.type == "cuda" and importlib.util.find_spec("flash_attn") is not None and dtype in {torch.float16, torch.bfloat16} ): major, _ = torch.cuda.get_device_capability(device) if major >= 8: return "flash_attention_2" # CUDA fallback: use PyTorch SDPA kernels. if device.type == "cuda": return "sdpa" # CPU fallback. return "eager" def detect_text_language(text: str) -> str: zh_chars = len(re.findall(r"[\u4e00-\u9fff]", text)) en_chars = len(re.findall(r"[A-Za-z]", text)) if zh_chars == 0 and en_chars == 0: return "en" return "zh" if zh_chars >= en_chars else "en" def supports_duration_control(mode_with_reference: str) -> bool: return mode_with_reference not in {MODE_CONTINUE, MODE_CONTINUE_CLONE} def estimate_duration_tokens(text: str) -> tuple[str, int, int, int]: normalized = text or "" effective_len = max(len(normalized), 1) language = detect_text_language(normalized) factor = ZH_TOKENS_PER_CHAR if language == "zh" else EN_TOKENS_PER_CHAR default_tokens = max(1, int(effective_len * factor)) min_tokens = max(1, int(default_tokens * 0.5)) max_tokens = max(min_tokens, int(default_tokens * 1.5)) return language, default_tokens, min_tokens, max_tokens def update_duration_controls( enabled: bool, text: str, current_tokens: float | int | None, mode_with_reference: str, ): if not supports_duration_control(mode_with_reference): return ( gr.update(visible=False), "Duration control is disabled for Continuation modes.", gr.update(value=False, interactive=False), ) checkbox_update = gr.update(interactive=True) if not enabled: return gr.update(visible=False), "Duration control is disabled.", checkbox_update language, default_tokens, min_tokens, max_tokens = estimate_duration_tokens(text) # Slider is initialized with value=1 as a placeholder; treat it as "unset" # so first-time estimation uses the computed default instead of clamping to min. if current_tokens is None or int(current_tokens) == 1: slider_value = default_tokens else: slider_value = int(current_tokens) slider_value = max(min_tokens, min(max_tokens, slider_value)) language_label = "Chinese" if language == "zh" else "English" hint = ( f"Duration control enabled | detected language: {language_label} | " f"default={default_tokens}, range=[{min_tokens}, {max_tokens}]" ) return ( gr.update( visible=True, minimum=min_tokens, maximum=max_tokens, value=slider_value, step=1, ), hint, checkbox_update, ) def normalize_language_tag(language_tag: str | None) -> str | None: language_tag = (language_tag or "").strip() if not language_tag or language_tag == LANGUAGE_TAG_AUTO: return None return language_tag def build_conversation( text: str, reference_audio: str | None, mode_with_reference: str, expected_tokens: int | None, language_tag: str | None, processor, ): text = (text or "").strip() if not text: raise ValueError("Please enter text to synthesize.") user_kwargs = {"text": text} normalized_language = normalize_language_tag(language_tag) if normalized_language is not None: user_kwargs["language"] = normalized_language if expected_tokens is not None: user_kwargs["tokens"] = int(expected_tokens) if not reference_audio: conversations = [[processor.build_user_message(**user_kwargs)]] return conversations, "generation", "Direct Generation" if mode_with_reference == MODE_CLONE: clone_kwargs = dict(user_kwargs) clone_kwargs["reference"] = [reference_audio] conversations = [[processor.build_user_message(**clone_kwargs)]] return conversations, "generation", MODE_CLONE if mode_with_reference == MODE_CONTINUE: conversations = [ [ processor.build_user_message(**user_kwargs), processor.build_assistant_message(audio_codes_list=[reference_audio]), ] ] return conversations, "continuation", MODE_CONTINUE continue_clone_kwargs = dict(user_kwargs) continue_clone_kwargs["reference"] = [reference_audio] conversations = [ [ processor.build_user_message(**continue_clone_kwargs), processor.build_assistant_message(audio_codes_list=[reference_audio]), ] ] return conversations, "continuation", MODE_CONTINUE_CLONE def render_mode_hint(reference_audio: str | None, mode_with_reference: str): if not reference_audio: return "Current mode: **Direct Generation** (no reference audio uploaded)" if mode_with_reference == MODE_CLONE: return "Current mode: **Clone** (speaker timbre will be cloned from the reference audio)" return f"Current mode: **{mode_with_reference}** \n> {CONTINUATION_NOTICE}" def apply_example_selection( mode_with_reference: str, duration_control_enabled: bool, duration_tokens: int, evt: gr.SelectData, ): if evt is None or evt.index is None: return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update() if isinstance(evt.index, (tuple, list)): row_idx = int(evt.index[0]) else: row_idx = int(evt.index) if row_idx < 0 or row_idx >= len(EXAMPLE_ROWS): return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update() _, audio_path, example_text = EXAMPLE_ROWS[row_idx] duration_slider_update, duration_hint, duration_checkbox_update = update_duration_controls( duration_control_enabled, example_text, duration_tokens, mode_with_reference, ) return ( audio_path, example_text, render_mode_hint(audio_path, mode_with_reference), duration_slider_update, duration_hint, duration_checkbox_update, ) def run_inference( text: str, reference_audio: str | None, mode_with_reference: str, duration_control_enabled: bool, duration_tokens: int, language_tag: str | None, temperature: float, top_p: float, top_k: int, repetition_penalty: float, model_path: str, device: str, attn_implementation: str, max_new_tokens: int, ): started_at = time.monotonic() model, processor, torch_device, sample_rate = load_backend( model_path=model_path, device_str=device, attn_implementation=attn_implementation, ) duration_enabled = bool(duration_control_enabled and supports_duration_control(mode_with_reference)) expected_tokens = int(duration_tokens) if duration_enabled else None conversations, mode, mode_name = build_conversation( text=text, reference_audio=reference_audio, mode_with_reference=mode_with_reference, expected_tokens=expected_tokens, language_tag=language_tag, processor=processor, ) batch = processor(conversations, mode=mode) input_ids = batch["input_ids"].to(torch_device) attention_mask = batch["attention_mask"].to(torch_device) with torch.no_grad(): outputs = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=int(max_new_tokens), audio_temperature=float(temperature), audio_top_p=float(top_p), audio_top_k=int(top_k), audio_repetition_penalty=float(repetition_penalty), ) messages = processor.decode(outputs) if not messages or messages[0] is None: raise RuntimeError("The model did not return a decodable audio result.") audio = messages[0].audio_codes_list[0] if isinstance(audio, torch.Tensor): audio_np = audio.detach().float().cpu().numpy() else: audio_np = np.asarray(audio, dtype=np.float32) if audio_np.ndim > 1: audio_np = audio_np.reshape(-1) audio_np = audio_np.astype(np.float32, copy=False) elapsed = time.monotonic() - started_at normalized_language = normalize_language_tag(language_tag) status = ( f"Done | mode: {mode_name} | language={normalized_language or 'auto'} | " f"elapsed: {elapsed:.2f}s | " f"max_new_tokens={int(max_new_tokens)}, " f"expected_tokens={expected_tokens if expected_tokens is not None else 'off'}, " f"audio_temperature={float(temperature):.2f}, audio_top_p={float(top_p):.2f}, " f"audio_top_k={int(top_k)}, audio_repetition_penalty={float(repetition_penalty):.2f}" ) return (sample_rate, audio_np), status def build_demo(args: argparse.Namespace): custom_css = """ :root { --bg: #f6f7f8; --panel: #ffffff; --ink: #111418; --muted: #4d5562; --line: #e5e7eb; --accent: #0f766e; } .gradio-container { background: linear-gradient(180deg, #f7f8fa 0%, #f3f5f7 100%); color: var(--ink); } .app-card { border: 1px solid var(--line); border-radius: 16px; background: var(--panel); padding: 14px; } .app-title { font-size: 22px; font-weight: 700; margin-bottom: 6px; letter-spacing: 0.2px; } .app-subtitle { color: var(--muted); font-size: 14px; margin-bottom: 8px; } #output_audio { padding-bottom: 12px; margin-bottom: 8px; overflow: hidden !important; } #output_audio > .wrap { overflow: hidden !important; } #output_audio audio { margin-bottom: 6px; } #run-btn { background: var(--accent); border: none; } """ with gr.Blocks(title="MOSS-TTS Demo", css=custom_css) as demo: gr.Markdown( """
MOSS-TTS v1.5
Direct Generation, Clone, Continuation, Continuation + Clone, language tags, and inline pause markers
""" ) with gr.Row(equal_height=False): with gr.Column(scale=3): text = gr.Textbox( label="Text", lines=9, placeholder="Enter text to synthesize. In continuation modes, prepend the reference audio transcript.", ) reference_audio = gr.Audio( label="Reference Audio (Optional)", type="filepath", ) mode_with_reference = gr.Radio( choices=[MODE_CLONE, MODE_CONTINUE, MODE_CONTINUE_CLONE], value=MODE_CLONE, label="Mode with Reference Audio", info="If no reference audio is uploaded, Direct Generation will be used automatically.", ) mode_hint = gr.Markdown(render_mode_hint(None, MODE_CLONE)) language_tag = gr.Dropdown( choices=LANGUAGE_TAG_CHOICES, value=LANGUAGE_TAG_AUTO, label="Language Tag", info="Optional for v1.5. Set this when the input language is known, especially outside Chinese and English.", ) duration_control_enabled = gr.Checkbox( value=False, label="Enable Duration Control (Expected Audio Tokens)", ) duration_tokens = gr.Slider( minimum=1, maximum=1, step=1, value=1, label="expected_tokens", visible=False, ) duration_hint = gr.Markdown("Duration control is disabled.") with gr.Accordion("Sampling Parameters (Audio)", open=True): temperature = gr.Slider( minimum=0.1, maximum=3.0, step=0.05, value=1.7, label="temperature", ) top_p = gr.Slider( minimum=0.1, maximum=1.0, step=0.01, value=0.8, label="top_p", ) top_k = gr.Slider( minimum=1, maximum=200, step=1, value=25, label="top_k", ) repetition_penalty = gr.Slider( minimum=0.8, maximum=2.0, step=0.05, value=1.0, label="repetition_penalty", ) max_new_tokens = gr.Slider( minimum=256, maximum=8192, step=128, value=DEFAULT_MAX_NEW_TOKENS, label="max_new_tokens", ) run_btn = gr.Button("Generate Speech", variant="primary", elem_id="run-btn") with gr.Column(scale=2): output_audio = gr.Audio(label="Output Audio", type="numpy", elem_id="output_audio") status = gr.Textbox(label="Status", lines=4, interactive=False) examples_table = gr.Dataframe( headers=["Reference Speech", "Example Text"], value=[[name, text] for name, _, text in EXAMPLE_ROWS], datatype=["str", "str"], row_count=(len(EXAMPLE_ROWS), "fixed"), col_count=(2, "fixed"), interactive=False, wrap=True, label="Examples (click a row to fill inputs)", ) reference_audio.change( fn=render_mode_hint, inputs=[reference_audio, mode_with_reference], outputs=[mode_hint], ) mode_with_reference.change( fn=render_mode_hint, inputs=[reference_audio, mode_with_reference], outputs=[mode_hint], ) duration_control_enabled.change( fn=update_duration_controls, inputs=[duration_control_enabled, text, duration_tokens, mode_with_reference], outputs=[duration_tokens, duration_hint, duration_control_enabled], ) text.change( fn=update_duration_controls, inputs=[duration_control_enabled, text, duration_tokens, mode_with_reference], outputs=[duration_tokens, duration_hint, duration_control_enabled], ) mode_with_reference.change( fn=update_duration_controls, inputs=[duration_control_enabled, text, duration_tokens, mode_with_reference], outputs=[duration_tokens, duration_hint, duration_control_enabled], ) examples_table.select( fn=apply_example_selection, inputs=[mode_with_reference, duration_control_enabled, duration_tokens], outputs=[ reference_audio, text, mode_hint, duration_tokens, duration_hint, duration_control_enabled, ], ) run_btn.click( fn=lambda text, reference_audio, mode_with_reference, duration_control_enabled, duration_tokens, language_tag, temperature, top_p, top_k, repetition_penalty, max_new_tokens: run_inference( text=text, reference_audio=reference_audio, mode_with_reference=mode_with_reference, duration_control_enabled=duration_control_enabled, duration_tokens=duration_tokens, language_tag=language_tag, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, model_path=args.model_path, device=args.device, attn_implementation=args.attn_implementation, max_new_tokens=max_new_tokens, ), inputs=[ text, reference_audio, mode_with_reference, duration_control_enabled, duration_tokens, language_tag, temperature, top_p, top_k, repetition_penalty, max_new_tokens, ], outputs=[output_audio, status], ) return demo def main(): parser = argparse.ArgumentParser(description="MossTTS Gradio Demo") parser.add_argument("--model_path", type=str, default=MODEL_PATH) parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--attn_implementation", type=str, default=DEFAULT_ATTN_IMPLEMENTATION) parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7860) parser.add_argument("--share", action="store_true") args = parser.parse_args() runtime_device = torch.device(args.device if torch.cuda.is_available() else "cpu") runtime_dtype = torch.bfloat16 if runtime_device.type == "cuda" else torch.float32 args.attn_implementation = resolve_attn_implementation( requested=args.attn_implementation, device=runtime_device, dtype=runtime_dtype, ) or "none" print(f"[INFO] Using attn_implementation={args.attn_implementation}", flush=True) # Preload model/processor at startup to avoid first-request cold start latency. preload_started_at = time.monotonic() print( f"[Startup] Preloading backend: model={args.model_path}, device={args.device}, attn={args.attn_implementation}", flush=True, ) load_backend( model_path=args.model_path, device_str=args.device, attn_implementation=args.attn_implementation, ) print( f"[Startup] Backend preload finished in {time.monotonic() - preload_started_at:.2f}s", flush=True, ) demo = build_demo(args) demo.queue(max_size=16, default_concurrency_limit=1).launch( server_name=args.host, server_port=args.port, share=args.share, ) if __name__ == "__main__": main()