{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai2.torch_basics import *\n", "from fastai2.data.all import *\n", "from fastai2.text.core import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp text.data\n", "#default_cls_lvl 3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Text data\n", "\n", "> Functions and transforms to help gather text data in a `Datasets`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Backwards\n", "\n", "Reversing the text can provide higher accuracy with an ensemble with a forward model. All that is needed is a `type_tfm` that will reverse the text as it is brought in:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def reverse_text(x): return x.flip(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = tensor([0,1,2])\n", "r = reverse_text(t)\n", "test_eq(r, tensor([2,1,0]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Numericalizing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Numericalization is the step in which we convert tokens to integers. The first step is to build a correspondence token to index that is called a vocab." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def make_vocab(count, min_freq=3, max_vocab=60000, special_toks=None):\n", " \"Create a vocab of `max_vocab` size from `Counter` `count` with items present more than `min_freq`\"\n", " vocab = [o for o,c in count.most_common(max_vocab) if c >= min_freq]\n", " special_toks = ifnone(special_toks, defaults.text_spec_tok)\n", " for o in reversed(special_toks): #Make sure all special tokens are in the vocab\n", " if o in vocab: vocab.remove(o)\n", " vocab.insert(0, o)\n", " vocab = vocab[:max_vocab]\n", " return vocab + [f'xxfake' for i in range(0, 8-len(vocab)%8)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If there are more than `max_vocab` tokens, the ones kept are the most frequent.\n", "\n", "> Note: For performance when using mixed precision, the vocabulary is always made of size a multiple of 8, potentially by adding `xxfake` tokens." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "count = Counter(['a', 'a', 'a', 'a', 'b', 'b', 'c', 'c', 'd'])\n", "test_eq(set([x for x in make_vocab(count) if not x.startswith('xxfake')]), \n", " set(defaults.text_spec_tok + 'a'.split()))\n", "test_eq(len(make_vocab(count))%8, 0)\n", "test_eq(set([x for x in make_vocab(count, min_freq=1) if not x.startswith('xxfake')]), \n", " set(defaults.text_spec_tok + 'a b c d'.split()))\n", "test_eq(set([x for x in make_vocab(count,max_vocab=12, min_freq=1) if not x.startswith('xxfake')]), \n", " set(defaults.text_spec_tok + 'a b c'.split()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TensorText(TensorBase): pass\n", "class LMTensorText(TensorText): pass\n", "\n", "TensorText.__doc__ = \"Semantic type for a tensor representing text\"\n", "LMTensorText.__doc__ = \"Semantic type for a tensor representing text in language modeling\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Numericalize(Transform):\n", " \"Reversible transform of tokenized texts to numericalized ids\"\n", " def __init__(self, vocab=None, min_freq=3, max_vocab=60000, special_toks=None, pad_tok=None):\n", " store_attr(self, 'vocab,min_freq,max_vocab,special_toks,pad_tok')\n", " self.o2i = None if vocab is None else defaultdict(int, {v:k for k,v in enumerate(vocab)})\n", "\n", " def setups(self, dsets):\n", " if dsets is None: return\n", " if self.vocab is None:\n", " count = dsets.counter if getattr(dsets, 'counter', None) is not None else Counter(p for o in dsets for p in o)\n", " if self.special_toks is None and hasattr(dsets, 'special_toks'):\n", " self.special_toks = dsets.special_toks\n", " self.vocab = make_vocab(count, min_freq=self.min_freq, max_vocab=self.max_vocab, special_toks=self.special_toks)\n", " self.o2i = defaultdict(int, {v:k for k,v in enumerate(self.vocab) if v != 'xxfake'})\n", "\n", " def encodes(self, o): return TensorText(tensor([self.o2i [o_] for o_ in o]))\n", " def decodes(self, o): return L(self.vocab[o_] for o_ in o if self.vocab[o_] != self.pad_tok)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If no `vocab` is passed, one is created at setup from the data, using `make_vocab` with `min_freq` and `max_vocab`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "start = 'This is an example of text'\n", "num = Numericalize(min_freq=1)\n", "num.setup(L(start.split(), 'this is another text'.split()))\n", "test_eq(set([x for x in num.vocab if not x.startswith('xxfake')]), \n", " set(defaults.text_spec_tok + 'This is an example of text this another'.split()))\n", "test_eq(len(num.vocab)%8, 0)\n", "t = num(start.split())\n", "\n", "test_eq(t, tensor([11, 9, 12, 13, 14, 10]))\n", "test_eq(num.decode(t), start.split())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "num = Numericalize(min_freq=2)\n", "num.setup(L('This is an example of text'.split(), 'this is another text'.split()))\n", "test_eq(set([x for x in num.vocab if not x.startswith('xxfake')]), \n", " set(defaults.text_spec_tok + 'is text'.split()))\n", "test_eq(len(num.vocab)%8, 0)\n", "t = num(start.split())\n", "test_eq(t, tensor([0, 9, 0, 0, 0, 10]))\n", "test_eq(num.decode(t), f'{UNK} is {UNK} {UNK} {UNK} text'.split())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#hide\n", "df = pd.DataFrame({'texts': ['This is an example of text', 'this is another text']})\n", "tl = TfmdLists(df, [attrgetter('text'), Tokenizer.from_df('texts'), Numericalize(min_freq=2)])\n", "test_eq(tl, [tensor([2, 8, 9, 10, 0, 0, 0, 11]), tensor([2, 9, 10, 0, 11])])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## LM_DataLoader -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _maybe_first(o): return o[0] if isinstance(o, tuple) else o" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _get_tokenizer(ds):\n", " tok = getattr(ds, 'tokenizer', None)\n", " if isinstance(tok, Tokenizer): return tok\n", " if isinstance(tok, (list,L)):\n", " for t in tok:\n", " if isinstance(t, Tokenizer): return t" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _get_lengths(ds):\n", " tok = _get_tokenizer(ds)\n", " if tok is None: return\n", " return tok.get_lengths(ds.items)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "#TODO: add backward\n", "@log_args(but_as=TfmdDL.__init__)\n", "@delegates()\n", "class LMDataLoader(TfmdDL):\n", " \"A `DataLoader` suitable for language modeling\"\n", " def __init__(self, dataset, lens=None, cache=2, bs=64, seq_len=72, num_workers=0, **kwargs):\n", " self.items = ReindexCollection(dataset, cache=cache, tfm=_maybe_first)\n", " self.seq_len = seq_len\n", " if lens is None: lens = _get_lengths(dataset)\n", " if lens is None: lens = [len(o) for o in self.items]\n", " self.lens = ReindexCollection(lens, idxs=self.items.idxs)\n", " # The \"-1\" is to allow for final label, we throw away the end that's less than bs\n", " corpus = round_multiple(sum(lens)-1, bs, round_down=True)\n", " self.bl = corpus//bs #bl stands for batch length\n", " self.n_batches = self.bl//(seq_len) + int(self.bl%seq_len!=0)\n", " self.last_len = self.bl - (self.n_batches-1)*seq_len\n", " self.make_chunks()\n", " super().__init__(dataset=dataset, bs=bs, num_workers=num_workers, **kwargs)\n", " self.n = self.n_batches*bs\n", "\n", " def make_chunks(self): self.chunks = Chunks(self.items, self.lens)\n", " def shuffle_fn(self,idxs):\n", " self.items.shuffle()\n", " self.make_chunks()\n", " return idxs\n", "\n", " def create_item(self, seq):\n", " if seq>=self.n: raise IndexError\n", " sl = self.last_len if seq//self.bs==self.n_batches-1 else self.seq_len\n", " st = (seq%self.bs)*self.bl + (seq//self.bs)*self.seq_len\n", " txt = self.chunks[st : st+sl+1]\n", " return LMTensorText(txt[:-1]),txt[1:]\n", "\n", " @delegates(TfmdDL.new)\n", " def new(self, dataset=None, seq_len=None, **kwargs):\n", " lens = self.lens.coll if dataset is None else None\n", " seq_len = self.seq_len if seq_len is None else seq_len\n", " return super().new(dataset=dataset, lens=lens, seq_len=seq_len, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class LMDataLoader[source]

\n", "\n", "> LMDataLoader(**`dataset`**, **`lens`**=*`None`*, **`cache`**=*`2`*, **`bs`**=*`64`*, **`seq_len`**=*`72`*, **`num_workers`**=*`0`*, **`shuffle`**=*`False`*, **`verbose`**=*`False`*, **`do_setup`**=*`True`*, **`pin_memory`**=*`False`*, **`timeout`**=*`0`*, **`batch_size`**=*`None`*, **`drop_last`**=*`False`*, **`indexed`**=*`None`*, **`n`**=*`None`*, **`device`**=*`None`*, **`wif`**=*`None`*, **`before_iter`**=*`None`*, **`after_item`**=*`None`*, **`before_batch`**=*`None`*, **`after_batch`**=*`None`*, **`after_iter`**=*`None`*, **`create_batches`**=*`None`*, **`create_item`**=*`None`*, **`create_batch`**=*`None`*, **`retain`**=*`None`*, **`get_idxs`**=*`None`*, **`sample`**=*`None`*, **`shuffle_fn`**=*`None`*, **`do_batch`**=*`None`*) :: [`TfmdDL`](/data.core#TfmdDL)\n", "\n", "A [`DataLoader`](/data.load#DataLoader) suitable for language modeling" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(LMDataLoader, title_level=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`dataset` should be a collection of numericalized texts for this to work. `lens` can be passed for optimizing the creation, otherwise, the `LMDataLoader` will do a full pass of the `dataset` to compute them. `cache` is used to avoid reloading items unnecessarily.\n", "\n", "The `LMDataLoader` will concatenate all texts (maybe `shuffle`d) in one big stream, split it in `bs` contiguous sentences, then go through those `seq_len` at a time." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "bs,sl = 4,3\n", "ints = L([0,1,2,3,4],[5,6,7,8,9,10],[11,12,13,14,15,16,17,18],[19,20],[21,22]).map(tensor)\n", "dl = LMDataLoader(ints, bs=bs, seq_len=sl)\n", "list(dl)\n", "test_eq(list(dl),\n", " [[tensor([[0, 1, 2], [5, 6, 7], [10, 11, 12], [15, 16, 17]]),\n", " tensor([[1, 2, 3], [6, 7, 8], [11, 12, 13], [16, 17, 18]])],\n", " [tensor([[3, 4], [8, 9], [13, 14], [18, 19]]),\n", " tensor([[4, 5], [9, 10], [14, 15], [19, 20]])]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs,sl = 4,3\n", "ints = L([0,1,2,3,4],[5,6,7,8,9,10],[11,12,13,14,15,16,17,18],[19,20],[21,22,23],[24]).map(tensor)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dl = LMDataLoader(ints, bs=bs, seq_len=sl)\n", "test_eq(list(dl),\n", " [[tensor([[0, 1, 2], [6, 7, 8], [12, 13, 14], [18, 19, 20]]),\n", " tensor([[1, 2, 3], [7, 8, 9], [13, 14, 15], [19, 20, 21]])],\n", " [tensor([[3, 4, 5], [ 9, 10, 11], [15, 16, 17], [21, 22, 23]]),\n", " tensor([[4, 5, 6], [10, 11, 12], [16, 17, 18], [22, 23, 24]])]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#Check lens work\n", "dl = LMDataLoader(ints, lens=ints.map(len), bs=bs, seq_len=sl)\n", "test_eq(list(dl),\n", " [[tensor([[0, 1, 2], [6, 7, 8], [12, 13, 14], [18, 19, 20]]),\n", " tensor([[1, 2, 3], [7, 8, 9], [13, 14, 15], [19, 20, 21]])],\n", " [tensor([[3, 4, 5], [ 9, 10, 11], [15, 16, 17], [21, 22, 23]]),\n", " tensor([[4, 5, 6], [10, 11, 12], [16, 17, 18], [22, 23, 24]])]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dl = LMDataLoader(ints, bs=bs, seq_len=sl, shuffle=True)\n", "for x,y in dl: test_eq(x[:,1:], y[:,:-1])\n", "((x0,y0), (x1,y1)) = tuple(dl)\n", "#Second batch begins where first batch ended\n", "test_eq(y0[:,-1], x1[:,0]) \n", "test_eq(type(x0), LMTensorText)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#test new works\n", "dl = LMDataLoader(ints, bs=bs, seq_len=sl, shuffle=True)\n", "dl1 = dl.new()\n", "test_eq(dl1.seq_len, sl)\n", "dl2 = dl.new(seq_len=2)\n", "test_eq(dl2.seq_len, 2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Showing -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@typedispatch\n", "def show_batch(x: TensorText, y, samples, ctxs=None, max_n=10, trunc_at=150, **kwargs):\n", " if ctxs is None: ctxs = get_empty_df(min(len(samples), max_n))\n", " if trunc_at is not None: samples = L((s[0].truncate(trunc_at),*s[1:]) for s in samples)\n", " ctxs = show_batch[object](x, y, samples, max_n=max_n, ctxs=ctxs, **kwargs)\n", " display_df(pd.DataFrame(ctxs))\n", " return ctxs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@typedispatch\n", "def show_batch(x: LMTensorText, y, samples, ctxs=None, max_n=10, trunc_at=150, **kwargs):\n", " samples = L((s[0].truncate(trunc_at), s[1].truncate(trunc_at)) for s in samples)\n", " return show_batch[TensorText](x, None, samples, ctxs=ctxs, max_n=max_n, trunc_at=None, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For classification, we deal with the fact that texts don't all have the same length by using padding." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def pad_input(samples, pad_idx=1, pad_fields=0, pad_first=False, backwards=False):\n", " \"Function that collect `samples` and adds padding\"\n", " pad_fields = L(pad_fields)\n", " max_len_l = pad_fields.map(lambda f: max([len(s[f]) for s in samples]))\n", " if backwards: pad_first = not pad_first\n", " def _f(field_idx, x):\n", " if field_idx not in pad_fields: return x\n", " idx = pad_fields.items.index(field_idx) #TODO: remove items if L.index is fixed\n", " sl = slice(-len(x), sys.maxsize) if pad_first else slice(0, len(x))\n", " pad = x.new_zeros(max_len_l[idx]-x.shape[0])+pad_idx\n", " x1 = torch.cat([pad, x] if pad_first else [x, pad])\n", " if backwards: x1 = x1.flip(0)\n", " return retain_type(x1, x)\n", " return [tuple(map(lambda idxx: _f(*idxx), enumerate(s))) for s in samples]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`pad_idx` is used for the padding, and the padding is applied to the `pad_fields` of the samples. The padding is applied at the beginning if `pad_first` is `True`, and if `backwards` is added, the tensors are flipped." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(pad_input([(tensor([1,2,3]),1), (tensor([4,5]), 2), (tensor([6]), 3)], pad_idx=0), \n", " [(tensor([1,2,3]),1), (tensor([4,5,0]),2), (tensor([6,0,0]), 3)])\n", "test_eq(pad_input([(tensor([1,2,3]), (tensor([6]))), (tensor([4,5]), tensor([4,5])), (tensor([6]), (tensor([1,2,3])))], pad_idx=0, pad_fields=1), \n", " [(tensor([1,2,3]),(tensor([6,0,0]))), (tensor([4,5]),tensor([4,5,0])), ((tensor([6]),tensor([1, 2, 3])))])\n", "test_eq(pad_input([(tensor([1,2,3]),1), (tensor([4,5]), 2), (tensor([6]), 3)], pad_idx=0, pad_first=True), \n", " [(tensor([1,2,3]),1), (tensor([0,4,5]),2), (tensor([0,0,6]), 3)])\n", "test_eq(pad_input([(tensor([1,2,3]),1), (tensor([4,5]), 2), (tensor([6]), 3)], pad_idx=0, backwards=True), \n", " [(tensor([3,2,1]),1), (tensor([5,4,0]),2), (tensor([6,0,0]), 3)])\n", "x = test_eq(pad_input([(tensor([1,2,3]),1), (tensor([4,5]), 2), (tensor([6]), 3)], pad_idx=0, backwards=True), \n", " [(tensor([3,2,1]),1), (tensor([5,4,0]),2), (tensor([6,0,0]), 3)])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#Check retain type\n", "x = [(TensorText([1,2,3]),1), (TensorText([4,5]), 2), (TensorText([6]), 3)]\n", "y = pad_input(x, pad_idx=0)\n", "for s in y: test_eq(type(s[0]), TensorText)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def pad_input_chunk(samples, pad_idx=1, pad_first=True, seq_len=72):\n", " \"Pad `samples` by adding padding by chunks of size `seq_len`\"\n", " max_len = max([len(s[0]) for s in samples])\n", " def _f(x):\n", " l = max_len - x.shape[0]\n", " pad_chunk = x.new_zeros((l//seq_len) * seq_len) + pad_idx\n", " pad_res = x.new_zeros(l % seq_len) + pad_idx\n", " x1 = torch.cat([pad_chunk, x, pad_res]) if pad_first else torch.cat([x, pad_res, pad_chunk])\n", " return retain_type(x1, x)\n", " return [(_f(s[0]), *s[1:]) for s in samples]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The difference with the base `pad_input` is that most of the padding is applied first (if `pad_first=True`) or at the end (if `pad_first=False`) but only by a round multiple of `seq_len`. The rest of the padding is applied to the end (or the beginning if `pad_first=False`). This is to work with `SequenceEncoder` with recurrent models." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(pad_input_chunk([(tensor([1,2,3,4,5,6]),1), (tensor([1,2,3]), 2), (tensor([1,2]), 3)], pad_idx=0, seq_len=2), \n", " [(tensor([1,2,3,4,5,6]),1), (tensor([0,0,1,2,3,0]),2), (tensor([0,0,0,0,1,2]), 3)])\n", "test_eq(pad_input_chunk([(tensor([1,2,3,4,5,6]),), (tensor([1,2,3]),), (tensor([1,2]),)], pad_idx=0, seq_len=2), \n", " [(tensor([1,2,3,4,5,6]),), (tensor([0,0,1,2,3,0]),), (tensor([0,0,0,0,1,2]),)])\n", "test_eq(pad_input_chunk([(tensor([1,2,3,4,5,6]),), (tensor([1,2,3]),), (tensor([1,2]),)], pad_idx=0, seq_len=2, pad_first=False), \n", " [(tensor([1,2,3,4,5,6]),), (tensor([1,2,3,0,0,0]),), (tensor([1,2,0,0,0,0]),)])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _default_sort(x): return len(x[0])\n", "\n", "@delegates(TfmdDL)\n", "class SortedDL(TfmdDL):\n", " \"A `DataLoader` that goes throught the item in the order given by `sort_func`\"\n", " def __init__(self, dataset, sort_func=None, res=None, **kwargs):\n", " super().__init__(dataset, **kwargs)\n", " self.sort_func = _default_sort if sort_func is None else sort_func\n", " if res is None and self.sort_func == _default_sort: res = _get_lengths(dataset)\n", " self.res = [self.sort_func(self.do_item(i)) for i in range_of(self.dataset)] if res is None else res\n", " if len(self.res) > 0: self.idx_max = np.argmax(self.res)\n", "\n", " def get_idxs(self):\n", " idxs = super().get_idxs()\n", " if self.shuffle: return idxs\n", " return sorted(idxs, key=lambda i: self.res[i], reverse=True)\n", "\n", " def shuffle_fn(self,idxs):\n", " idxs = np.random.permutation(len(self.dataset))\n", " idx_max = np.where(idxs==self.idx_max)[0][0]\n", " idxs[0],idxs[idx_max] = idxs[idx_max],idxs[0]\n", " sz = self.bs*50\n", " chunks = [idxs[i:i+sz] for i in range(0, len(idxs), sz)]\n", " chunks = [sorted(s, key=lambda i: self.res[i], reverse=True) for s in chunks]\n", " sort_idx = np.concatenate(chunks)\n", "\n", " sz = self.bs\n", " batches = [sort_idx[i:i+sz] for i in range(0, len(sort_idx), sz)]\n", " sort_idx = np.concatenate(np.random.permutation(batches[1:-1])) if len(batches) > 2 else np.array([],dtype=np.int)\n", " sort_idx = np.concatenate((batches[0], sort_idx) if len(batches)==1 else (batches[0], sort_idx, batches[-1]))\n", " return iter(sort_idx)\n", "\n", " @delegates(TfmdDL.new)\n", " def new(self, dataset=None, **kwargs):\n", " if 'val_res' in kwargs and kwargs['val_res'] is not None: res = kwargs['val_res']\n", " else: res = self.res if dataset is None else None\n", " return super().new(dataset=dataset, res=res, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`res` is the result of `sort_func` applied on all elements of the `dataset`. You can pass it if available to make the init much faster by avoiding an initial pass over the whole dataset. For example if sorting by text length (as in the default `sort_func`, called `_default_sort`) you should pass a list with the length of each element in `dataset` to `res` to take advantage of this speed-up. \n", "\n", "To get the same init speed-up for the validation set, `val_res` (a list of text lengths for your validation set) can be passed to the `kwargs` argument of `SortedDL`. Below is an example to reduce the init time by passing a list of text lengths for both the training set and the validation set:\n", "\n", "```\n", "# Pass the training dataset text lengths to SortedDL\n", "srtd_dl=partial(SortedDL, res = train_text_lens)\n", "\n", "# Pass the validation dataset text lengths \n", "dl_kwargs = [{},{'val_res': val_text_lens}]\n", "\n", "# init our Datasets \n", "dsets = Datasets(...) \n", "\n", "# init our Dataloaders\n", "dls = dsets.dataloaders(...,dl_type = srtd_dl, dl_kwargs = dl_kwargs)\n", "```\n", "\n", "If `shuffle` is `True`, this will shuffle a bit the results of the sort to have items of roughly the same size in batches, but not in the exact sorted order." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ds = [(tensor([1,2]),1), (tensor([3,4,5,6]),2), (tensor([7]),3), (tensor([8,9,10]),4)]\n", "dl = SortedDL(ds, bs=2, before_batch=partial(pad_input, pad_idx=0))\n", "test_eq(list(dl), [(tensor([[ 3, 4, 5, 6], [ 8, 9, 10, 0]]), tensor([2, 4])), \n", " (tensor([[1, 2], [7, 0]]), tensor([1, 3]))])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ds = [(tensor(range(random.randint(1,10))),i) for i in range(101)]\n", "dl = SortedDL(ds, bs=2, create_batch=partial(pad_input, pad_idx=-1), shuffle=True, num_workers=0)\n", "batches = list(dl)\n", "max_len = len(batches[0][0])\n", "for b in batches: \n", " assert(len(b[0])) <= max_len \n", " test_ne(b[0][-1], -1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TransformBlock for text" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use the data block API, you will need this build block for texts." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TextBlock(TransformBlock):\n", " \"A `TransformBlock` for texts\"\n", " @delegates(Numericalize.__init__)\n", " def __init__(self, tok_tfm, vocab=None, is_lm=False, seq_len=72, backwards=False, **kwargs):\n", " type_tfms = [tok_tfm, Numericalize(vocab, **kwargs)]\n", " if backwards: type_tfms += [reverse_text]\n", " return super().__init__(type_tfms=type_tfms,\n", " dl_type=LMDataLoader if is_lm else SortedDL,\n", " dls_kwargs={'seq_len': seq_len} if is_lm else {'before_batch': partial(pad_input_chunk, seq_len=seq_len)})\n", "\n", " @classmethod\n", " @delegates(Tokenizer.from_df, keep=True)\n", " def from_df(cls, text_cols, vocab=None, is_lm=False, seq_len=72, backwards=False, min_freq=3, max_vocab=60000, **kwargs):\n", " \"Build a `TextBlock` from a dataframe using `text_cols`\"\n", " return cls(Tokenizer.from_df(text_cols, **kwargs), vocab=vocab, is_lm=is_lm, seq_len=seq_len,\n", " backwards=backwards, min_freq=min_freq, max_vocab=max_vocab)\n", "\n", " @classmethod\n", " @delegates(Tokenizer.from_folder, keep=True)\n", " def from_folder(cls, path, vocab=None, is_lm=False, seq_len=72, backwards=False, min_freq=3, max_vocab=60000, **kwargs):\n", " \"Build a `TextBlock` from a `path`\"\n", " return cls(Tokenizer.from_folder(path, **kwargs), vocab=vocab, is_lm=is_lm, seq_len=seq_len,\n", " backwards=backwards, min_freq=min_freq, max_vocab=max_vocab)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For efficient tokenization, you probably want to use one of the factory methods. Otherwise, you can pass your custom `tok_tfm` that will deal with tokenization (if your texts are already tokenized, you can pass `noop`), a `vocab`, or leave it to be inferred on the texts using `min_freq` and `max_vocab`.\n", "\n", "`is_lm` indicates if we want to use texts for language modeling or another task, `seq_len` is only necessary to tune if `is_lm=False`, and is passed along to `pad_input_chunk`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

TextBlock.from_df[source]

\n", "\n", "> TextBlock.from_df(**`text_cols`**, **`vocab`**=*`None`*, **`is_lm`**=*`False`*, **`seq_len`**=*`72`*, **`backwards`**=*`False`*, **`min_freq`**=*`3`*, **`max_vocab`**=*`60000`*, **`tok`**=*`None`*, **`rules`**=*`None`*, **`sep`**=*`' '`*, **`n_workers`**=*`64`*, **`mark_fields`**=*`None`*, **`res_col_name`**=*`'text'`*, **\\*\\*`kwargs`**)\n", "\n", "Build a [`TextBlock`](/text.data#TextBlock) from a dataframe using `text_cols`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TextBlock.from_df)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is an example using a sample of IMDB stored as a CSV file:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
labeltextis_valid
0negativeUn-bleeping-believable! Meg Ryan doesn't even look her usual pert lovable self in this, which normally makes me forgive her shallow ticky acting schtick. Hard to believe she was the producer on this dog. Plus Kevin Kline: what kind of suicide trip has his career been on? Whoosh... Banzai!!! Finally this was directed by the guy who did Big Chill? Must be a replay of Jonestown - hollywood style. Wooofff!False
1positiveThis is a extremely well-made film. The acting, script and camera-work are all first-rate. The music is good, too, though it is mostly early in the film, when things are still relatively cheery. There are no really superstars in the cast, though several faces will be familiar. The entire cast does an excellent job with the script.<br /><br />But it is hard to watch, because there is no good end to a situation like the one presented. It is now fashionable to blame the British for setting Hindus and Muslims against each other, and then cruelly separating them into two countries. There is som...False
\n", "
" ], "text/plain": [ " label \\\n", "0 negative \n", "1 positive \n", "\n", " text \\\n", "0 Un-bleeping-believable! Meg Ryan doesn't even look her usual pert lovable self in this, which normally makes me forgive her shallow ticky acting schtick. Hard to believe she was the producer on this dog. Plus Kevin Kline: what kind of suicide trip has his career been on? Whoosh... Banzai!!! Finally this was directed by the guy who did Big Chill? Must be a replay of Jonestown - hollywood style. Wooofff! \n", "1 This is a extremely well-made film. The acting, script and camera-work are all first-rate. The music is good, too, though it is mostly early in the film, when things are still relatively cheery. There are no really superstars in the cast, though several faces will be familiar. The entire cast does an excellent job with the script.

But it is hard to watch, because there is no good end to a situation like the one presented. It is now fashionable to blame the British for setting Hindus and Muslims against each other, and then cruelly separating them into two countries. There is som... \n", "\n", " is_valid \n", "0 False \n", "1 False " ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = untar_data(URLs.IMDB_SAMPLE)\n", "df = pd.read_csv(path/'texts.csv')\n", "\n", "imdb_clas = DataBlock(\n", " blocks=(TextBlock.from_df('text', seq_len=72), CategoryBlock),\n", " get_x=ColReader('text'), get_y=ColReader('label'), splitter=ColSplitter())\n", "\n", "dls = imdb_clas.dataloaders(df, bs=64)\n", "dls.show_batch(max_n=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`vocab`, `is_lm`, `seq_len`, `min_freq` and `max_vocab` are passed to the main init, the other argument to `Tokenizer.from_df`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

TextBlock.from_folder[source]

\n", "\n", "> TextBlock.from_folder(**`path`**, **`vocab`**=*`None`*, **`is_lm`**=*`False`*, **`seq_len`**=*`72`*, **`backwards`**=*`False`*, **`min_freq`**=*`3`*, **`max_vocab`**=*`60000`*, **`tok`**=*`None`*, **`rules`**=*`None`*, **`extensions`**=*`None`*, **`folders`**=*`None`*, **`output_dir`**=*`None`*, **`skip_if_exists`**=*`True`*, **`output_names`**=*`None`*, **`n_workers`**=*`64`*, **`encoding`**=*`'utf8'`*, **\\*\\*`kwargs`**)\n", "\n", "Build a [`TextBlock`](/text.data#TextBlock) from a `path`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TextBlock.from_folder)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`vocab`, `is_lm`, `seq_len`, `min_freq` and `max_vocab` are passed to the main init, the other argument to `Tokenizer.from_folder`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TextDataLoaders -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TextDataLoaders(DataLoaders):\n", " \"Basic wrapper around several `DataLoader`s with factory methods for NLP problems\"\n", " @classmethod\n", " @delegates(DataLoaders.from_dblock)\n", " def from_folder(cls, path, train='train', valid='valid', valid_pct=None, seed=None, vocab=None, text_vocab=None, is_lm=False,\n", " tok_tfm=None, seq_len=72, backwards=False, **kwargs):\n", " \"Create from imagenet style dataset in `path` with `train` and `valid` subfolders (or provide `valid_pct`)\"\n", " splitter = GrandparentSplitter(train_name=train, valid_name=valid) if valid_pct is None else RandomSplitter(valid_pct, seed=seed)\n", " blocks = [TextBlock.from_folder(path, text_vocab, is_lm, seq_len, backwards) if tok_tfm is None else TextBlock(tok_tfm, text_vocab, is_lm, seq_len, backwards)]\n", " if not is_lm: blocks.append(CategoryBlock(vocab=vocab))\n", " get_items = partial(get_text_files, folders=[train,valid]) if valid_pct is None else get_text_files\n", " dblock = DataBlock(blocks=blocks,\n", " get_items=get_items,\n", " splitter=splitter,\n", " get_y=None if is_lm else parent_label)\n", " return cls.from_dblock(dblock, path, path=path, seq_len=seq_len, **kwargs)\n", "\n", " @classmethod\n", " @delegates(DataLoaders.from_dblock)\n", " def from_df(cls, df, path='.', valid_pct=0.2, seed=None, text_col=0, label_col=1, label_delim=None, y_block=None,\n", " text_vocab=None, is_lm=False, valid_col=None, tok_tfm=None, seq_len=72, backwards=False, **kwargs):\n", " \"Create from `df` in `path` with `valid_pct`\"\n", " blocks = [TextBlock.from_df(text_col, text_vocab, is_lm, seq_len, backwards) if tok_tfm is None else TextBlock(tok_tfm, text_vocab, is_lm, seq_len, backwards)]\n", " if y_block is None and not is_lm:\n", " blocks.append(MultiCategoryBlock if is_listy(label_col) and len(label_col) > 1 else CategoryBlock)\n", " if y_block is not None and not is_lm: blocks += (y_block if is_listy(y_block) else [y_block])\n", " splitter = RandomSplitter(valid_pct, seed=seed) if valid_col is None else ColSplitter(valid_col)\n", " dblock = DataBlock(blocks=blocks,\n", " get_x=ColReader(\"text\"),\n", " get_y=None if is_lm else ColReader(label_col, label_delim=label_delim),\n", " splitter=splitter)\n", " return cls.from_dblock(dblock, df, path=path, seq_len=seq_len, **kwargs)\n", "\n", " @classmethod\n", " def from_csv(cls, path, csv_fname='labels.csv', header='infer', delimiter=None, **kwargs):\n", " \"Create from `csv` file in `path/csv_fname`\"\n", " df = pd.read_csv(Path(path)/csv_fname, header=header, delimiter=delimiter)\n", " return cls.from_df(df, path=path, **kwargs)\n", "\n", "TextDataLoaders.from_csv = delegates(to=TextDataLoaders.from_df)(TextDataLoaders.from_csv)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class TextDataLoaders[source]

\n", "\n", "> TextDataLoaders(**\\*`loaders`**, **`path`**=*`'.'`*, **`device`**=*`None`*) :: [`DataLoaders`](/data.core#DataLoaders)\n", "\n", "Basic wrapper around several [`DataLoader`](/data.load#DataLoader)s with factory methods for NLP problems" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TextDataLoaders, title_level=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You should not use the init directly but one of the following factory methods. All those factory methods accept as arguments:\n", "\n", "- `text_vocab`: the vocabulary used for numericalizing texts (if not passed, it's infered from the data)\n", "- `tok_tfm`: if passed, uses this `tok_tfm` instead of the default\n", "- `seq_len`: the sequence length used for batch\n", "- `bs`: the batch size\n", "- `val_bs`: the batch size for the validation `DataLoader` (defaults to `bs`)\n", "- `shuffle_train`: if we shuffle the training `DataLoader` or not\n", "- `device`: the PyTorch device to use (defaults to `default_device()`)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

TextDataLoaders.from_folder[source]

\n", "\n", "> TextDataLoaders.from_folder(**`path`**, **`train`**=*`'train'`*, **`valid`**=*`'valid'`*, **`valid_pct`**=*`None`*, **`seed`**=*`None`*, **`vocab`**=*`None`*, **`text_vocab`**=*`None`*, **`is_lm`**=*`False`*, **`tok_tfm`**=*`None`*, **`seq_len`**=*`72`*, **`backwards`**=*`False`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`device`**=*`None`*)\n", "\n", "Create from imagenet style dataset in `path` with `train` and `valid` subfolders (or provide `valid_pct`)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TextDataLoaders.from_folder)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If `valid_pct` is provided, a random split is performed (with an optional `seed`) by setting aside that percentage of the data for the validation set (instead of looking at the grandparents folder). If a `vocab` is passed, only the folders with names in `vocab` are kept.\n", "\n", "Here is an example on a sample of the IMDB movie review dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
textcategory
0▁xxbos ▁xxmaj ▁match ▁1: ▁xxmaj ▁tag ▁xxmaj ▁team ▁xxmaj ▁table ▁xxmaj ▁match ▁xxmaj ▁ bub ba ▁xxmaj ▁ray ▁and ▁xxmaj ▁spike ▁xxmaj ▁dudley ▁vs ▁xxmaj ▁eddie ▁xxmaj ▁guerrero ▁and ▁xxmaj ▁chris ▁xxmaj ▁benoit ▁xxmaj ▁ bub ba ▁xxmaj ▁ray ▁and ▁xxmaj ▁spike ▁xxmaj ▁dudley ▁started ▁things ▁off ▁with ▁a ▁xxmaj ▁tag ▁xxmaj ▁team ▁xxmaj ▁table ▁xxmaj ▁match ▁against ▁xxmaj ▁eddie ▁xxmaj ▁guerrero ▁and ▁xxmaj ▁chris ▁xxmaj ▁benoit . ▁xxmaj ▁according ▁to ▁the ▁rules ▁of ▁the ▁match , ▁both ▁opponents ▁have ▁to ▁go ▁through ▁tables ▁in ▁order ▁to ▁get ▁the ▁win . ▁xxmaj ▁benoit ▁and ▁xxmaj ▁guerrero ▁heated ▁up ▁early ▁on ▁by ▁taking ▁turns ▁hammer ing ▁first ▁xxmaj ▁spike ▁and ▁then ▁xxmaj ▁ bub ba ▁xxmaj ▁ray . ▁a ▁xxmaj ▁german ▁su plex ▁by ▁xxmaj ▁benoit ▁to ▁xxmaj ▁ bub ba ▁took ▁the ▁wind ▁out ▁of ▁the ▁xxmaj ▁dudley ▁brother . ▁xxmaj ▁spike ▁tried ▁to ▁help ▁his ▁brother , ▁but ▁thepos
1xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpadneg
2xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpadpos
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#slow\n", "path = untar_data(URLs.IMDB)\n", "dls = TextDataLoaders.from_folder(path)\n", "dls.show_batch(max_n=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

TextDataLoaders.from_df[source]

\n", "\n", "> TextDataLoaders.from_df(**`df`**, **`path`**=*`'.'`*, **`valid_pct`**=*`0.2`*, **`seed`**=*`None`*, **`text_col`**=*`0`*, **`label_col`**=*`1`*, **`label_delim`**=*`None`*, **`y_block`**=*`None`*, **`text_vocab`**=*`None`*, **`is_lm`**=*`False`*, **`valid_col`**=*`None`*, **`tok_tfm`**=*`None`*, **`seq_len`**=*`72`*, **`backwards`**=*`False`*, **`bs`**=*`64`*, **`val_bs`**=*`None`*, **`shuffle_train`**=*`True`*, **`device`**=*`None`*)\n", "\n", "Create from `df` in `path` with `valid_pct`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TextDataLoaders.from_df)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`seed` can optionally be passed for reproducibility. `text_col`, `label_col` and optionaly `valid_col` are indices or names of columns for texts/labels and the validation flag. `label_delim` can be passed for a multi-label problem if your labels are in one column, separated by a particular char. `y_block` should be passed to indicate your type of targets, in case the library did no infer it properly.\n", "\n", "Here are examples on subsets of IMDB:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls = TextDataLoaders.from_df(df, path=path, text_col='text', label_col='label', valid_col='is_valid')\n", "dls.show_batch(max_n=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls = TextDataLoaders.from_df(df, path=path, text_col='text', is_lm=True, valid_col='is_valid')\n", "dls.show_batch(max_n=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "show_doc(TextDataLoaders.from_csv)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Opens the csv file with `header` and `delimiter`, then pass all the other arguments to `TextDataLoaders.from_df`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls = TextDataLoaders.from_csv(path=path, csv_fname='texts.csv', text_col='text', label_col='label', valid_col='is_valid')\n", "dls.show_batch(max_n=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.export import notebook2script\n", "notebook2script()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }