import base64 import os import string from dataclasses import dataclass, field from functools import cached_property, lru_cache from typing import Dict, List, Optional, Tuple import tiktoken LANGUAGES = { "en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", "pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese", "he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian", "th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu", "fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian", "br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili", "gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian", "be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole", "ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy", "as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese", "yue": "cantonese", } # language code lookup by name, with a few language aliases TO_LANGUAGE_CODE = { **{language: code for code, language in LANGUAGES.items()}, "burmese": "my", "valencian": "ca", "flemish": "nl", "haitian": "ht", "letzeburgesch": "lb", "pushto": "ps", "panjabi": "pa", "moldavian": "ro", "moldovan": "ro", "sinhalese": "si", "castilian": "es", "mandarin": "zh", } @dataclass class Tokenizer: """A thin wrapper around `tiktoken` providing quick access to special tokens""" encoding: tiktoken.Encoding num_languages: int language: Optional[str] = None task: Optional[str] = None sot_sequence: Tuple[int] = () special_tokens: Dict[str, int] = field(default_factory=dict) def __post_init__(self): for special in self.encoding.special_tokens_set: special_token = self.encoding.encode_single_token(special) self.special_tokens[special] = special_token sot: int = self.special_tokens["<|startoftranscript|>"] translate: int = self.special_tokens["<|translate|>"] transcribe: int = self.special_tokens["<|transcribe|>"] langs = tuple(LANGUAGES.keys())[: self.num_languages] sot_sequence = [sot] if self.language is not None: sot_sequence.append(sot + 1 + langs.index(self.language)) if self.task is not None: task_token: int = transcribe if self.task == "transcribe" else translate sot_sequence.append(task_token) self.sot_sequence = tuple(sot_sequence) def encode(self, text, **kwargs): return self.encoding.encode(text, **kwargs) def decode(self, token_ids: List[int], **kwargs) -> str: token_ids = [t for t in token_ids if t < self.timestamp_begin] return self.encoding.decode(token_ids, **kwargs) def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str: """ Timestamp tokens are above other special tokens' id range and are ignored by `decode()`. This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". """ return self.encoding.decode(token_ids, **kwargs) @cached_property def eot(self) -> int: return self.encoding.eot_token @cached_property def transcribe(self) -> int: return self.special_tokens["<|transcribe|>"] @cached_property def translate(self) -> int: return self.special_tokens["<|translate|>"] @cached_property def sot(self) -> int: return self.special_tokens["<|startoftranscript|>"] @cached_property def sot_lm(self) -> int: return self.special_tokens["<|startoflm|>"] @cached_property def sot_prev(self) -> int: return self.special_tokens["<|startofprev|>"] @cached_property def no_speech(self) -> int: return self.special_tokens["<|nospeech|>"] @cached_property def no_timestamps(self) -> int: return self.special_tokens["<|notimestamps|>"] @cached_property def timestamp_begin(self) -> int: return self.special_tokens["<|0.00|>"] @cached_property def language_token(self) -> int: """Returns the token id corresponding to the value of the `language` field""" if self.language is None: raise ValueError("This tokenizer does not have language token configured") return self.to_language_token(self.language) def to_language_token(self, language): if token := self.special_tokens.get(f"<|{language}|>", None): return token raise KeyError(f"Language {language} not found in tokenizer.") @cached_property def all_language_tokens(self) -> Tuple[int]: result = [] for token, token_id in self.special_tokens.items(): if token.strip("<|>") in LANGUAGES: result.append(token_id) return tuple(result)[: self.num_languages] @cached_property def all_language_codes(self) -> Tuple[str]: return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens) @cached_property def sot_sequence_including_notimestamps(self) -> Tuple[int]: return tuple(list(self.sot_sequence) + [self.no_timestamps]) @cached_property def non_speech_tokens(self) -> Tuple[int]: """ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. - ♪♪♪ - ( SPEAKING FOREIGN LANGUAGE ) - [DAVID] Hey there, keeping basic punctuations like commas, periods, question marks, exclamation points, etc. """ symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') symbols += ( "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() ) # symbols that may be a single token or multiple tokens depending on the tokenizer. # In case they're multiple tokens, suppress the first token, which is safe because: # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. miscellaneous = set("♩♪♫♬♭♮♯") assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]} for symbol in symbols + list(miscellaneous): for tokens in [ self.encoding.encode(symbol), self.encoding.encode(" " + symbol), ]: if len(tokens) == 1 or symbol in miscellaneous: result.add(tokens[0]) return tuple(sorted(result)) def split_to_word_tokens(self, tokens: List[int]): if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: # These languages don't typically use spaces, so it is difficult to split words # without morpheme analysis. Here, we instead split words at any # position where the tokens are decoded as valid unicode points return self.split_tokens_on_unicode(tokens) return self.split_tokens_on_spaces(tokens) def split_tokens_on_unicode(self, tokens: List[int]): decoded_full = self.decode_with_timestamps(tokens) replacement_char = "\ufffd" words = [] word_tokens = [] current_tokens = [] unicode_offset = 0 for token in tokens: current_tokens.append(token) decoded = self.decode_with_timestamps(current_tokens) if ( replacement_char not in decoded or decoded_full[unicode_offset + decoded.index(replacement_char)] == replacement_char ): words.append(decoded) word_tokens.append(current_tokens) current_tokens = [] unicode_offset += len(decoded) return words, word_tokens def split_tokens_on_spaces(self, tokens: List[int]): subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) words = [] word_tokens = [] for subword, subword_tokens in zip(subwords, subword_tokens_list): special = subword_tokens[0] >= self.eot with_space = subword.startswith(" ") punctuation = subword.strip() in string.punctuation if special or with_space or punctuation or len(words) == 0: words.append(subword) word_tokens.append(subword_tokens) else: words[-1] = words[-1] + subword word_tokens[-1].extend(subword_tokens) return words, word_tokens @lru_cache(maxsize=None) def get_encoding(name: str = "gpt2", num_languages: int = 99): vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") ranks = { base64.b64decode(token): int(rank) for token, rank in (line.split() for line in open(vocab_path) if line) } n_vocab = len(ranks) special_tokens = {} specials = [ "<|endoftext|>", "<|startoftranscript|>", *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], "<|translate|>", "<|transcribe|>", "<|startoflm|>", "<|startofprev|>", "<|nospeech|>", "<|notimestamps|>", *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], ] for token in specials: special_tokens[token] = n_vocab n_vocab += 1 return tiktoken.Encoding( name=os.path.basename(vocab_path), explicit_n_vocab=n_vocab, pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", mergeable_ranks=ranks, special_tokens=special_tokens, ) @lru_cache(maxsize=None) def get_tokenizer( multilingual: bool, *, num_languages: int = 99, language: Optional[str] = None, task: Optional[str] = None, # Literal["transcribe", "translate", None] ) -> Tokenizer: if language is not None: language = language.lower() if language not in LANGUAGES: if language in TO_LANGUAGE_CODE: language = TO_LANGUAGE_CODE[language] else: raise ValueError(f"Unsupported language: {language}") if multilingual: encoding_name = "multilingual" language = language or "en" task = task or "transcribe" else: encoding_name = "gpt2" language = None task = None encoding = get_encoding(name=encoding_name, num_languages=num_languages) return Tokenizer( encoding=encoding, num_languages=num_languages, language=language, task=task )