{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from nb_007 import *\n", "\n", "import pandas as pd, re, html, os\n", "\n", "import warnings\n", "warnings.filterwarnings(\"ignore\", message=\"numpy.dtype size changed\")\n", "warnings.filterwarnings(\"ignore\", message=\"numpy.ufunc size changed\")\n", "\n", "import spacy\n", "from spacy.symbols import ORTH\n", "from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# IMDB" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reading the texts" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PATH = Path('data/aclImdb/')\n", "CLAS_PATH = PATH/'clas'\n", "LM_PATH = PATH/'lm'\n", "os.makedirs(CLAS_PATH, exist_ok=True)\n", "os.makedirs(LM_PATH, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "BOS,FLD,UNK,PAD = 'xxbos','xxfld','xxunk','xxpad'\n", "TK_UP,TK_REP,TK_WREP = 'xxup','xxrep','xxwrep'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "CLASSES = ['neg', 'pos', 'unsup']\n", "\n", "def get_texts(path):\n", " texts,labels = [],[]\n", " for idx,label in enumerate(CLASSES):\n", " for fname in (path/label).glob('*.*'):\n", " texts.append(fname.open('r', encoding='utf8').read())\n", " labels.append(idx)\n", " return np.array(texts),np.array(labels)\n", "\n", "train_texts,train_labels = get_texts(PATH/'train')\n", "valid_texts,valid_labels = get_texts(PATH/'test')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_idx = np.random.permutation(len(train_texts))\n", "valid_idx = np.random.permutation(len(valid_texts))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_texts,train_labels = train_texts[train_idx],train_labels[train_idx]\n", "valid_texts,valid_labels = valid_texts[valid_idx],valid_labels[valid_idx]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_df = pd.DataFrame({'text':train_texts, 'labels':train_labels}, columns=['labels','text'])\n", "valid_df = pd.DataFrame({'text':valid_texts, 'labels':valid_labels}, columns=['labels','text'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We put aside the unsup labels for the classification but keep them to finetune the language model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_df[train_df['labels']!=2].to_csv(CLAS_PATH/'train.csv', header=False, index=False)\n", "valid_df.to_csv(CLAS_PATH/'valid.csv', header=False, index=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_texts = np.concatenate([train_texts,valid_texts])\n", "idx = np.random.permutation(len(all_texts))\n", "cut = int(0.1 * len(idx))\n", "train_df = pd.DataFrame({'text':all_texts[idx[cut:]], 'labels':[0] * (len(all_texts)-cut)}, columns=['labels','text'])\n", "valid_df = pd.DataFrame({'text':all_texts[idx[:cut]], 'labels':[0] * cut}, columns=['labels','text'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_df.to_csv(LM_PATH/'train.csv', header=False, index=False)\n", "valid_df.to_csv(LM_PATH/'valid.csv', header=False, index=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Tokenization + Numericalization" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def partition(a:Collection, sz:int) -> List[Collection]: \n", " \"Splits iterables a in equal parts of size sz\"\n", " return [a[i:i+sz] for i in range(0, len(a), sz)]\n", "\n", "def partition_by_cores(a:Collection, n_cpus:int) -> List[Collection]:\n", " \"Split data equally among CPU cores\"\n", " return partition(a, len(a)//n_cpus + 1)\n", "\n", "def num_cpus() -> int:\n", " \"Return the number of CPUs in the system\"\n", " try:\n", " return len(os.sched_getaffinity(0))\n", " except AttributeError:\n", " return os.cpu_count()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class BaseTokenizer():\n", " \"Basic class for a tokenizer function.\"\n", " def __init__(self, lang:str):\n", " self.lang = lang\n", " \n", " def tokenizer(self, t:spacy.tokens.doc.Doc) -> List[str]: raise NotImplementedError\n", " def add_special_cases(self, toks:Collection[str]): raise NotImplementedError \n", "\n", "#export\n", "class SpacyTokenizer(BaseTokenizer):\n", " \"Little wrapper around a `spacy` tokenizer\"\n", " \n", " def __init__(self, lang:str):\n", " self.tok = spacy.load(lang)\n", " \n", " def tokenizer(self, t:spacy.tokens.doc.Doc) -> List[str]:\n", " return [t.text for t in self.tok.tokenizer(t)]\n", " \n", " def add_special_cases(self, toks:Collection[str]):\n", " for w in toks:\n", " self.tok.tokenizer.add_special_case(w, [{ORTH: w}])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_df = pd.read_csv(LM_PATH/'train.csv', header=None, chunksize=10)\n", "trn_df = next(train_df)\n", "test_tok = SpacyTokenizer('en')\n", "test_txt = trn_df.iloc[0][1]\n", "test_tok.tokenizer(test_txt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def sub_br(t:str) -> str:\n", " \"Replaces the
by \\n\"\n", " re_br = re.compile(r'<\\s*br\\s*/?>', re.IGNORECASE)\n", " return re_br.sub(\"\\n\", t)\n", "\n", "def spec_add_spaces(t:str) -> str:\n", " \"Add spaces between special characters\"\n", " return re.sub(r'([/#])', r' \\1 ', t)\n", "\n", "def rm_useless_spaces(t:str) -> str:\n", " \"Remove multiple spaces\"\n", " return re.sub(' {2,}', ' ', t)\n", "\n", "def replace_rep(t:str) -> str:\n", " \"Replace repetitions at the character level\"\n", " def _replace_rep(m:Collection[str]) -> str:\n", " c,cc = m.groups()\n", " return f' {TK_REP} {len(cc)+1} {c} '\n", " re_rep = re.compile(r'(\\S)(\\1{3,})')\n", " return re_rep.sub(_replace_rep, t)\n", " \n", "def replace_wrep(t:str) -> str:\n", " \"Replace word repetitions\"\n", " def _replace_wrep(m:Collection[str]) -> str:\n", " c,cc = m.groups()\n", " return f' {TK_WREP} {len(cc.split())+1} {c} '\n", " re_wrep = re.compile(r'(\\b\\w+\\W+)(\\1{3,})')\n", " return re_wrep.sub(_replace_wrep, t)\n", "\n", "def deal_caps(t:str) -> str:\n", " \"Replace words in all caps\"\n", " res = []\n", " for s in re.findall(r'\\w+|\\W+', t):\n", " res += ([f' {TK_UP} ',s.lower()] if (s.isupper() and (len(s)>2)) else [s.lower()])\n", " return ''.join(res)\n", "\n", "def fixup(x:str) -> str:\n", " \"List of replacements from html strings\"\n", " re1 = re.compile(r' +')\n", " x = x.replace('#39;', \"'\").replace('amp;', '&').replace('#146;', \"'\").replace(\n", " 'nbsp;', ' ').replace('#36;', '$').replace('\\\\n', \"\\n\").replace('quot;', \"'\").replace(\n", " '
', \"\\n\").replace('\\\\\"', '\"').replace('',UNK).replace(' @.@ ','.').replace(\n", " ' @-@ ','-').replace('\\\\', ' \\\\ ')\n", " return re1.sub(' ', html.unescape(x))\n", " \n", "default_rules = [fixup, replace_rep, replace_wrep, deal_caps, spec_add_spaces, rm_useless_spaces, sub_br]\n", "default_spec_tok = [BOS, FLD, UNK, PAD]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "\n", "ListRules = Collection[Callable[[str],str]]\n", "\n", "class Tokenizer():\n", " \"Puts together rules, a tokenizer function and a language to process text with multiprocessing\"\n", " def __init__(self, tok_fn:Callable=SpacyTokenizer, lang:str='en', rules:ListRules=None, \n", " special_cases:Collection[str]=None, n_cpus:int=None):\n", " self.tok_fn,self.lang,self.special_cases = tok_fn,lang,special_cases\n", " self.rules = rules if rules else default_rules\n", " self.special_cases = special_cases if special_cases else default_spec_tok\n", " self.n_cpus = n_cpus or num_cpus()//2\n", " \n", " def __repr__(self) -> str:\n", " res = f'Tokenizer {self.tok_fn.__name__} in {self.lang} with the following rules:\\n'\n", " for rule in self.rules: res += f' - {rule.__name__}\\n'\n", " return res\n", " \n", " def proc_text(self, t:str, tok:BaseTokenizer) -> List[str]:\n", " \"Processes one text\"\n", " for rule in self.rules: t = rule(t)\n", " return tok.tokenizer(t)\n", " \n", " def process_all_1(self, texts:Collection[str]) -> List[List[str]]:\n", " \"Processes a list of texts in one process\"\n", " tok = self.tok_fn(self.lang)\n", " if self.special_cases: tok.add_special_cases(self.special_cases)\n", " return [self.proc_text(t, tok) for t in texts]\n", "\n", " def process_all(self, texts:Collection[str]) -> List[List[str]]:\n", " \"Processes a list of texts in several processes\"\n", " if self.n_cpus <= 1: return self.process_all_1(texts)\n", " with ProcessPoolExecutor(self.n_cpus) as e:\n", " return sum(e.map(self.process_all_1, partition_by_cores(texts, self.n_cpus)), [])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tokenizer = Tokenizer(rules=default_rules, special_cases=[BOS, FLD, 'xxunk', 'xxpad'], n_cpus=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tokenizer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert sub_br('end

begins again')=='end \\n\\n begins again'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert spec_add_spaces('\\#%') == '\\\\ # %'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert rm_useless_spaces('this is') == 'this is'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert replace_rep('ffffffive .') == ' xxrep 6 f ive .'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert replace_wrep('five five five five .') == ' xxwrep 4 five .'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert deal_caps('ANGRY') == ' xxup angry'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def get_chunk_length(csv_name:PathOrStr, chunksize:int) -> int:\n", " \"Reads the number of chunks in a pandas `DataFrame`\"\n", " dfs = pd.read_csv(csv_name, header=None, chunksize=chunksize)\n", " l = 0\n", " for _ in dfs: l+=1\n", " return l\n", "\n", "def get_total_length(csv_name:PathOrStr, chunksize:int) -> int:\n", " \"Reads the the total length of a pandas `DataFrame`\"\n", " dfs = pd.read_csv(csv_name, header=None, chunksize=chunksize)\n", " l = 0\n", " for df in dfs: l+=len(df)\n", " return l" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def maybe_copy(old_fnames:Collection[PathOrStr], new_fnames:Collection[PathOrStr]):\n", " \"Copies the `old_fnames` to `new_fnames` location if new_fnames don't exist or are less recent.\"\n", " os.makedirs(os.path.dirname(new_fnames[0]), exist_ok=True)\n", " for old_fname,new_fname in zip(old_fnames, new_fnames): \n", " if not os.path.isfile(new_fname) or os.path.getmtime(new_fname) < os.path.getmtime(old_fname):\n", " shutil.copyfile(old_fname, new_fname)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "import hashlib\n", "\n", "Tokens = Collection[Collection[str]]\n", "\n", "class Vocab():\n", " \"Contains the correspondance between numbers and tokens and numericalizes\"\n", " \n", " def __init__(self, path:PathOrStr):\n", " self.itos = pickle.load(open(path/'itos.pkl', 'rb'))\n", " self.stoi = collections.defaultdict(int,{v:k for k,v in enumerate(self.itos)})\n", " \n", " def numericalize(self, t:Collection[str]) -> List[int]:\n", " \"Converts a list of tokens to their ids\"\n", " return [self.stoi[w] for w in t]\n", " \n", " def textify(self, nums:Collection[int]) -> List[str]:\n", " \"Converts a list of ids to their tokens\"\n", " return ' '.join([self.itos[i] for i in nums])\n", " \n", " @classmethod\n", " def create(cls, path:PathOrStr, tokens:Tokens, max_vocab:int, min_freq:int) -> 'Vocab':\n", " \"Create a vocabulary from a set of tokens.\"\n", " freq = Counter(p for o in tokens for p in o)\n", " itos = [o for o,c in freq.most_common(max_vocab) if c > min_freq]\n", " itos.insert(0, PAD)\n", " if UNK in itos: itos.remove(UNK)\n", " itos.insert(0, UNK)\n", " pickle.dump(itos, open(path/'itos.pkl', 'wb'))\n", " h = hashlib.sha1(np.array(itos))\n", " with open(path/'numericalize.log','w') as f: f.write(h.hexdigest())\n", " return cls(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "TextMtd = IntEnum('TextMtd', 'CSV TOK IDS')\n", "\n", "import shutil\n", "\n", "class TextDataset():\n", " \"Basic dataset for NLP tasks.\"\n", " \n", " def __init__(self, path:PathOrStr, tokenizer:Tokenizer, vocab:Vocab=None, max_vocab:int=60000, chunksize:int=10000, \n", " name:str='train', min_freq:int=2, n_labels:int=1, create_mtd:TextMtd=TextMtd.CSV, classes:Classes=None):\n", " self.path,self.tokenizer,self.max_vocab,self.min_freq = Path(path/'tmp'),tokenizer,max_vocab,min_freq\n", " self.chunksize,self.name,self.n_labels,self.create_mtd = chunksize,name,n_labels,create_mtd\n", " self.vocab=vocab\n", " os.makedirs(self.path, exist_ok=True)\n", " if not self.check_toks(): self.tokenize()\n", " if not self.check_ids(): self.numericalize()\n", " \n", " if self.vocab is None: self.vocab = Vocab(self.path)\n", " self.ids = np.load(self.path/f'{self.name}_ids.npy')\n", " if os.path.isfile(self.path/f'{self.name}_lbl.npy'):\n", " self.labels = np.load(self.path/f'{self.name}_lbl.npy')\n", " else: self.labels = np.zeros((len(self.ids),), dtype=np.int64)\n", " self.classes = classes if classes else np.unique(self.labels)\n", " \n", " def __getitem__(self, idx:int) -> Tuple[int,int]: return self.ids[idx],self.labels[idx]\n", " def __len__(self) -> int: return len(self.ids)\n", " \n", " def general_check(self, pre_files:Collection[PathOrStr], post_files:Collection[PathOrStr]):\n", " \"Checks that post_files exist and were modified after all the prefiles.\"\n", " if not np.all([os.path.isfile(fname) for fname in post_files]): return False\n", " for pre_file in pre_files:\n", " if os.path.getmtime(pre_file) > os.path.getmtime(post_files[0]): return False\n", " return True\n", " \n", " def check_ids(self) -> bool:\n", " \"Checks if a new numericalization is needed.\"\n", " if self.create_mtd >= TextMtd.IDS: return True\n", " if not self.general_check([self.tok_files[0],self.id_files[1]], self.id_files): return False\n", " itos = pickle.load(open(self.id_files[1], 'rb'))\n", " h = hashlib.sha1(np.array(itos))\n", " with open(self.id_files[2]) as f:\n", " if h.hexdigest() != f.read() or len(itos) > self.max_vocab + 2: return False\n", " toks,ids = np.load(self.tok_files[0]),np.load(self.id_files[0])\n", " if len(toks) != len(ids): return False\n", " return True\n", " \n", " def check_toks(self) -> bool:\n", " \"Checks if a new tokenization is needed.\"\n", " if self.create_mtd >= TextMtd.TOK: return True\n", " if not self.general_check([self.csv_file], self.tok_files): return False\n", " with open(self.tok_files[1]) as f:\n", " if repr(self.tokenizer) != f.read(): return False\n", " return True\n", " \n", " def tokenize(self):\n", " \"Tokenizes the texts in the csv file\"\n", " print(f'Tokenizing {self.name}. This might take a while so you should grab a coffee.')\n", " curr_len = get_chunk_length(self.csv_file, self.chunksize)\n", " dfs = pd.read_csv(self.csv_file, header=None, chunksize=self.chunksize)\n", " tokens,labels = [],[]\n", " for _ in progress_bar(range(curr_len), leave=False):\n", " df = next(dfs)\n", " lbls = df.iloc[:,range(self.n_labels)].values.astype(np.int64)\n", " texts = f'\\n{BOS} {FLD} 1 ' + df[self.n_labels].astype(str)\n", " for i in range(self.n_labels+1, len(df.columns)): \n", " texts += f' {FLD} {i-n_lbls} ' + df[i].astype(str)\n", " toks = self.tokenizer.process_all(texts)\n", " tokens += toks\n", " labels += list(np.squeeze(lbls))\n", " np.save(self.tok_files[0], np.array(tokens))\n", " np.save(self.path/f'{self.name}_lbl.npy', np.array(labels))\n", " with open(self.tok_files[1],'w') as f: f.write(repr(self.tokenizer))\n", " \n", " def numericalize(self):\n", " \"Numericalizes the tokens in the token file\"\n", " print(f'Numericalizing {self.name}.')\n", " toks = np.load(self.tok_files[0])\n", " if self.vocab is None: self.vocab = Vocab.create(self.path, toks, self.max_vocab, self.min_freq)\n", " ids = np.array([self.vocab.numericalize(t) for t in toks])\n", " np.save(self.id_files[0], ids)\n", " \n", " def clear(self):\n", " \"Removes temporary files\"\n", " files = [self.path/f'{self.name}_{suff}.npy' for suff in ['ids','tok','lbl']]\n", " files.append(self.path/f'{self.name}.csv')\n", " for file in files:\n", " if os.path.isfile(file): os.remove(file)\n", " \n", " @property\n", " def csv_file(self) -> Path: return self.path/f'{self.name}.csv'\n", " @property\n", " def tok_files(self) -> List[Path]: return [self.path/f'{self.name}_tok.npy', self.path/'tokenize.log']\n", " @property\n", " def id_files(self) -> List[Path]:\n", " return [self.path/f'{self.name}_ids.npy', self.path/'itos.pkl', self.path/'numericalize.log']\n", " \n", " @classmethod\n", " def from_ids(cls, folder:PathOrStr, name:str, id_suff:str='_ids', lbl_suff:str='_lbl', \n", " itos:str='itos.pkl', **kwargs) -> 'TextDataset':\n", " \"Creates a dataset from an id, a dictionary and label file.\"\n", " orig = [Path(folder/file) for file in [f'{name}{id_suff}.npy', f'{name}{lbl_suff}.npy', itos]]\n", " dest = [Path(folder)/'tmp'/file for file in [f'{name}_ids.npy', f'{name}_lbl.npy', 'itos.pkl']]\n", " maybe_copy(orig, dest)\n", " return cls(folder, None, name=name, create_mtd=TextMtd.IDS, **kwargs)\n", " \n", " @classmethod\n", " def from_tokens(cls, folder:PathOrStr, name:str, tok_suff:str='_tok', lbl_suff:str='_lbl', \n", " **kwargs) -> 'TextDataset':\n", " \"Creates a dataset from a token and label file.\"\n", " orig = [Path(folder/file) for file in [f'{name}{tok_suff}.npy', f'{name}{lbl_suff}.npy']]\n", " dest = [Path(folder)/'tmp'/file for file in [f'{name}_tok.npy', f'{name}_lbl.npy']]\n", " maybe_copy(orig, dest)\n", " return cls(folder, None, name=name, create_mtd=TextMtd.TOK, **kwargs)\n", " \n", " @classmethod\n", " def from_csv(cls, folder:PathOrStr, tokenizer:Tokenizer, name:str, **kwargs) -> 'TextDataset':\n", " \"Creates a dataset from texts in a csv file.\"\n", " orig = [Path(folder)/f'{name}.csv']\n", " dest = [Path(folder)/'tmp'/f'{name}.csv']\n", " maybe_copy(orig, dest)\n", " return cls(folder, tokenizer, name=name, **kwargs)\n", " \n", " @classmethod\n", " def from_folder(cls, folder:PathOrStr, tokenizer:Tokenizer, name:str, classes:Classes=None, \n", " shuffle:bool=True, **kwargs) -> 'TextDataset':\n", " \"Creates a dataset from a folder\"\n", " path = Path(folder)/'tmp'\n", " os.makedirs(path, exist_ok=True)\n", " if classes is None: classes = [cls.name for cls in find_classes(Path(folder)/name)]\n", " texts,labels = [],[]\n", " for idx,label in enumerate(classes):\n", " for fname in (Path(folder)/name/label).glob('*.*'):\n", " texts.append(fname.open('r', encoding='utf8').read())\n", " labels.append(idx)\n", " texts,labels = np.array(texts),np.array(labels)\n", " if shuffle:\n", " idx = np.random.permutation(len(texts))\n", " texts,labels = texts[idx],labels[idx]\n", " df = pd.DataFrame({'text':texts, 'labels':labels}, columns=['labels','text'])\n", " if os.path.isfile(path/f'{name}.csv'):\n", " if get_total_length(path/f'{name}.csv', 10000) != len(df):\n", " df.to_csv(path/f'{name}.csv', index=False, header=False)\n", " else: df.to_csv(path/f'{name}.csv', index=False, header=False)\n", " return cls(folder, tokenizer, name=name, classes=classes, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def extract_kwargs(names:Collection[str], kwargs:KWArgs):\n", " \"Extracts the keys in names from the kwargs.\"\n", " new_kwargs = {}\n", " for arg_name in names:\n", " if arg_name in kwargs:\n", " arg_val = kwargs.pop(arg_name)\n", " new_kwargs[arg_name] = arg_val\n", " return new_kwargs, kwargs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class LanguageModelLoader():\n", " \"Creates a dataloader with bptt slightly changing.\"\n", " def __init__(self, dataset:TextDataset, bs:int=64, bptt:int=70, backwards:bool=False):\n", " self.dataset,self.bs,self.bptt,self.backwards = dataset,bs,bptt,backwards\n", " self.data = self.batchify(np.concatenate(dataset.ids))\n", " self.first,self.i,self.iter = True,0,0\n", " self.n = len(self.data)\n", "\n", " def __iter__(self):\n", " self.i,self.iter = 0,0\n", " while self.i < self.n-1 and self.iter int: return (self.n-1) // self.bptt\n", "\n", " def batchify(self, data:np.ndarray) -> LongTensor:\n", " \"Splits the corpus in batches.\"\n", " nb = data.shape[0] // self.bs\n", " data = np.array(data[:nb*self.bs]).reshape(self.bs, -1).T\n", " if self.backwards: data=data[::-1]\n", " return LongTensor(data)\n", "\n", " def get_batch(self, i:int, seq_len:int) -> Tuple[LongTensor, LongTensor]:\n", " \"Creates a batch of a given seq_len\"\n", " seq_len = min(seq_len, len(self.data) - 1 - i)\n", " return self.data[i:i+seq_len], self.data[i+1:i+1+seq_len].contiguous().view(-1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def standard_data(datasets:Collection[DatasetBase], path:PathOrStr, **kwargs) -> DataBunch:\n", " \"Simply creates a `DataBunch` from the `datasets`\"\n", " return DataBunch.create(*datasets, path=path, **kwargs)\n", "\n", "def lm_data(datasets:Collection[TextDataset], path:PathOrStr, **kwargs) -> DataBunch:\n", " \"Creates a `DataBunch` from the `datasets` for language modelling\"\n", " dataloaders = [LanguageModelLoader(ds, **kwargs) for ds in datasets]\n", " return DataBunch(*dataloaders, path=path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "DataFunc = Callable[[Collection[DatasetBase], PathOrStr, KWArgs], DataBunch]\n", "\n", "def data_from_textids(path:PathOrStr, train:str='train', valid:str='valid', test:Optional[str]=None, \n", " data_func:DataFunc=standard_data, itos:str='itos.pkl', **kwargs) -> DataBunch:\n", " \"Creates a `DataBunch` from ids, labels and a dictionary.\"\n", " path=Path(path)\n", " txt_kwargs, kwargs = extract_kwargs(['max_vocab', 'chunksize', 'min_freq', 'n_labels', 'id_suff', 'lbl_suff'], kwargs)\n", " train_ds = TextDataset.from_ids(path, train, itos=itos, **txt_kwargs)\n", " datasets = [train_ds, TextDataset.from_ids(path, valid, itos=itos, **txt_kwargs)]\n", " if test: datasets.append(TextDataset.from_ids(path, test, itos=itos, **txt_kwargs))\n", " return data_func(datasets, path, **kwargs)\n", "\n", "def data_from_texttokens(path:PathOrStr, train:str='train', valid:str='valid', test:Optional[str]=None, \n", " data_func:DataFunc=standard_data, vocab:Vocab=None, **kwargs) -> DataBunch:\n", " \"Creates a `DataBunch` from tokens and labels.\"\n", " path=Path(path)\n", " txt_kwargs, kwargs = extract_kwargs(['max_vocab', 'chunksize', 'min_freq', 'n_labels', 'tok_suff', 'lbl_suff'], kwargs)\n", " train_ds = TextDataset.from_tokens(path, train, vocab=vocab, **txt_kwargs)\n", " datasets = [train_ds, TextDataset.from_tokens(path, valid, vocab=train_ds.vocab, **txt_kwargs)]\n", " if test: datasets.append(TextDataset.from_tokens(path, test, vocab=train_ds.vocab, **txt_kwargs))\n", " return data_func(datasets, path, **kwargs)\n", "\n", "def data_from_textcsv(path:PathOrStr, tokenizer:Tokenizer, train:str='train', valid:str='valid', test:Optional[str]=None, \n", " data_func:DataFunc=standard_data, vocab:Vocab=None, **kwargs) -> DataBunch:\n", " \"Creates a `DataBunch` from texts in csv files.\"\n", " path=Path(path)\n", " txt_kwargs, kwargs = extract_kwargs(['max_vocab', 'chunksize', 'min_freq', 'n_labels'], kwargs)\n", " train_ds = TextDataset.from_csv(path, tokenizer, train, vocab=vocab, **txt_kwargs)\n", " datasets = [train_ds, TextDataset.from_csv(path, tokenizer, valid, vocab=train_ds.vocab, **txt_kwargs)]\n", " if test: datasets.append(TextDataset.from_csv(path, tokenizer, test, vocab=train_ds.vocab, **txt_kwargs))\n", " return data_func(datasets, path, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tokenizer = Tokenizer(rules=default_rules, special_cases=[BOS, FLD, 'xxunk', 'xxpad'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = data_from_textcsv(LM_PATH, tokenizer, chunksize=1000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = data_from_texttokens(LM_PATH, chunksize=1000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = data_from_textids(LM_PATH, chunksize=1000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#data[0].clear()\n", "#data[1].clear()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def make_sample(dir_name='tst_folders'):\n", " os.makedirs(PATH/dir_name, exist_ok=True)\n", " PATH1 = PATH/dir_name\n", " for name,name1 in zip(['train', 'valid'],['train', 'test']):\n", " os.makedirs(PATH1/name, exist_ok=True)\n", " for clas in ['neg', 'pos']:\n", " os.makedirs(PATH1/name/clas, exist_ok=True)\n", " fnames = list((PATH/name1/clas).glob('*.txt'))\n", " for i in range(2000):\n", " shutil.copy(fnames[i], PATH1/name/clas/fnames[i].name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#make_sample()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_ds = TextDataset.from_folder(PATH/'tst_folders', tokenizer, 'train', chunksize=1000)\n", "valid_ds = TextDataset.from_folder(PATH/'tst_folders', tokenizer, 'valid', chunksize=1000, vocab=train_ds.vocab)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def data_from_textfolder(path:PathOrStr, tokenizer:Tokenizer, train:str='train', valid:str='valid', test:Optional[str]=None, \n", " shuffle:bool=True, data_func:DataFunc=standard_data, vocab:Vocab=None, **kwargs):\n", " \"Creates a `DataBunch` from text files in folders.\"\n", " path=Path(path)\n", " txt_kwargs, kwargs = extract_kwargs(['max_vocab', 'chunksize', 'min_freq', 'n_labels'], kwargs)\n", " train_ds = TextDataset.from_folder(path, tokenizer, train, shuffle=shuffle, vocab=vocab, **txt_kwargs)\n", " datasets = [train_ds, TextDataset.from_folder(path, tokenizer, valid, classes=train_ds.classes, \n", " shuffle=shuffle, vocab=train_ds.vocab, **txt_kwargs)]\n", " if test: datasets.append(TextDataset.from_folder(path, tokenizer, test, classes=train_ds.classes, \n", " shuffle=shuffle, vocab=train_ds.vocab, **txt_kwargs))\n", " return data_func(datasets, path, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = data_from_textfolder(PATH/'tst_folders', tokenizer, chunksize=1000)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }