import argparse import functools import importlib.util import re import time from pathlib import Path from typing import Optional import gradio as gr import numpy as np import torch import torchaudio from transformers import AutoModel, AutoProcessor # Disable the broken cuDNN SDPA backend torch.backends.cuda.enable_cudnn_sdp(False) # Keep these enabled as fallbacks torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(True) MODEL_PATH = "OpenMOSS-Team/MOSS-TTSD-v1.0" CODEC_MODEL_PATH = "OpenMOSS-Team/MOSS-Audio-Tokenizer" DEFAULT_ATTN_IMPLEMENTATION = "auto" DEFAULT_MAX_NEW_TOKENS = 2000 MIN_SPEAKERS = 1 MAX_SPEAKERS = 5 PRESET_REF_AUDIO_S1 = "assets/audio/reference_02_s1.wav" PRESET_REF_AUDIO_S2 = "assets/audio/reference_02_s2.wav" PRESET_PROMPT_TEXT_S1 = ( "[S1] In short, we embarked on a mission to make America great again for all Americans." ) PRESET_PROMPT_TEXT_S2 = ( "[S2] NVIDIA reinvented computing for the first time after 60 years. In fact, Erwin at IBM knows quite " "well that the computer has largely been the same since the 60s." ) PRESET_DIALOGUE_TEXT = ( "[S1] Listen, let's talk business. China. I'm hearing things.\n" "People are saying they're catching up. Fast. What's the real scoop?\n" "Their AI, is it a threat?\n" "[S2] Well, the pace of innovation there is extraordinary, honestly.\n" "They have the researchers, and they have the drive.\n" "[S1] Extraordinary? I don't like that. I want us to be extraordinary.\n" "Are they winning?\n" "[S2] I wouldn't say winning, but their progress is very promising.\n" "They are building massive clusters. They're very determined.\n" "[S1] Promising. There it is. I hate that word.\n" "When China is promising, it means we're losing.\n" "It's a disaster, Jensen. A total disaster." ) PRESET_EXAMPLES = [ { "name": "Quick Start | reference_02_s1/s2", "speaker_count": 2, "s1_audio": PRESET_REF_AUDIO_S1, "s1_prompt": PRESET_PROMPT_TEXT_S1, "s2_audio": PRESET_REF_AUDIO_S2, "s2_prompt": PRESET_PROMPT_TEXT_S2, "dialogue_text": PRESET_DIALOGUE_TEXT, } ] PRESET_DISPLAY_FIELDS = [ ("Speaker Count", "speaker_count"), ("S1 Reference Audio (Optional)", "s1_audio"), ("S1 Prompt Text (Required with reference audio)", "s1_prompt"), ("S2 Reference Audio (Optional)", "s2_audio"), ("S2 Prompt Text (Required with reference audio)", "s2_prompt"), ("Dialogue Text", "dialogue_text"), ] def _build_preset_table_rows(): rows = [] row_to_preset = [] for preset_idx, preset in enumerate(PRESET_EXAMPLES): for field_name, field_key in PRESET_DISPLAY_FIELDS: value = str(preset.get(field_key, "")) if field_key == "dialogue_text": value = value.replace("\n", " ").strip() if len(value) > 120: value = value[:120] + " ..." rows.append([field_name, value]) row_to_preset.append(preset_idx) return rows, row_to_preset PRESET_TABLE_ROWS, PRESET_TABLE_ROW_TO_PRESET = _build_preset_table_rows() def resolve_attn_implementation(requested: str, device: torch.device, dtype: torch.dtype) -> str | None: requested_norm = (requested or "").strip().lower() if requested_norm in {"none"}: return None if requested_norm not in {"", "auto"}: return requested # Prefer FlashAttention 2 when package + device conditions are met. if ( device.type == "cuda" and importlib.util.find_spec("flash_attn") is not None and dtype in {torch.float16, torch.bfloat16} ): major, _ = torch.cuda.get_device_capability(device) if major >= 8: return "flash_attention_2" # CUDA fallback: use PyTorch SDPA kernels. if device.type == "cuda": return "sdpa" # CPU fallback. return "eager" @functools.lru_cache(maxsize=1) def load_backend(model_path: str, codec_path: str, device_str: str, attn_implementation: str): device = torch.device(device_str if torch.cuda.is_available() else "cpu") dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 resolved_attn_implementation = resolve_attn_implementation( requested=attn_implementation, device=device, dtype=dtype, ) processor = AutoProcessor.from_pretrained( model_path, trust_remote_code=True, codec_path=codec_path, ) if hasattr(processor, "audio_tokenizer"): processor.audio_tokenizer = processor.audio_tokenizer.to(device) processor.audio_tokenizer.eval() model_kwargs = { "trust_remote_code": True, "torch_dtype": dtype, } if resolved_attn_implementation: model_kwargs["attn_implementation"] = resolved_attn_implementation model = AutoModel.from_pretrained(model_path, **model_kwargs).to(device) model.eval() sample_rate = int(getattr(processor.model_config, "sampling_rate", 24000)) return model, processor, device, sample_rate def _resample_wav(wav: torch.Tensor, orig_sr: int, target_sr: int) -> torch.Tensor: if int(orig_sr) == int(target_sr): return wav new_num_samples = int(round(wav.shape[-1] * float(target_sr) / float(orig_sr))) if new_num_samples <= 0: raise ValueError(f"Invalid resample length from {orig_sr}Hz to {target_sr}Hz.") return torch.nn.functional.interpolate( wav.unsqueeze(0), size=new_num_samples, mode="linear", align_corners=False, ).squeeze(0) def _load_audio(audio_path: str) -> tuple[torch.Tensor, int]: path = Path(audio_path).expanduser() if not path.exists(): raise FileNotFoundError(f"Reference audio not found: {path}") wav, sr = torchaudio.load(str(path)) if wav.numel() == 0: raise ValueError(f"Reference audio is empty: {path}") if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) return wav, int(sr) def normalize_text(text: str) -> str: text = re.sub(r"\[(\d+)\]", r"[S\1]", text) remove_chars = "【】《》()『』「」" '"-_“”~~‘’' segments = re.split(r"(?=\[S\d+\])", text.replace("\n", " ")) processed_parts = [] for seg in segments: seg = seg.strip() if not seg: continue matched = re.match(r"^(\[S\d+\])\s*(.*)", seg) tag, content = matched.groups() if matched else ("", seg) content = re.sub(f"[{re.escape(remove_chars)}]", "", content) content = re.sub(r"哈{2,}", "[笑]", content) content = re.sub(r"\b(ha(\s*ha)+)\b", "[laugh]", content, flags=re.IGNORECASE) content = content.replace("——", ",") content = content.replace("……", ",") content = content.replace("...", ",") content = content.replace("⸺", ",") content = content.replace("―", ",") content = content.replace("—", ",") content = content.replace("…", ",") internal_punct_map = str.maketrans( {";": ",", ";": ",", ":": ",", ":": ",", "、": ","} ) content = content.translate(internal_punct_map) content = content.strip() content = re.sub(r"([,。?!,.?!])[,。?!,.?!]+", r"\1", content) if len(content) > 1: last_ch = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1]) body = content[:-1].replace("。", ",") content = body + last_ch processed_parts.append({"tag": tag, "content": content}) if not processed_parts: return "" merged_lines = [] current_tag = processed_parts[0]["tag"] current_content = [processed_parts[0]["content"]] for part in processed_parts[1:]: if part["tag"] == current_tag and current_tag: current_content.append(part["content"]) else: merged_lines.append(f"{current_tag}{''.join(current_content)}".strip()) current_tag = part["tag"] current_content = [part["content"]] merged_lines.append(f"{current_tag}{''.join(current_content)}".strip()) return "".join(merged_lines).replace("‘", "'").replace("’", "'") def _validate_dialogue_text(dialogue_text: str, speaker_count: int) -> str: text = (dialogue_text or "").strip() if not text: raise ValueError("Please enter dialogue text.") tags = re.findall(r"\[S(\d+)\]", text) if not tags: raise ValueError("Dialogue must include speaker tags like [S1], [S2], ...") max_tag = max(int(t) for t in tags) if max_tag > speaker_count: raise ValueError( f"Dialogue contains [S{max_tag}], but speaker count is set to {speaker_count}." ) return text def update_speaker_panels(speaker_count: int): count = int(speaker_count) count = max(MIN_SPEAKERS, min(MAX_SPEAKERS, count)) return [gr.update(visible=(idx < count)) for idx in range(MAX_SPEAKERS)] def apply_preset_selection(evt: gr.SelectData): if evt is None or evt.index is None: return ( gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *[gr.update() for _ in range(MAX_SPEAKERS)], ) if isinstance(evt.index, (tuple, list)): row_idx = int(evt.index[0]) else: row_idx = int(evt.index) if row_idx < 0 or row_idx >= len(PRESET_TABLE_ROW_TO_PRESET): return ( gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *[gr.update() for _ in range(MAX_SPEAKERS)], ) preset_idx = PRESET_TABLE_ROW_TO_PRESET[row_idx] if preset_idx < 0 or preset_idx >= len(PRESET_EXAMPLES): return ( gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), *[gr.update() for _ in range(MAX_SPEAKERS)], ) preset = PRESET_EXAMPLES[preset_idx] panel_updates = update_speaker_panels(int(preset["speaker_count"])) return ( gr.update(value=int(preset["speaker_count"])), gr.update(value=str(preset["s1_audio"])), gr.update(value=str(preset["s1_prompt"])), gr.update(value=str(preset["s2_audio"])), gr.update(value=str(preset["s2_prompt"])), gr.update(value=str(preset["dialogue_text"])), *panel_updates, ) def _merge_consecutive_speaker_tags(text: str) -> str: segments = re.split(r"(?=\[S\d+\])", text) if not segments: return text merged_parts = [] current_tag = None for seg in segments: seg = seg.strip() if not seg: continue matched = re.match(r"^(\[S\d+\])\s*(.*)", seg, re.DOTALL) if not matched: merged_parts.append(seg) continue tag, content = matched.groups() if tag == current_tag: merged_parts.append(content) else: current_tag = tag merged_parts.append(f"{tag}{content}") return "".join(merged_parts) def _normalize_prompt_text(prompt_text: str, speaker_id: int) -> str: text = (prompt_text or "").strip() if not text: raise ValueError(f"S{speaker_id} prompt text is empty.") expected_tag = f"[S{speaker_id}]" if not text.lstrip().startswith(expected_tag): text = f"{expected_tag} {text}" return text def _build_prefixed_text( dialogue_text: str, prompt_text_map: dict[int, str], cloned_speakers: list[int], ) -> str: prompt_prefix = "".join([prompt_text_map[speaker_id] for speaker_id in cloned_speakers]) return _merge_consecutive_speaker_tags(prompt_prefix + dialogue_text) def _encode_reference_audio_codes( processor, clone_wavs: list[torch.Tensor], cloned_speakers: list[int], speaker_count: int, sample_rate: int, ) -> list[Optional[torch.Tensor]]: encoded_list = processor.encode_audios_from_wav(clone_wavs, sampling_rate=sample_rate) reference_audio_codes: list[Optional[torch.Tensor]] = [None for _ in range(speaker_count)] for speaker_id, audio_codes in zip(cloned_speakers, encoded_list): reference_audio_codes[speaker_id - 1] = audio_codes return reference_audio_codes def build_conversation( dialogue_text: str, reference_audio_codes: list[Optional[torch.Tensor]], prompt_audio: torch.Tensor | None, processor, ): if prompt_audio is None: return [[processor.build_user_message(text=dialogue_text)]], "generation", "Generation" user_message = processor.build_user_message( text=dialogue_text, reference=reference_audio_codes, ) return ( [ [ user_message, processor.build_assistant_message(audio_codes_list=[prompt_audio]), ], ], "continuation", "voice_clone_and_continuation", ) def run_inference(speaker_count: int, *all_inputs): speaker_count = int(speaker_count) speaker_count = max(MIN_SPEAKERS, min(MAX_SPEAKERS, speaker_count)) reference_audio_values = all_inputs[:MAX_SPEAKERS] prompt_text_values = all_inputs[MAX_SPEAKERS : 2 * MAX_SPEAKERS] dialogue_text = all_inputs[2 * MAX_SPEAKERS] text_normalize, sample_rate_normalize, temperature, top_p, top_k, repetition_penalty, max_new_tokens, model_path, codec_path, device, attn_implementation = all_inputs[ 2 * MAX_SPEAKERS + 1 : ] started_at = time.monotonic() model, processor, torch_device, sample_rate = load_backend( model_path=str(model_path), codec_path=str(codec_path), device_str=str(device), attn_implementation=str(attn_implementation), ) text_normalize = bool(text_normalize) sample_rate_normalize = bool(sample_rate_normalize) normalized_dialogue = str(dialogue_text or "").strip() if text_normalize: normalized_dialogue = normalize_text(normalized_dialogue) normalized_dialogue = _validate_dialogue_text(normalized_dialogue, speaker_count) cloned_speakers: list[int] = [] loaded_clone_wavs: list[tuple[torch.Tensor, int]] = [] prompt_text_map: dict[int, str] = {} for idx in range(speaker_count): ref_audio = reference_audio_values[idx] prompt_text = str(prompt_text_values[idx] or "").strip() has_reference = bool(ref_audio) has_prompt_text = bool(prompt_text) if has_reference != has_prompt_text: raise ValueError( f"S{idx + 1} must provide both reference audio and prompt text together." ) if has_reference: speaker_id = idx + 1 ref_audio_path = str(ref_audio) cloned_speakers.append(speaker_id) loaded_clone_wavs.append(_load_audio(ref_audio_path)) prompt_text_map[speaker_id] = _normalize_prompt_text(prompt_text, speaker_id) prompt_audio: Optional[torch.Tensor] = None reference_audio_codes: list[Optional[torch.Tensor]] = [] conversation_text = normalized_dialogue if cloned_speakers: conversation_text = _build_prefixed_text( dialogue_text=normalized_dialogue, prompt_text_map=prompt_text_map, cloned_speakers=cloned_speakers, ) if text_normalize: conversation_text = normalize_text(conversation_text) conversation_text = _validate_dialogue_text(conversation_text, speaker_count) if sample_rate_normalize: min_sr = min(sr for _, sr in loaded_clone_wavs) else: min_sr = None clone_wavs: list[torch.Tensor] = [] for wav, orig_sr in loaded_clone_wavs: processed_wav = wav current_sr = int(orig_sr) if min_sr is not None: processed_wav = _resample_wav(processed_wav, current_sr, int(min_sr)) current_sr = int(min_sr) processed_wav = _resample_wav(processed_wav, current_sr, sample_rate) clone_wavs.append(processed_wav) reference_audio_codes = _encode_reference_audio_codes( processor=processor, clone_wavs=clone_wavs, cloned_speakers=cloned_speakers, speaker_count=speaker_count, sample_rate=sample_rate, ) concat_prompt_wav = torch.cat(clone_wavs, dim=-1) prompt_audio = processor.encode_audios_from_wav([concat_prompt_wav], sampling_rate=sample_rate)[0] conversations, mode, mode_name = build_conversation( dialogue_text=conversation_text, reference_audio_codes=reference_audio_codes, prompt_audio=prompt_audio, processor=processor, ) batch = processor(conversations, mode=mode) input_ids = batch["input_ids"].to(torch_device) attention_mask = batch["attention_mask"].to(torch_device) with torch.no_grad(): outputs = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=int(max_new_tokens), audio_temperature=float(temperature), audio_top_p=float(top_p), audio_top_k=int(top_k), audio_repetition_penalty=float(repetition_penalty), ) messages = processor.decode(outputs) if not messages or messages[0] is None: raise RuntimeError("The model did not return a decodable audio result.") audio = messages[0].audio_codes_list[0] if isinstance(audio, torch.Tensor): audio_np = audio.detach().float().cpu().numpy() else: audio_np = np.asarray(audio, dtype=np.float32) if audio_np.ndim > 1: audio_np = audio_np.reshape(-1) audio_np = audio_np.astype(np.float32, copy=False) clone_summary = "none" if not cloned_speakers else ",".join([f"S{i}" for i in cloned_speakers]) elapsed = time.monotonic() - started_at status = ( f"Done | mode={mode_name} | speakers={speaker_count} | cloned={clone_summary} | elapsed={elapsed:.2f}s | " f"text_normalize={text_normalize}, sample_rate_normalize={sample_rate_normalize} | " f"max_new_tokens={int(max_new_tokens)}, " f"audio_temperature={float(temperature):.2f}, audio_top_p={float(top_p):.2f}, " f"audio_top_k={int(top_k)}, audio_repetition_penalty={float(repetition_penalty):.2f}" ) return (sample_rate, audio_np), status def build_demo(args: argparse.Namespace): custom_css = """ :root { --bg: #f6f7f8; --panel: #ffffff; --ink: #111418; --muted: #4d5562; --line: #e5e7eb; --accent: #0f766e; } .gradio-container { background: linear-gradient(180deg, #f7f8fa 0%, #f3f5f7 100%); color: var(--ink); } .app-card { border: 1px solid var(--line); border-radius: 16px; background: var(--panel); padding: 14px; } .app-title { font-size: 22px; font-weight: 700; margin-bottom: 6px; letter-spacing: 0.2px; } .app-subtitle { color: var(--muted); font-size: 14px; margin-bottom: 8px; } #output_panel { overflow: hidden !important; } #output_audio { padding-bottom: 24px; margin-bottom: 0; overflow: hidden !important; } #output_audio > .wrap, #output_audio .wrap, #output_audio .audio-container, #output_audio .block { overflow: hidden !important; } #output_audio .audio-container { padding-bottom: 10px; min-height: 96px; } #output_audio_spacer { height: 12px; } #output_status { margin-top: 0; } #run-btn { background: var(--accent); border: none; } """ with gr.Blocks(title="MOSS-TTSD Demo", css=custom_css) as demo: gr.Markdown( """