{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# SMASAC - fastText Embedding\n", "\n", "This notebook shows how to use the [fastText](https://fasttext.cc) to generate word, tweet representation in embedding space.\n", "\n", "This notebook is structured as follow:\n", "\n", "1. Preprocessing the data\n", "2. Training the fastText embedding model\n", "3. Query similar word based on embedding model" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "import fastText\n", "import sklearn\n", "import sklearn.metrics\n", "import numpy as np\n", "import re" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Configuration\n", "\n", "Folder structure of this project:\n", "\n", "* data: data directory\n", " - twitter_las_vegas_shooting : Text for training, sample of 50k tweets\n", " - twitter_las_vegas_shooting.preprocessed : Preprocessed training text\n", " - twitter_las_vegas_shooting.labels : Hashtags in training corpus\n", " - twitter_las_vegas_shooting.embedding : Hashtags emebdding vectors\n", " - twitter_las_vegas_shooting.low_dim_embedding : Hashtags embedding vectors in 2D\n", "* model: model directory\n", "\n", "\n", "We will use `twitter_las_vegas_shooting` for training, which contains 50,000 tweets crawled during Las Vegas mass shooting massacre. " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "root_dir = Path(\"..\")\n", "data_dir = root_dir / \"data\" / \"3-entity-extraction\"\n", "notebook_dir = root_dir / \"notebooks\"\n", "model_dir = data_dir / \"model\" \n", "\n", "if not model_dir.exists():\n", " model_dir.mkdir()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# corpus\n", "data_path = data_dir / \"twitter_las_vegas_shooting\"\n", "# Training corpus filename\n", "input_filename = str(data_path)\n", "# Model filename\n", "model_filename = str(model_dir / \"twitter.bin\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Preprocessing\n", "\n", "Preprocessing tweet to obtain a good representation of language model.\n", "\n", "* Remove hashtags\n", "* Remove mentioned\n", "* Remove punctuations\n", "* Remove urls\n", "* Convert tweet to lowercase" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# Preprocessing Config\n", "preprocess_config = {\n", " \"hashtag\": True,\n", " \"mentioned\": True,\n", " \"punctuation\": True,\n", " \"url\": True,\n", "}\n", "\n", "# Pattern\n", "hashtag_pattern = \"#\\w+\"\n", "mentioned_pattern = \"@\\w+\"\n", "url_pattern = 'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'\n", "\n", "trans_str = \"!\\\"$%&\\'()*+,-./:;<=>?[\\\\]^_`{|}~\" + \"…\"\n", "translate_table = str.maketrans(trans_str, \" \" * len(trans_str))\n", "\n", "def preprocess(s):\n", " s = s.lower()\n", " if preprocess_config[\"hashtag\"]:\n", " s = re.sub(hashtag_pattern, \"\", s)\n", " if preprocess_config[\"mentioned\"]:\n", " s = re.sub(mentioned_pattern, \"\", s)\n", " if preprocess_config[\"url\"]:\n", " s = re.sub(url_pattern, \"\", s)\n", " if preprocess_config[\"punctuation\"]:\n", " s = \" \".join(s.translate(translate_table).split())\n", " return s\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Preprocessing Example** \n", "Here is an example output of preprocessing. " ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Original Tweet:\n", "RT @TheLeadCNN: Remembering Keri Lynn Galvan, from Thousand Oaks, California. #LasVegasLost https://t.co/QuvXa6WvlE https://t.co/hDF2d3Owgn\n", "\n", "Preprocessed Tweet:\n", "rt remembering keri lynn galvan from thousand oaks california\n" ] } ], "source": [ "# example of preprocessing\n", "example_tweet = \"RT @TheLeadCNN: Remembering Keri Lynn Galvan, from Thousand Oaks, California. #LasVegasLost https://t.co/QuvXa6WvlE https://t.co/hDF2d3Owgn\"\n", "\n", "print(\"Original Tweet:\")\n", "print(example_tweet)\n", "print()\n", "print(\"Preprocessed Tweet:\")\n", "print(preprocess(example_tweet))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Preprocessing corpus**" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Preprocessing\n", "preprocessed_data_path = data_dir / \"twitter_las_vegas_shooting.preprocessed\"\n", "\n", "with data_path.open() as f:\n", " lines = [l.strip() for l in f.readlines()]\n", "\n", "with preprocessed_data_path.open(\"w\") as f:\n", " for l in lines:\n", " f.write(preprocess(l))\n", " f.write(\"\\n\")\n", "\n", "# use preprocessed data as input\n", "input_filename = str(preprocessed_data_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Training fastText embedding model\n", "\n", "Use corpus after preprocessing to generate the 100 dimensions embedding representation model." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# fastText Config\n", "embedding_model = \"skipgram\"\n", "lr = 0.05\n", "dim = 100\n", "ws = 5\n", "epoch = 5\n", "minCount = 5\n", "minCountLabel = 0\n", "minn = 3\n", "maxn = 6\n", "neg = 5\n", "wordNgrams = 1\n", "loss = \"ns\"\n", "bucket = 2000000\n", "thread = 12\n", "lrUpdateRate = 100\n", "t = 1e-4\n", "verbose = 2" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training finished.\n", "Dimension: 100\n", "Number of words: 6040\n" ] } ], "source": [ "model = fastText.train_unsupervised(\n", " input = input_filename,\n", " model=embedding_model,\n", " lr=lr,\n", " dim=dim,\n", " ws=ws,\n", " epoch=epoch,\n", " minCount=minCount,\n", " minCountLabel=minCountLabel,\n", " minn=minn,\n", " maxn=maxn,\n", " neg=neg,\n", " wordNgrams=wordNgrams,\n", " loss=loss,\n", " bucket=bucket,\n", " thread=thread,\n", " lrUpdateRate=lrUpdateRate,\n", " t=t,\n", " verbose=verbose,\n", ")\n", "\n", "print(\"Training finished.\")\n", "print(\"Dimension: {}\".format(model.get_dimension()))\n", "print(\"Number of words: {}\".format(len(model.get_words())))\n", "\n", "# Output model to disk if needed\n", "model.save_model(model_filename)\n", "\n", "# Load saved model if needed\n", "model = fastText.load_model(model_filename)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Query" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Get word vectors of corpus**" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "words = np.array(model.get_words())\n", "word_vectors = np.array([model.get_word_vector(w) for w in words])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Similarity of word vectors**\n", "In text embedding space, cosine similarity could be used for measuring similarity between words" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# Calculate N neighbors based on cosine similarity\n", "def calc_n_cosine_neighbor(inX, X, N):\n", " if inX.ndim == 1:\n", " inX = [inX]\n", " distances = sklearn.metrics.pairwise.pairwise_distances(\n", " X, inX, metric=\"cosine\")\n", " sortedDist = distances.reshape((distances.shape[0],)).argsort()\n", " return sortedDist[:N], distances\n", "\n", "# calculate nearest neighbours based on cosine similarity\n", "def nn(query, words=words, word_vectors=word_vectors, k=10):\n", " \"\"\"\n", " words: numpy array of words\n", " k: (optional, 10 by default) top k labels\n", " \"\"\"\n", " global model\n", " v = model.get_word_vector(query)\n", " idx, _ = calc_n_cosine_neighbor(v, word_vectors, k)\n", " return words[idx]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Query nearest words" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Neighbours of word \"lasvegasshooting\":\n", "shooting\n", "lasvegas\n", "vegas”\n", "las\n", "vegas\n", "rt\n", "vega\n", "“shooting”\n", "shootin\n", "\n", "shooting”\n", "shooti\n", "shootings\n", "❤\n", "🙏🙏🙏\n", "👍\n", "cc\n", "🙏🏾\n", "😢💔\n", "😓\n" ] } ], "source": [ "q = \"lasvegasshooting\"\n", "\n", "neighbours = nn(\"lasvegasshooting\", k=20)\n", "\n", "print(\"Neighbours of word \\\"{}\\\":\".format(q))\n", "for word in neighbours:\n", " print(word)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get sentence vector\n", "\n", "Use API `get_sentence_vector` to get a representation of sentende" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Tweet vector in embedding space:\n", "RT @TheLeadCNN: Remembering Keri Lynn Galvan, from Thousand Oaks, California. #LasVegasLost https://t.co/QuvXa6WvlE https://t.co/hDF2d3Owgn\n", "\n", "[-0.02461572 0.04784836 -0.05343785 0.00153351 0.04367601 0.10020498\n", " -0.01127366 -0.00975734 -0.01951972 0.07512145 0.03622668 -0.00580111\n", " 0.08758368 0.031007 -0.00507403 0.07074952 -0.05185707 -0.11242248\n", " -0.03888126 -0.01926897 0.08175821 -0.01120457 -0.07555435 -0.04022888\n", " 0.00478477 -0.0012044 0.05348494 0.0350855 0.0982817 0.01342872\n", " 0.00545024 0.00250413 0.03077969 -0.0874893 -0.03390906 0.14996992\n", " -0.01272367 -0.02368226 -0.01887075 -0.02408492 -0.03291685 -0.05095126\n", " -0.04614896 0.10122891 0.07110424 -0.12804917 -0.05888803 -0.03085945\n", " -0.01463612 0.11134949 -0.08774657 -0.01715528 -0.08862083 0.00346183\n", " 0.09192748 0.05510866 -0.04465136 -0.0433164 0.02116909 -0.06731256\n", " -0.00497376 -0.02442945 0.04918417 -0.03386533 0.05390133 0.01210842\n", " -0.03669443 0.00295777 -0.00802929 0.05568004 0.03773327 0.02532181\n", " -0.00200854 -0.02188686 -0.09255282 0.01222703 0.02790884 -0.03890502\n", " -0.08059786 -0.02257247 -0.00428823 -0.00145929 0.07885873 0.07522529\n", " 0.02859196 0.06673713 0.04707326 -0.04525249 0.04185518 0.05594757\n", " 0.03690847 0.0279574 0.05838018 0.034359 0.02512365 -0.06622847\n", " 0.00108856 -0.06983574 0.02984929 -0.01983222]\n", "\n", "Words similar this tweet\n", "['➜', '😢💔', 'nw', 'pittsburgh', '🇨🇦', '31', '🙏🏾', 'novato', 'md', '38', '💜', '👍', 'quartz', 'pasadena', '🙏🙏🙏', 'jusswaggtv', '→', '😓', 'hrc', 'umc']\n" ] } ], "source": [ "example_tweet = \"RT @TheLeadCNN: Remembering Keri Lynn Galvan, from Thousand Oaks, California. #LasVegasLost https://t.co/QuvXa6WvlE https://t.co/hDF2d3Owgn\"\n", "\n", "tweet_vector = model.get_sentence_vector(example_tweet)\n", "print(\"Tweet vector in embedding space:\")\n", "print(example_tweet)\n", "print()\n", "print(tweet_vector)\n", "\n", "print()\n", "print(\"Words similar this tweet\")\n", "idx, _ = calc_n_cosine_neighbor(tweet_vector, word_vectors, 20)\n", "print([words[i] for i in idx])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }