{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tokenizing text"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from nb_200 import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preprocessing the dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.IMDB)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"from multiprocessing import Process, Queue\n",
"import spacy,html\n",
"from spacy.symbols import ORTH\n",
"from fastprogress import progress_bar,master_bar\n",
"import pickle,random"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before even tokenizing, we will apply a bit of preprocessing on the texts to clean them up (we saw the one up there had some HTML code). These rules are applied before we split the sentences in tokens."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"#special tokens\n",
"UNK, PAD, BOS, EOS, FLD, TK_REP, TK_WREP, TK_UP, TK_MAJ = \"xxunk xxpad xxbos xxeos xxfld xxrep xxwrep xxup xxmaj\".split()\n",
"\n",
"def sub_br(t):\n",
" \"Replaces the
by \\n\"\n",
" re_br = re.compile(r'<\\s*br\\s*/?>', re.IGNORECASE)\n",
" return re_br.sub(\"\\n\", t)\n",
"\n",
"def spec_add_spaces(t):\n",
" \"Add spaces around / and #\"\n",
" return re.sub(r'([/#])', r' \\1 ', t)\n",
"\n",
"def rm_useless_spaces(t):\n",
" \"Remove multiple spaces\"\n",
" return re.sub(' {2,}', ' ', t)\n",
"\n",
"def replace_rep(t):\n",
" \"Replace repetitions at the character level: cccc -> TK_REP 4 c\"\n",
" def _replace_rep(m:Collection[str]) -> str:\n",
" c,cc = m.groups()\n",
" return f' {TK_REP} {len(cc)+1} {c} '\n",
" re_rep = re.compile(r'(\\S)(\\1{3,})')\n",
" return re_rep.sub(_replace_rep, t)\n",
" \n",
"def replace_wrep(t):\n",
" \"Replace word repetitions: word word word -> TK_WREP 3 word\"\n",
" def _replace_wrep(m:Collection[str]) -> str:\n",
" c,cc = m.groups()\n",
" return f' {TK_WREP} {len(cc.split())+1} {c} '\n",
" re_wrep = re.compile(r'(\\b\\w+\\W+)(\\1{3,})')\n",
" return re_wrep.sub(_replace_wrep, t)\n",
"\n",
"def fixup_text(x):\n",
" \"Various messy things we've seen in documents\"\n",
" re1 = re.compile(r' +')\n",
" x = x.replace('#39;', \"'\").replace('amp;', '&').replace('#146;', \"'\").replace(\n",
" 'nbsp;', ' ').replace('#36;', '$').replace('\\\\n', \"\\n\").replace('quot;', \"'\").replace(\n",
" '
', \"\\n\").replace('\\\\\"', '\"').replace('',UNK).replace(' @.@ ','.').replace(\n",
" ' @-@ ','-').replace('\\\\', ' \\\\ ')\n",
" return re1.sub(' ', html.unescape(x))\n",
" \n",
"default_pre_rules = [fixup_text, replace_rep, replace_wrep, spec_add_spaces, rm_useless_spaces, sub_br]\n",
"default_spec_tok = [UNK, PAD, BOS, EOS, FLD, TK_REP, TK_WREP, TK_UP, TK_MAJ]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"replace_rep('cccc')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"replace_wrep('word word word word word ')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These rules are applies after the tokenization on the list of tokens."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"def replace_all_caps(x):\n",
" \"Replace tokens in ALL CAPS by their lower version and add `TK_UP` before.\"\n",
" res = []\n",
" for t in x:\n",
" if t.isupper() and len(t) > 1: res.append(TK_UP); res.append(t.lower())\n",
" else: res.append(t)\n",
" return res\n",
"\n",
"def deal_caps(x):\n",
" \"Replace all Capitalized tokens in by their lower version and add `TK_MAJ` before.\"\n",
" res = []\n",
" for t in x:\n",
" if t == '': continue\n",
" if t[0].isupper() and len(t) > 1 and t[1:].islower(): res.append(TK_MAJ)\n",
" res.append(t.lower())\n",
" return res\n",
"\n",
"def add_eos_bos(x): return [BOS] + x + [EOS]\n",
"\n",
"default_post_rules = [deal_caps, replace_all_caps, add_eos_bos]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"replace_all_caps(['I', 'AM', 'SHOUTING'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"deal_caps(['My', 'name', 'is', 'Jeremy'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A Tokenizer should implement two methods: init with a certain language and some special tokens, then `tokenize_pipe` which returns a generator that yields the tokenized texts (should take a generator). `chunksize` is used for some tokenizers like spacy that can treat items as batches."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class BaseTokenizer():\n",
" def __init__(self, lang, special_toks): pass\n",
" def pipe(self, items): \n",
" for t in items: yield t.split(' ')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class SpacyTokenizer():\n",
" def __init__(self, lang='en', special_toks=None, batch_size=5000):\n",
" special_toks = ifnone(special_toks, default_spec_tok)\n",
" self.nlp = spacy.blank(lang, disable=[\"parser\", \"tagger\", \"ner\"])\n",
" for w in default_spec_tok: self.nlp.tokenizer.add_special_case(w, [{ORTH: w}])\n",
" self.batch_size=batch_size\n",
" \n",
" def pipe(self, items):\n",
" for doc in self.nlp.pipe(items, batch_size=self.batch_size):\n",
" yield [d.text for d in doc]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def apply_rules(items, rules):\n",
" for o in items: yield apply_all(o, rules)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tokenize1(text, tok_func=SpacyTokenizer, pre_rules=None, post_rules=None, **tok_kwargs):\n",
" pre_rules = listify(ifnone(pre_rules, default_pre_rules.copy()))\n",
" post_rules = listify(ifnone(post_rules, default_post_rules.copy()))\n",
" tokenizer = tok_func(**tok_kwargs)\n",
" for tok in tokenizer.pipe(apply_rules([text], pre_rules)):\n",
" tok = apply_all(tok, post_rules)\n",
" return tok"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Returns a generator from `items` after applying `rules` to them."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A basic function that reads the content of file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def read_text(fname):\n",
" with open(fname, 'r') as f: return f.read()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The main function that will be called during tokenization. It will create an instance of a tokenizer with `tok_func` and `tok_kwargs`, then iterate through the `items`, apply them `pre_rules`, tokenize them, apply them `post_rules`, then apply `output_func` to the original item and the tokens and put the result in `output_queue`.\n",
"\n",
"If a `data_queue` is passed, we count the different tokens and return the Counter in it at the end."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tok_items(items, tok_func, pre_rules, post_rules, output_func, output_queue, data_queue=None, **tok_kwargs):\n",
" tokenizer = tok_func(**tok_kwargs)\n",
" if data_queue: counts = Counter()\n",
" for i,tok in enumerate(tokenizer.pipe(apply_rules(items, pre_rules))):\n",
" tok = apply_all(tok, post_rules)\n",
" output_queue.put(output_func(items[i], tok))\n",
" if data_queue: counts.update(Counter(tok))\n",
" if data_queue: data_queue.put(counts)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Helper function to create the same directory structure as in a given folder."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def create_folders(path, output_dir, include=None):\n",
" output_dir = Path(output_dir)\n",
" os.makedirs(output_dir, exist_ok=True)\n",
" for i,(p,d,f) in enumerate(os.walk(path)): # returns (dirpath, dirnames, filenames)\n",
" if include is not None and i==0: d[:] = [o for o in d if o in include]\n",
" else: d[:] = [o for o in d if not o.startswith('.')]\n",
" for x in d: os.makedirs(output_dir/(Path(p)/Path(x)).relative_to(path), exist_ok=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Preprocessing function for texts in filenames. Tokenized texts will be saved in a similar fashion in a directory suffixed with `_tok` in the parent folder of `path` (override with `output_dir`)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"SEP = '▁'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fname = path/'labels.csv'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fname.suffix"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tok_folder(path, extensions=['.txt'], include=None, output_dir=None, n_workers=4,\n",
" pre_rules=None, post_rules=None, tok_func=SpacyTokenizer, **tok_kwargs):\n",
" path = Path(path)\n",
" fnames = get_files(path, extensions=extensions, recurse=True, include=include)\n",
" output_dir = Path(ifnone(output_dir, path.parent/f'{path.name}_tok'))\n",
" create_folders(path, output_dir, include=include)\n",
" pre_rules = [read_text] + listify(ifnone(pre_rules, default_pre_rules.copy()))\n",
" post_rules = listify(ifnone(post_rules, default_post_rules.copy()))\n",
" \n",
" output_queue,data_queue = Queue(maxsize=n_workers),Queue(maxsize=n_workers)\n",
" def _output(o, tok):\n",
" out = output_dir/o.relative_to(path)\n",
" with open(out, 'w') as f: f.write(SEP.join(tok))\n",
" with open(out.parent/f'{out.stem}.len', 'w') as f: f.write(str(len(tok)))\n",
" return 1\n",
" \n",
" processes = [Process(target=tok_items,\n",
" args=(batch, tok_func, pre_rules, post_rules, _output, output_queue),\n",
" kwargs={'data_queue': data_queue, **tok_kwargs})\n",
" for i,batch in enumerate(np.array_split(fnames, n_workers))]\n",
" \n",
" for p in processes: p.start()\n",
" counter = Counter()\n",
" for _ in progress_bar(fnames, leave=False): _ = output_queue.get()\n",
" for _ in processes: counter.update(data_queue.get())\n",
" for p in processes: p.join()\n",
" pickle.dump(counter, open(output_dir/'counter.pkl','wb'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.IMDB)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# test\n",
"fnames = get_files(path, extensions=['.txt'], recurse=True, include=['train', 'test', 'unsup'])\n",
"tok_path = path.parent/'imdb_tok'\n",
"assert tok_path.exists()\n",
"#Take one file randomly\n",
"idx = random.randint(0, len(fnames)-1)\n",
"#Check we have the corresponding tokenized version...\n",
"tok_fname = tok_path/(fnames[idx].relative_to(path))\n",
"assert tok_fname.exists()\n",
"text = read_text(fnames[idx])\n",
"tok = tokenize1(text)\n",
"assert SEP.join(tok) == read_text(tok_fname)\n",
"len_fname = tok_fname.parent/f'{tok_fname.stem}.len'\n",
"assert len(tok) == int(read_text(len_fname))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When text is in a dataframe, we need to merge the text columns, and maybe mark_fields."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def join_texts(idx, df, mark_fields=False):\n",
" return ' '.join([(f'{FLD} {i} ' if mark_fields else '') + t for i,t in enumerate(df.iloc[int(idx)].values)])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Preprocessing function for texts in a dataframe. Tokenized texts will be put in a similar dataframe with just one column of texts and the other columns the same."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tok_df(df, text_cols, n_workers=4, pre_rules=None, post_rules=None, mark_fields=None, \n",
" tok_func=SpacyTokenizer, **tok_kwargs):\n",
" text_cols = listify(text_cols)\n",
" mark_fields = ifnone(mark_fields, len(listify(text_cols)) > 1)\n",
" pre_rules = listify(ifnone(pre_rules, default_pre_rules.copy()))\n",
" pre_rules = [partial(join_texts, df=df[text_cols], mark_fields=mark_fields)] + pre_rules\n",
" post_rules = listify(ifnone(post_rules, default_post_rules.copy()))\n",
" \n",
" output_queue,data_queue = Queue(maxsize=n_workers),Queue(maxsize=n_workers)\n",
" def _output(o, tok): return (o,tok)\n",
" \n",
" processes = [Process(target=tok_items,\n",
" args=(batch, tok_func, pre_rules, post_rules, _output, output_queue),\n",
" kwargs={'data_queue': data_queue, **tok_kwargs})\n",
" for i,batch in enumerate(np.array_split(range(len(df)), n_workers))]\n",
" \n",
" for p in processes: p.start()\n",
" lengths,outputs,counter = np.zeros(len(df)),np.zeros(len(df), dtype=np.object),Counter()\n",
" for _ in progress_bar(range(len(df)), leave=False): \n",
" i,tok = output_queue.get()\n",
" lengths[i],outputs[i] = len(tok),SEP.join(tok)\n",
" for _ in processes: counter.update(data_queue.get())\n",
" for p in processes: p.join()\n",
" \n",
" other_cols = [c for c in df.columns if c not in text_cols]\n",
" res = df[other_cols].copy()\n",
" res['text'],res['text_lengths'] = outputs,lengths\n",
" return res, counter"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# test\n",
"path = untar_data(URLs.IMDB_SAMPLE)\n",
"df = pd.read_csv(path/'texts.csv')\n",
"out,cnt = tok_df(df, text_cols='text')\n",
"test_eq(set(out.columns),set(list(df.columns)+['text_lengths']))\n",
"idx = random.randint(0, len(df)-1)\n",
"text = df['text'][idx]\n",
"tok = tokenize1(text)\n",
"test_eq(SEP.join(tok), out['text'][idx])\n",
"test_eq(len(tok), out['text_lengths'][idx])\n",
"#With two fields, mark fields become true by default\n",
"df['text1'] = df['text']\n",
"out,cnt = tok_df(df, text_cols=['text', 'text1'])\n",
"idx = random.randint(0, len(df)-1)\n",
"text = f\"{FLD} 0 {df['text'][idx]} {FLD} 1 {df['text1'][idx]}\"\n",
"tok = tokenize1(text)\n",
"test_eq(SEP.join(tok), out['text'][idx])\n",
"test_eq(len(tok), out['text_lengths'][idx])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tok_csv(fname, text_cols, outname=None, n_workers=4, pre_rules=None, post_rules=None, \n",
" mark_fields=None, tok_func=SpacyTokenizer, header='infer', chunksize=None, **tok_kwargs):\n",
" df = pd.read_csv(fname, header=header, chunksize=chunksize)\n",
" outname = Path(ifnone(outname, fname.parent/f'{fname.stem}_tok.csv'))\n",
" kwargs = dict(n_workers=n_workers, pre_rules=pre_rules, post_rules=post_rules, \n",
" mark_fields=mark_fields, tok_func=tok_func, **tok_kwargs)\n",
" if chunksize is None:\n",
" out,cnt = tok_df(df, text_cols, **kwargs)\n",
" out.to_csv(outname, header=header, index=False)\n",
" else:\n",
" cnt = Counter()\n",
" for i,dfp in enumerate(df):\n",
" out,c = tok_df(dfp, text_cols, **kwargs)\n",
" out.to_csv(outname, header=header if i==0 else None, index=False, mode='w' if i==0 else 'a')\n",
" cnt.update(c)\n",
" pickle.dump(cnt, open(outname.parent/'counter.pkl', 'wb'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#test\n",
"path = untar_data(URLs.IMDB_SAMPLE)\n",
"tok_csv(path/'texts.csv', 'text')\n",
"assert (path/'texts_tok.csv').exists()\n",
"df = pd.read_csv(path/'texts.csv')\n",
"df_tok = pd.read_csv(path/'texts_tok.csv')\n",
"idx = random.randint(0, len(df)-1)\n",
"text = df['text'][idx]\n",
"tok = tokenize1(text)\n",
"test_eq(SEP.join(tok), df_tok['text'][idx])\n",
"test_eq(len(tok), df_tok['text_lengths'][idx])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#test\n",
"path = untar_data(URLs.IMDB_SAMPLE)\n",
"tok_csv(path/'texts.csv', 'text', chunksize=500)\n",
"assert (path/'texts_tok.csv').exists()\n",
"df = pd.read_csv(path/'texts.csv')\n",
"df_tok = pd.read_csv(path/'texts_tok.csv')\n",
"test_eq(len(df_tok), len(df))\n",
"idx = random.randint(0, len(df)-1)\n",
"text = df['text'][idx]\n",
"tok = tokenize1(text)\n",
"test_eq(SEP.join(tok), df_tok['text'][idx])\n",
"test_eq(len(tok), df_tok['text_lengths'][idx])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Getting in a DataBunch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Text data blocks"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import collections"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ReadTokens(Transform):\n",
" def __call__(self, o):\n",
" text = read_text(o) if isinstance(o, Path) else str(o)\n",
" return text.split(SEP)\n",
" def decode(self, o): return SEP.join(o)\n",
" \n",
" def show(self, x, ax): print(x)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Numericalize(MultiCategorize):\n",
" _order = 5\n",
" def __init__(self, vocab): \n",
" self.vocab = vocab\n",
" self.o2i = collections.defaultdict(int, {w:i for i,w in enumerate(vocab)})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Text(Item):\n",
" tfm = [ReadTokens, Numericalize]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def text_getter(suf='', **kwargs):\n",
" def _inner(o, **kwargs):\n",
" return get_files(o/suf, extensions=['.txt'], recurse=True)\n",
" return _inner"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class ImdbData(DataBlock):\n",
" types = (Text,Item)\n",
" get_items = text_getter()\n",
" split = random_splitter()\n",
" label_func = lambda fn,self: int(read_text(fn.parent/f'{fn.stem}.len'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.IMDB)\n",
"path_tok = path.parent/'imdb_tok'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"counter = pickle.load(open(path_tok/'counter.pkl', 'rb'))\n",
"vocab = [w for w,i in counter.most_common(60000) if i >= 2]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dblk = ImdbData(path_tok, tfms_x=[ReadTokens(), Numericalize(vocab)])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dsrc = dblk.datasource()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x,y = dsrc.get(0,0)\n",
"t = dsrc.decode((x,y))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"t"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Batching"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LM_PreLoader():\n",
" def __init__(self, fl, lengths=None, bs=64, bptt=70, shuffle=False):\n",
" self.fl,self.bs,self.bptt,self.shuffle = fl,bs,bptt,shuffle\n",
" self.lengths = [len(o[0]) for o in fl] if lengths is None else lengths\n",
" self.n_batch = sum(self.lengths) // bs\n",
" self.batchify()\n",
" \n",
" def __len__(self): return ((self.n_batch-1) // self.bptt) * self.bs\n",
" \n",
" def __getitem__(self, i):\n",
" k = (i % self.bs) * self.n_batch + (i // self.bs) * self.bptt\n",
" item_idx = (self.cumlen > k).nonzero().min().item()\n",
" offset = k if item_idx==0 else k-self.cumlen[item_idx-1]\n",
" text = self.fl[item_idx][0][offset:]\n",
" while len(text) <= self.bptt:\n",
" item_idx += 1\n",
" text += self.fl[item_idx][0]\n",
" return tensor(text[:self.bptt]),tensor(text[1:self.bptt+1])\n",
" \n",
" def batchify(self):\n",
" self.idxs = torch.randperm(len(fl)) if self.shuffle else tensor(range(len(self.fl)))\n",
" self.cumlen = (tensor(self.lengths)[idxs] if self.shuffle else tensor(self.lengths)).cumsum(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#test\n",
"ds = LM_PreLoader(dsrc[0], lengths=lengths)\n",
"x,y = ds[0]\n",
"test_equal(x[1:], y[:-1])\n",
"x0,x1 = dsrc.get(0,0)[0],dsrc.get(1,0)[0]\n",
"test_equal(x, tensor(x0+x1)[:70])\n",
"test_equal(ds[64][0], tensor(x0+x1)[70:140])\n",
"k = ds.n_batch\n",
"x,y = ds[1]\n",
"offset = k - ds.cumlen[1262]\n",
"test_equal(x, tensor(dsrc.get(1263,0)[0][offset:offset+70]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = DataLoader(ds, 64, shuffle=False, num_workers=4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%time for (x,y) in progress_bar(data): pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}