{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.text import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "itos = pickle.load(open('itos_tfm.pkl', 'rb'))\n", "vocab_sz = len(itos); vocab_sz" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = get_transformer_lm(vocab_sz, 512, 12, 12, 768, 64, 768*4, 0.1, 0.1, 0.1, 0., act=Activation.GeLU, double_drop=False,\n", " out_bias=False).cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.load_state_dict(torch.load('tfmer.pth'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = model.cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "stoi = {s:i for i,s in itos.items()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Careful: words have a `` flag in the vocabulary." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "stoi['vanilla']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "itos[15000]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def textify(ids): return ' '.join([itos[i].replace('', '') for i in ids])\n", "def numericalize(text): return [stoi[f'{w}'] for w in text.split(' ')]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def predict(text, n_words, topk=10, temperature=1.):\n", " ids = numericalize(text)\n", " x = LongTensor(ids)[None].cuda()\n", " model.reset()\n", " model.eval()\n", " new_idx = []\n", " for k in range(n_words):\n", " out = F.softmax(model(x)[0][0,-1], dim=-1)\n", " if temperature != 1.: out.pow_(1 / temperature)\n", " values, indices = out.topk(topk)\n", " next_idx = indices.gather(-1, torch.multinomial(values, 1)).item()\n", " new_idx.append(next_idx)\n", " x = LongTensor(ids + new_idx)[None].cuda()\n", " model.reset()\n", " return text + ' ' + textify(new_idx)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predict(\"this state has a population of\", 50)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }