{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# fast.ai ULMFiT" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.text import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "twitter_data_path = \".\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_lm = (TextList\n", " .from_csv(\"./twitter-data/\", 'train-processed.csv', cols=5)\n", " .split_by_rand_pct()\n", " .label_for_lm()\n", " .databunch(bs=32))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batchsize = 32" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = language_model_learner(data_lm, AWD_LSTM, drop_mult=0.3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(10, 1e-2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save_encoder('fine_tuned_enc')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_class = (TextList\n", " .from_csv(twitter_data_path, 'train-processed.csv', cols=5, vocab=data_lm.vocab)\n", " .split_by_rand_pct()\n", " .label_from_df(cols=0)\n", " .databunch())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "twitter_classifer_learner = text_classifier_learner(data_class, AWD_LSTM, drop_mult=0.5)\n", "twitter_classifer_learner.load_encoder('fine_tuned_enc')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "twitter_classifer_learner.lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "twitter_classifer_learner.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "twitter_classifer_learner.fit_one_cycle(5, 1e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "twitter_classifer_learner.freeze_to(-2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "twitter_classifer_learner.fit_one_cycle(1, 1e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }