{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.torch_core import *\n", "from fastai.basic_data import *\n", "from fastai.datasets import *\n", "from fastai.text.learner import RNNLearner\n", "from fastai.text.transform import PAD, UNK, FLD, Tokenizer\n", "from concurrent.futures import ProcessPoolExecutor, as_completed" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def save_texts(fname:PathOrStr, texts:Collection[str]):\n", " with open(fname, 'w') as f:\n", " for t in texts: f.write(f'{t}\\n')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Vocab():\n", " \"Contain the correspondance between numbers and tokens and numericalize.\"\n", "\n", " def __init__(self, itos:Dict[int,str]):\n", " self.itos = itos\n", " self.stoi = collections.defaultdict(int,{v:k for k,v in enumerate(self.itos)})\n", "\n", " def numericalize(self, t:Collection[str]) -> List[int]:\n", " \"Convert a list of tokens `t` to their ids.\"\n", " return [self.stoi[w] for w in t]\n", "\n", " def textify(self, nums:Collection[int]) -> List[str]:\n", " \"Convert a list of `nums` to their tokens.\"\n", " return '_'.join([self.itos[i] for i in nums])\n", "\n", " @classmethod\n", " def create(cls, tokens:Tokens, max_vocab:int, min_freq:int) -> 'Vocab':\n", " \"Create a vocabulary from a set of tokens.\"\n", " freq = Counter(p for o in tokens for p in o)\n", " itos = [o for o,c in freq.most_common(max_vocab) if c > min_freq]\n", " itos.insert(0, PAD)\n", " if UNK in itos: itos.remove(UNK)\n", " itos.insert(0, UNK)\n", " return cls(itos)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TextBase(LabelDataset):\n", " def __init__(self, x:Collection[Any], labels:Collection[Union[int,float]]=None, classes:Collection[Any]=None):\n", " super().__init__(classes=classes)\n", " self.x = np.array(x)\n", " self.y = np.zeros(len(x)) if labels is None else np.array(labels)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "text_exensions = ['.txt']\n", "\n", "class NumericalizedDataset(TextBase):\n", " \"To directly create a text datasets from `ids` and `labels`.\"\n", " def __init__(self, vocab:Vocab, ids:Collection[Collection[int]], labels:Collection[Union[int,float]]=None,\n", " classes:Collection[Any]=None):\n", " super().__init__(ids, labels, classes)\n", " self.vocab, self.vocab_size = vocab, len(vocab.itos)\n", " self.loss_func = F.cross_entropy if len(self.y.shape) <= 1 else F.binary_cross_entropy_with_logits\n", " \n", " def get_text_item(self, idx):\n", " return self.vocab.textify(self.x[idx]), self.classes[self.y[idx]]\n", " \n", " def save(self, path:Path, name:str):\n", " os.makedirs(path, exist_ok=True)\n", " np.save(path/f'{name}_ids.npy', self.x)\n", " np.save(path/f'{name}_lbl.npy', self.y)\n", " pickle.dump(self.vocab.itos, open(path/'itos.pkl', 'wb'))\n", " save_texts(path/'classes.txt', self.classes)\n", " \n", " @classmethod\n", " def load(cls, path:Path, name:str):\n", " vocab = Vocab(pickle.load(open(path/f'itos.pkl', 'rb')))\n", " x,y = np.load(path/f'{name}_ids.npy'), np.load(path/f'{name}_lbl.npy')\n", " classes = loadtxt_str(path/'classes.txt')\n", " return cls(vocab, x, y, classes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TokenizedDataset(TextBase):\n", " \n", " def __init__(self, tokens:Collection[Collection[str]], labels:Collection[Union[int,float]]=None, \n", " classes:Collection[Any]=None):\n", " super().__init__(tokens, labels, classes)\n", " \n", " def save(self, path:Path, name:str):\n", " os.makedirs(path, exist_ok=True)\n", " np.save(path/f'name_tok.npy', self.x)\n", " np.save(path/f'name_lbl.npy', self.y)\n", " np.savetxt(path/'classes.txt', self.classes.as_type(str))\n", " \n", " def numericalize(self, vocab:Vocab=None, max_vocab:int=60000, min_freq:int=2):\n", " vocab = ifnone(vocab, Vocab.create(self.x, max_vocab, min_freq))\n", " ids = np.array([vocab.numericalize(t) for t in self.x])\n", " return NumericalizedDataset(vocab, ids, self.y, self.classes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TextDataset(TextBase):\n", " \"Basic dataset for NLP tasks.\"\n", "\n", " def __init__(self, texts:Collection[str], labels:Collection[Union[int,float]]=None, \n", " classes:Collection[Any]=None):\n", " super().__init__(texts, labels, classes)\n", "\n", " @classmethod\n", " def from_df(cls, df:DataFrame, classes:Collection[Any]=None, n_labels:int=1, txt_cols:Collection[Union[int,str]]=None, \n", " label_cols:Collection[Union[int,str]]=None, mark_fields:bool=True) -> 'TextDataset':\n", " \"Create a `TextDataset` from the texts in a dataframe\"\n", " label_cols = ifnone(label_cols, list(range(n_labels)))\n", " if classes is None:\n", " if len(label_cols) == 0: classes = [0]\n", " elif len(label_cols) == 1: classes = df[0].unique()\n", " else: classes = label_cols\n", " lbl_type = np.float32 if len(label_cols) > 1 else np.int64\n", " lbls = df[label_cols].values.astype(lbl_type) if (len(label_cols) > 0) else [0] * len(df)\n", " txt_cols = ifnone(txt_cols, list(range(len(label_cols),len(df.columns))))\n", " texts = f'{FLD} {1} ' + df[txt_cols[0]].astype(str) if mark_fields else df[txt_cols[0]].astype(str)\n", " for i, col in enumerate(txt_cols[1:]): \n", " texts += (f' {FLD} {i+2} ' if mark_fields else ' ') + df[col].astype(str)\n", " return cls(texts.values, np.squeeze(lbls), classes)\n", "\n", " @staticmethod\n", " def _folder_files(folder:Path, label:str, extensions:Collection[str]=text_exensions)->Tuple[str,str]:\n", " \"From `folder` return texts in files and labels. The labels are all `label`.\"\n", " fnames = get_files(folder, extensions='.txt')\n", " texts = []\n", " for f in fnames:\n", " with open(f,'r') as f: texts.append(f.readlines())\n", " return texts,[label]*len(texts)\n", " \n", " @classmethod\n", " def from_folder(cls, path:PathOrStr, classes:Collection[Any]=None, \n", " extensions:Collection[str]=text_exensions) -> 'TextDataset':\n", " \"Create a `TextDataset` from the text files in a folder.\"\n", " path = Path(path)\n", " classes = ifnone(classes, [cls.name for cls in find_classes(path)])\n", " texts, labels = [], []\n", " for cl in classes:\n", " t,l = self._folder_files(path/cl, cl, extensions=extensions)\n", " fexts+=t; labels+=l\n", " keep[cl] = len(t)\n", " classes = [cl for cl in classes if keep[cl]]\n", " return cls(texts, labels, classes)\n", " \n", " @classmethod\n", " def from_one_folder(cls, path:PathOrStr, classes:Collection[Any], shuffle:bool=True, \n", " extensions:Collection[str]=text_exensions) -> 'TextDataset':\n", " \"Create a dataset from one folder, labelled `classes[0]` (used for the test set).\"\n", " path = Path(path)\n", " text,labels = self._folder_files(path, classes[0], extensions=extensions)\n", " return cls(texts, labels, classes)\n", " \n", " def tokenize(self, tokenizer:Tokenizer=None, chunksize:int=10000):\n", " tokenizer = ifnone(tokenizer, Tokenizer())\n", " tokens = []\n", " for i in progress_bar(range(0,len(self.x),chunksize), leave=False):\n", " tokens += tokenizer.process_all(self.x[i:i+chunksize])\n", " return TokenizedDataset(tokens, self.y, self.classes)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.IMDB_SAMPLE)\n", "df = pd.read_csv(path/'train.csv', header=None)\n", "ds = TextDataset.from_df(df, classes=['negative', 'positive'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_ds = ds.tokenize().numericalize()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_ds.get_text_item(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_ds.save(path/'tmp', 'train')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_ds = NumericalizedDataset.load(path/'tmp', 'train')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_ds.get_text_item(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LanguageModelLoader():\n", " \"Create a dataloader with bptt slightly changing.\"\n", " def __init__(self, dataset:TextDataset, bs:int=64, bptt:int=70, backwards:bool=False):\n", " self.dataset,self.bs,self.bptt,self.backwards = dataset,bs,bptt,backwards\n", " self.data = self.batchify(np.concatenate(dataset.x))\n", " self.first,self.i,self.iter = True,0,0\n", " self.n = len(self.data)\n", " self.num_workers = 0\n", "\n", " def __iter__(self):\n", " if getattr(self.dataset, 'item', None) is not None:\n", " yield LongTensor(getattr(self.dataset, 'item')).unsqueeze(1),LongTensor([0])\n", " self.i,self.iter = 0,0\n", " while self.i < self.n-1 and self.iter int: return (self.n-1) // self.bptt\n", "\n", " def batchify(self, data:np.ndarray) -> LongTensor:\n", " \"Split the corpus `data` in batches.\"\n", " nb = data.shape[0] // self.bs\n", " data = np.array(data[:nb*self.bs]).reshape(self.bs, -1).T\n", " if self.backwards: data=data[::-1].copy()\n", " return LongTensor(data)\n", "\n", " def get_batch(self, i:int, seq_len:int) -> Tuple[LongTensor, LongTensor]:\n", " \"Create a batch at `i` of a given `seq_len`.\"\n", " seq_len = min(seq_len, len(self.data) - 1 - i)\n", " return self.data[i:i+seq_len], self.data[i+1:i+1+seq_len].contiguous().view(-1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class SortSampler(Sampler):\n", " \"Go through the text data by order of length.\"\n", "\n", " def __init__(self, data_source:NPArrayList, key:KeyFunc): self.data_source,self.key = data_source,key\n", " def __len__(self) -> int: return len(self.data_source)\n", " def __iter__(self):\n", " return iter(sorted(range_of(self.data_source), key=self.key, reverse=True))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class SortishSampler(Sampler):\n", " \"Go through the text data by order of length with a bit of randomness.\"\n", "\n", " def __init__(self, data_source:NPArrayList, key:KeyFunc, bs:int):\n", " self.data_source,self.key,self.bs = data_source,key,bs\n", "\n", " def __len__(self) -> int: return len(self.data_source)\n", "\n", " def __iter__(self):\n", " idxs = np.random.permutation(len(self.data_source))\n", " sz = self.bs*50\n", " ck_idx = [idxs[i:i+sz] for i in range(0, len(idxs), sz)]\n", " sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx])\n", " sz = self.bs\n", " ck_idx = [sort_idx[i:i+sz] for i in range(0, len(sort_idx), sz)]\n", " max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key,\n", " ck_idx[0],ck_idx[max_ck] = ck_idx[max_ck],ck_idx[0] # then make sure it goes first.\n", " sort_idx = np.concatenate(np.random.permutation(ck_idx[1:]))\n", " sort_idx = np.concatenate((ck_idx[0], sort_idx))\n", " return iter(sort_idx)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def pad_collate(samples:BatchSamples, pad_idx:int=1, pad_first:bool=True) -> Tuple[LongTensor, LongTensor]:\n", " \"Function that collect samples and adds padding.\"\n", " max_len = max([len(s[0]) for s in samples])\n", " res = torch.zeros(max_len, len(samples)).long() + pad_idx\n", " for i,s in enumerate(samples):\n", " if pad_first: res[-len(s[0]):,i] = LongTensor(s[0])\n", " else: res[:len(s[0]):,i] = LongTensor(s[0])\n", " return res, tensor([s[1] for s in samples])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _parse_kwargs(kwargs):\n", " txt_kwargs, kwargs = extract_kwargs(['n_labels', 'txt_cols', 'label_cols'], kwargs)\n", " tok_kwargs, kwargs = extract_kwargs(['chunksize'], kwargs)\n", " num_kwargs, kwargs = extract_kwargs(['max_vocab', 'min_freq'], kwargs)\n", " return txt_kwargs, tok_kwargs, num_kwargs, kwargs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TextDataBunch(DataBunch):\n", " \n", " def save(self, cache_name:str='tmp'):\n", " cache_path = self.path/cache_name\n", " pickle.dump(self.train_ds.vocab.itos, open(cache_path/f'itos.pkl', 'wb'))\n", " np.save(cache_path/f'train_ids.npy', self.train_ds.x)\n", " np.save(cache_path/f'train_lbl.npy', self.train_ds.y)\n", " np.save(cache_path/f'valid_ids.npy', self.valid_ds.x)\n", " np.save(cache_path/f'valid_lbl.npy', self.valid_ds.y)\n", " if self.test_dl is not None: np.save(cache_path/f'test_ids.npy', self.test_ds.x)\n", " save_texts(cache_path/'classes.txt', self.train_ds.classes)\n", " \n", " @classmethod\n", " def from_ids(cls, path:PathOrStr, vocab:Vocab, trn_ids:Collection[Collection[int]], val_ids:Collection[Collection[int]], \n", " tst_ids:Collection[Collection[int]]=None, trn_lbls:Collection[Union[int,float]]=None, \n", " val_lbls:Collection[Union[int,float]]=None, classes:Collection[Any]=None, **kwargs) -> DataBunch:\n", " \"Create a `TextDataBunch` from ids, labels and a dictionary.\"\n", " train_ds = NumericalizedDataset(vocab, trn_ids, trn_lbls, classes)\n", " datasets = [train_ds, NumericalizedDataset(vocab, val_ids, val_lbls, classes)]\n", " if tst_ids is not None: datasets.append(NumericalizedDataset(vocab, tst_ids, None, classes))\n", " return cls.create(datasets, path, **kwargs)\n", "\n", " @classmethod\n", " def load(cls, path:PathOrStr, **kwargs):\n", " cache_path = Path(path)/'tmp'\n", " vocab = Vocab(pickle.load(open(cache_path/f'itos.pkl', 'rb')))\n", " trn_ids,trn_lbls = np.load(cache_path/f'train_ids.npy'), np.load(cache_path/f'train_lbl.npy')\n", " val_ids,val_lbls = np.load(cache_path/f'valid_ids.npy'), np.load(cache_path/f'valid_lbl.npy')\n", " tst_ids = np.load(cache_path/f'test_ids.npy') if os.path.isfile(cache_path/f'test_ids.npy') else None\n", " classes = loadtxt_str(cache_path/'classes.txt')\n", " return cls.from_ids(path, vocab, trn_ids, val_ids, tst_ids, trn_lbls, val_lbls, classes, **kwargs)\n", "\n", " @classmethod\n", " def from_tokens(cls, path:PathOrStr, trn_tok:Collection[Collection[str]], trn_lbls:Collection[Union[int,float]],\n", " val_tok:Collection[Collection[str]], val_lbls:Collection[Union[int,float]], vocab:Vocab=None, \n", " tst_tok:Collection[Collection[str]]=None, classes:Collection[Any]=None, **kwargs) -> DataBunch:\n", " \"Create a `TextDataBunch` from tokens and labels.\"\n", " num_kwargs, kwargs = extract_kwargs(['max_vocab', 'min_freq'], kwargs)\n", " train_ds = TokenizedDataset(trn_tok, trn_lbls, classes).numericalize(vocab, **num_kwargs)\n", " datasets = [train_ds, TokenizedDataset(val_tok, val_lbls, classes).numericalize(vocab)]\n", " if test: datasets.append(TokenizedDataset(tst_tok, [0]*len(tst_tok), classes).numericalize(vocab))\n", " return cls.create(datasets, path, **kwargs)\n", " \n", " @classmethod\n", " def from_df(cls, path:PathOrStr, train_df:DataFrame, valid_df:DataFrame, test_df:Optional[DataFrame]=None, \n", " tokenizer:Tokenizer=None, vocab:Vocab=None, classes:Collection[str]=None, **kwargs) -> DataBunch:\n", " \"Create a `TextDataBunch` from DataFrames.\"\n", " txt_kwargs, tok_kwargs, num_kwargs, kwargs = _parse_kwargs(kwargs)\n", " datasets = [(TextDataset.from_df(train_df, classes, **txt_kwargs)\n", " .tokenize(tokenizer, **tok_kwargs)\n", " .numericalize(vocab, **num_kwargs))]\n", " dfs = [valid_df] if test_df is None else [valid_df, test_df]\n", " for df in dfs:\n", " datasets.append((TextDataset.from_df(df, classes, **txt_kwargs)\n", " .tokenize(tokenizer, **tok_kwargs)\n", " .numericalize(datasets[0].vocab, **num_kwargs)))\n", " return cls.create(datasets, path, **kwargs)\n", "\n", " @classmethod\n", " def from_csv(cls, path:PathOrStr, train:str='train', valid:str='valid', test:Optional[str]=None,\n", " tokenizer:Tokenizer=None, vocab:Vocab=None, classes:Collection[str]=None, **kwargs) -> DataBunch:\n", " \"Create a `TextDataBunch` from texts in csv files.\"\n", " header = 'infer' if 'txt_cols' in kwargs else None\n", " train_df = pd.read_csv(os.path.join(path, train+'.csv'), header=header)\n", " valid_df = pd.read_csv(os.path.join(path, valid+'.csv'), header=header)\n", " test_df = None if test is None else pd.read_csv(os.path.join(path, test+'.csv'), header=header)\n", " return cls.from_df(path, train_df, valid_df, test_df, tokenizer, vocab, classes, **kwargs)\n", "\n", " @classmethod\n", " def from_folder(cls, path:PathOrStr, train:str='train', valid:str='valid', test:Optional[str]=None,\n", " tokenizer:Tokenizer=None, vocab:Vocab=None, **kwargs):\n", " \"Create a `TextDataBunch` from text files in folders.\"\n", " txt_kwargs, tok_kwargs, num_kwargs, kwargs = _parse_kwargs(kwargs)\n", " train_ds = (TextDataset.from_folder(train, classes, **txt_kwargs)\n", " .tokenize(tokenizer, **tok_kwargs)\n", " .numericalize(vocab, **num_kwargs))\n", " datasets = [train_ds, (TextDataset.from_folder(valid, train_ds.classes, **txt_kwargs)\n", " .tokenize(tokenizer, **tok_kwargs)\n", " .numericalize(train_ds.vocab, **num_kwargs))]\n", " if test:\n", " datasets.append((TextDataset.from_one_folder(valid, train_ds.classes, **txt_kwargs)\n", " .tokenize(tokenizer, **tok_kwargs)\n", " .numericalize(train_ds.vocab, **num_kwargs)))\n", " return cls.create(datasets, path, **kwargs)\n", "\n", " @classmethod\n", " def create(cls, datasets:Collection[TextDataset], path:PathOrStr, **kwargs) -> DataBunch:\n", " \"Call's `DataBunch.create` but changes the arguments so it'll work OK\"\n", " return DataBunch.create(*datasets, path=path, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TextLMDataBunch(TextDataBunch):\n", " \"Create a `TextDataBunch` suitable for training a language model.\"\n", " @classmethod\n", " def create(cls, datasets:Collection[TextDataset], path:PathOrStr, **kwargs) -> DataBunch:\n", " \"Create a `TextDataBunch` in `path` from the `datasets` for language modelling.\"\n", " dataloaders = [LanguageModelLoader(ds, **kwargs) for ds in datasets]\n", " return cls(*dataloaders, path=path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TextClasDataBunch(TextDataBunch):\n", " \"Create a `TextDataBunch` suitable for training an RNN classifier.\"\n", " @classmethod\n", " def create(cls, datasets:Collection[TextDataset], path:PathOrStr, bs=64, pad_idx=1, pad_first=True, **kwargs) -> DataBunch:\n", " \"Function that transform the `datasets` in a `DataBunch` for classification.\"\n", " collate_fn = partial(pad_collate, pad_idx=pad_idx, pad_first=pad_first)\n", " train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0].x[t]), bs=bs//2)\n", " train_dl = DataLoader(datasets[0], batch_size=bs//2, sampler=train_sampler, **kwargs)\n", " dataloaders = [train_dl]\n", " for ds in datasets[1:]:\n", " sampler = SortSampler(ds.x, key=lambda t: len(ds.x[t]))\n", " dataloaders.append(DataLoader(ds, batch_size=bs, sampler=sampler, **kwargs))\n", " return cls(*dataloaders, path=path, collate_fn=collate_fn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = TextLMDataBunch.from_csv(path, classes=['negative', 'positive'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = RNNLearner.language_model(data, pretrained_model=URLs.WT103)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "example_text = \"I would like to know which word comes after this sentence\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LanguageLearner(RNNLearner):\n", " def predict(self, text:str, n_words:int=1, tokenizer:Tokenizer=None):\n", " \"Return the `n_words` that come after `text`.\"\n", " tokenizer = ifnone(tokenizer, Tokenizer())\n", " tokens = tokenizer.process_all([text])\n", " ds = self.data.valid_ds\n", " ids = ds.vocab.numericalize(tokens[0]) \n", " self.model.reset()\n", " for _ in progress_bar(range(n_words)):\n", " ds.set_item(ids)\n", " res = self.pred_batch()\n", " ids.append(res[-1].argmax())\n", " ds.clear_item()\n", " return self.data.train_ds.vocab.textify(ids)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = LanguageLearner.language_model(data, pretrained_model=URLs.WT103)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.predict(\"Jeremy Howard is\", 100)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TextClassifierLearner(RNNLearner):\n", " def predict(self, text:str, tokenizer:Tokenizer=None):\n", " \"Return prect class, label and probabilities for `text`.\"\n", " tokenizer = ifnone(tokenizer, Tokenizer())\n", " tokens = tokenizer.process_all([text])\n", " ds = self.data.valid_ds\n", " ids = ds.vocab.numericalize(tokens[0]) \n", " self.model.reset()\n", " ds.set_item(ids)\n", " res = self.pred_batch()[0]\n", " ds.clear_item()\n", " pred_max = res.argmax()\n", " return self.data.train_ds.classes[pred_max],pred_max,res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.text.learner import get_rnn_classifier, rnn_classifier_split" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def text_classifier(data:DataBunch, bptt:int=70, max_len:int=70*20, emb_sz:int=400, nh:int=1150, nl:int=3,\n", " lin_ftrs:Collection[int]=None, ps:Collection[float]=None, pad_token:int=1,\n", " drop_mult:float=1., qrnn:bool=False, **kwargs) -> 'RNNLearner':\n", " \"Create a RNN classifier.\"\n", " dps = np.array([0.4,0.5,0.05,0.3,0.4]) * drop_mult\n", " if lin_ftrs is None: lin_ftrs = [50]\n", " if ps is None: ps = [0.1]\n", " ds = data.train_ds\n", " vocab_size, lbl = ds.vocab_size, ds.y[0]\n", " n_class = (len(ds.classes) if (not isinstance(lbl, Iterable) or (len(lbl) == 1)) else len(lbl))\n", " layers = [emb_sz*3] + lin_ftrs + [n_class]\n", " ps = [dps[4]] + ps\n", " model = get_rnn_classifier(bptt, max_len, n_class, vocab_size, emb_sz, nh, nl, pad_token,\n", " layers, ps, input_p=dps[0], weight_p=dps[1], embed_p=dps[2], hidden_p=dps[3], qrnn=qrnn)\n", " learn = TextClassifierLearner(data, model, bptt, split_func=rnn_classifier_split, **kwargs)\n", " return learn" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = TextClasDataBunch.load(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = text_classifier(data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "example_text = \"I really liked that movie, it was just the best I ever saw!\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.predict(example_text)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }