{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CS 579\n",
"
\n",
"\n",
"## Clustering Words with K-Means\n",
"\n",
"
\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Motivation\n",
"\n",
"Often, we want to know which features appear together.\n",
"\n",
"- If you liked *Twilight* you might like *Nosferatu*.\n",
"- \"happy\" is a synonym of \"glad.\"\n",
"\n",
"Can be used to summarize a large collection of messages."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll use k-means to cluster together related words from Twitter.\n",
"\n",
"**Caution:** This uses live Twitter data, which often contains profanity."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100\n",
"200\n",
"300\n",
"400\n",
"500\n",
"600\n",
"700\n",
"800\n",
"900\n",
"1000\n",
"1100\n",
"1200\n",
"1300\n",
"1400\n",
"1500\n",
"1600\n",
"1700\n",
"1800\n",
"1900\n",
"2000\n",
"2100\n",
"2200\n",
"2300\n",
"2400\n",
"2500\n",
"2600\n",
"2700\n",
"2800\n",
"2900\n",
"3000\n",
"3100\n",
"3200\n",
"3300\n",
"3400\n",
"3500\n",
"3600\n",
"3700\n",
"3800\n",
"3900\n",
"4000\n",
"4100\n",
"4200\n",
"4300\n",
"4400\n",
"4500\n",
"4600\n",
"4700\n",
"4800\n",
"4900\n",
"5000\n",
"5100\n",
"5200\n",
"5300\n",
"5400\n",
"5500\n",
"5600\n",
"5700\n",
"5800\n",
"5900\n",
"6000\n",
"6100\n",
"6200\n",
"6300\n",
"6400\n",
"6500\n",
"6600\n",
"6700\n",
"6800\n",
"6900\n",
"7000\n",
"7100\n",
"7200\n",
"7300\n",
"7400\n",
"7500\n",
"7600\n",
"7700\n",
"7800\n",
"7900\n",
"8000\n",
"8100\n",
"8200\n",
"8300\n",
"8400\n",
"8500\n",
"8600\n",
"8700\n",
"8800\n",
"8900\n",
"9000\n",
"9100\n",
"9200\n",
"9300\n",
"9400\n",
"9500\n",
"9600\n",
"9700\n",
"9800\n",
"9900\n",
"10000\n"
]
}
],
"source": [
"# Get some tweets containing the word 'i'.\n",
"\n",
"import os\n",
"from TwitterAPI import TwitterAPI\n",
"\n",
"# Read Twitter credentials from environmental variables.\n",
"api = TwitterAPI(os.environ.get('TW_CONSUMER_KEY'),\n",
" os.environ.get('TW_CONSUMER_SECRET'),\n",
" os.environ.get('TW_ACCESS_TOKEN'),\n",
" os.environ.get('TW_ACCESS_TOKEN_SECRET'))\n",
"\n",
"# Collect 10000 tweets.\n",
"tweets = []\n",
"while True: \n",
" r = api.request('statuses/filter', {'track':'i',\n",
" 'language':'en'})\n",
" if r.status_code != 200: # error\n",
" break\n",
" else:\n",
" for item in r.get_iterator():\n",
" tweets.append(item)\n",
" if len(tweets) > 10000:\n",
" break\n",
" elif len(tweets) % 100 == 0:\n",
" print(len(tweets))\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10002\n"
]
}
],
"source": [
"print(len(tweets))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"text @KYFriedComrade im rereading it ....maybe it is about twitter i dunno anymore i give up lol\n",
"description: gringa, adoptive santiaguina, ADHD stoner lela lez left stream of consciousness commentariat + retweets of dope ppl โ๐คฃ๐๐\n",
"name: naty ๐ค๐งก๐\n",
"location: Santiago, Chile\n"
]
}
],
"source": [
"# Each tweet is a Python dict.\n",
"print('text', tweets[0]['text'])\n",
"print('description:', tweets[0]['user']['description'])\n",
"print('name:', tweets[0]['user']['name'])\n",
"print('location:', tweets[0]['user']['location'])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"tweets = [t for t in tweets if 'text' in t]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"9806"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(tweets)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['im',\n",
" 'rereading',\n",
" 'it',\n",
" 'maybe',\n",
" 'it',\n",
" 'is',\n",
" 'about',\n",
" 'twitter',\n",
" 'i',\n",
" 'dunno',\n",
" 'anymore',\n",
" 'i',\n",
" 'give',\n",
" 'up',\n",
" 'lol']"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Tokenize each tweet text.\n",
"import re\n",
"tokens = []\n",
"for tweet in tweets:\n",
" text = tweet['text'].lower()\n",
" text = re.sub('@\\S+', ' ', text) # Remove mentions.\n",
" text = re.sub('http\\S+', ' ', text) # Remove urls.\n",
" tokens.append(re.findall('[A-Za-z]+', text)) # Retain words.\n",
"tokens[0]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Count words.\n",
"from collections import Counter\n",
"\n",
"word_counts = Counter()\n",
"for tweet in tokens:\n",
" word_counts.update(tweet)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"13183 unique terms\n"
]
},
{
"data": {
"text/plain": [
"[('i', 11474),\n",
" ('rt', 5231),\n",
" ('the', 3741),\n",
" ('to', 3629),\n",
" ('a', 2863),\n",
" ('and', 2425),\n",
" ('you', 2414),\n",
" ('my', 2104),\n",
" ('it', 1820),\n",
" ('this', 1816)]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Inspect word counts.\n",
"import math\n",
"\n",
"print(len(word_counts), 'unique terms')\n",
"word_counts.most_common(10)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4077 words occur at least three times.\n"
]
}
],
"source": [
"# Retain in vocabulary words occurring more than twice.\n",
"vocab = set([w for w, c in word_counts.items() if c > 2])\n",
"print('%d words occur at least three times.' % len(vocab))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# Prune tokens.\n",
"newtoks = []\n",
"for i, tweet in enumerate(tokens):\n",
" newtok = [token for token in tweet if token in vocab]\n",
" if len(newtok) > 0:\n",
" newtoks.append(newtok)\n",
"tokens = newtoks"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['im',\n",
" 'it',\n",
" 'maybe',\n",
" 'it',\n",
" 'is',\n",
" 'about',\n",
" 'twitter',\n",
" 'i',\n",
" 'anymore',\n",
" 'i',\n",
" 'give',\n",
" 'up',\n",
" 'lol']"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# A sample pruned tweet.\n",
"tokens[0]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['rt',\n",
" 'from',\n",
" 'the',\n",
" 'bottom',\n",
" 'of',\n",
" 'my',\n",
" 'heart',\n",
" 'i',\n",
" 'hope',\n",
" 'is',\n",
" 'a',\n",
" 'better',\n",
" 'mental',\n",
" 'health',\n",
" 'year',\n",
" 'for',\n",
" 'everyone',\n",
" 'lt']"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokens[2]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Context features**\n",
"\n",
"To determine if two words are similar, we will create a feature vector that counts how often other words appear nearby.\n",
"\n",
"E.g.,\n",
"\n",
"> I really **love** school.\n",
"\n",
"> I really **like** school.\n",
"\n",
"> You **love** school.\n",
"\n",
"**love:** {really@-1: 1, school@1: 2, you@-1: 1}\n",
"\n",
"**like:** {really@-1: 1, school@1: 1}\n",
"\n",
"
\n",
"\n",
"**Assumption**: words with similar meaning have similar contexts vectors.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"context for word twitter in ['im', 'it', 'maybe', 'it', 'is', 'about', 'twitter', 'i', 'anymore', 'i', 'give', 'up', 'lol']\n",
"['it@-2', 'maybe@-1', 'is@1', 'about@2']\n"
]
}
],
"source": [
"import numpy as np\n",
"def get_contexts(tweet, i, window):\n",
" \"\"\"\n",
" Get the context features for token at position i\n",
" in this tweet, using the given window size.\n",
" \"\"\"\n",
" features = []\n",
" for j in range(np.amax([0, i-window]), i):\n",
" features.append(tweet[j] + \"@\" + str(j-i))\n",
" for j in range(i+1, min(i + window + 1, len(tweet))):\n",
" features.append(tweet[j] + \"@\" + str(j-i))\n",
" return features\n",
"\n",
"print('context for word %s in %s' % (tokens[0][6], tokens[0]))\n",
"print(get_contexts(tokens[0], i=3, window=2))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"** Q: How would the approach differ if we ignore location of context?**\n",
"\n",
"E.g., **love:** {really: 1, school:1, you: 1} **vs** {really@-1: 1, school@1: 1, you@-1: 1}"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# For each term, create a context vector, indicating how often\n",
"# each word occurs to the left or right of it.\n",
"from collections import defaultdict\n",
"import numpy as np\n",
"\n",
"# dict from term to context vector.\n",
"contexts = defaultdict(lambda: Counter())\n",
"window = 2\n",
"for tweet in tokens:\n",
" for i, token in enumerate(tweet):\n",
" features = get_contexts(tweet, i, window)\n",
" contexts[token].update(features)\n",
" # Optionally: ignore word order\n",
" # contexts[token].update(tweet[:i] + tweet[i+1:])\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('rt@-1', 1823),\n",
" ('m@1', 1667),\n",
" ('t@2', 798),\n",
" ('rt@-2', 529),\n",
" ('and@-1', 444),\n",
" ('to@2', 433),\n",
" ('love@1', 424),\n",
" ('a@2', 418),\n",
" ('have@1', 412),\n",
" ('don@1', 398),\n",
" ('am@1', 374),\n",
" ('can@1', 361),\n",
" ('you@2', 359),\n",
" ('ll@1', 334),\n",
" ('ve@1', 326),\n",
" ('when@-1', 323),\n",
" ('just@1', 320),\n",
" ('but@-1', 277),\n",
" ('was@1', 276),\n",
" ('the@2', 262)]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"contexts['i'].most_common(20)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**tf-idf vectors**\n",
"\n",
"- We will transform the context features by dividing by (the log of) the number of distinct terms this feature appears in."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('i@-1', 11338),\n",
" ('i@-2', 10905),\n",
" ('i@1', 9485),\n",
" ('i@2', 7121),\n",
" ('rt@-1', 5217)]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Compute the number of different contexts each term appears in.\n",
"# Actually: this is the total number of times this context feature appears.\n",
"tweet_freq = Counter()\n",
"for context in contexts.values():\n",
" tweet_freq.update(context)\n",
"tweet_freq.most_common(5)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": true
},
"outputs": [
{
"data": {
"text/plain": [
"Counter({1766: 1,\n",
" 25: 71,\n",
" 109: 8,\n",
" 55: 24,\n",
" 772: 1,\n",
" 88: 12,\n",
" 5200: 1,\n",
" 306: 1,\n",
" 2381: 1,\n",
" 3436: 1,\n",
" 34: 45,\n",
" 7: 842,\n",
" 412: 2,\n",
" 106: 8,\n",
" 649: 1,\n",
" 211: 2,\n",
" 44: 32,\n",
" 5: 1369,\n",
" 1032: 1,\n",
" 2021: 1,\n",
" 903: 1,\n",
" 4: 2130,\n",
" 2335: 2,\n",
" 3249: 1,\n",
" 1188: 1,\n",
" 2850: 1,\n",
" 1377: 1,\n",
" 670: 1,\n",
" 32: 43,\n",
" 7121: 1,\n",
" 62: 19,\n",
" 27: 62,\n",
" 2216: 1,\n",
" 5217: 1,\n",
" 14: 224,\n",
" 35: 41,\n",
" 60: 18,\n",
" 320: 1,\n",
" 10: 415,\n",
" 815: 2,\n",
" 219: 1,\n",
" 3558: 1,\n",
" 40: 32,\n",
" 118: 8,\n",
" 862: 1,\n",
" 75: 14,\n",
" 6: 1045,\n",
" 10905: 1,\n",
" 99: 7,\n",
" 97: 6,\n",
" 140: 4,\n",
" 385: 3,\n",
" 424: 2,\n",
" 131: 3,\n",
" 201: 4,\n",
" 33: 42,\n",
" 114: 7,\n",
" 91: 5,\n",
" 45: 26,\n",
" 21: 102,\n",
" 1453: 1,\n",
" 119: 4,\n",
" 78: 10,\n",
" 65: 25,\n",
" 468: 1,\n",
" 258: 1,\n",
" 144: 1,\n",
" 52: 26,\n",
" 585: 1,\n",
" 1447: 1,\n",
" 1: 240,\n",
" 158: 2,\n",
" 103: 10,\n",
" 116: 7,\n",
" 214: 2,\n",
" 92: 11,\n",
" 199: 3,\n",
" 59: 17,\n",
" 8: 669,\n",
" 410: 1,\n",
" 584: 1,\n",
" 16: 192,\n",
" 79: 13,\n",
" 510: 2,\n",
" 367: 1,\n",
" 41: 48,\n",
" 1652: 1,\n",
" 1807: 1,\n",
" 94: 11,\n",
" 113: 4,\n",
" 3: 3235,\n",
" 31: 59,\n",
" 586: 1,\n",
" 1196: 1,\n",
" 133: 3,\n",
" 173: 1,\n",
" 20: 132,\n",
" 3656: 1,\n",
" 503: 1,\n",
" 2202: 1,\n",
" 15: 203,\n",
" 390: 3,\n",
" 911: 1,\n",
" 227: 3,\n",
" 185: 4,\n",
" 574: 1,\n",
" 2: 831,\n",
" 803: 1,\n",
" 46: 23,\n",
" 453: 1,\n",
" 66: 17,\n",
" 19: 132,\n",
" 2037: 1,\n",
" 28: 70,\n",
" 73: 15,\n",
" 64: 13,\n",
" 2771: 1,\n",
" 93: 7,\n",
" 1270: 1,\n",
" 490: 1,\n",
" 146: 3,\n",
" 1305: 1,\n",
" 905: 1,\n",
" 653: 1,\n",
" 149: 5,\n",
" 161: 5,\n",
" 9485: 1,\n",
" 80: 12,\n",
" 49: 14,\n",
" 23: 89,\n",
" 84: 9,\n",
" 83: 12,\n",
" 9: 454,\n",
" 1475: 1,\n",
" 1036: 1,\n",
" 129: 6,\n",
" 293: 2,\n",
" 321: 1,\n",
" 2452: 1,\n",
" 102: 3,\n",
" 70: 14,\n",
" 12: 301,\n",
" 37: 47,\n",
" 1480: 1,\n",
" 297: 3,\n",
" 61: 12,\n",
" 966: 1,\n",
" 69: 9,\n",
" 1898: 1,\n",
" 1522: 1,\n",
" 48: 27,\n",
" 148: 3,\n",
" 403: 1,\n",
" 484: 2,\n",
" 1629: 1,\n",
" 215: 4,\n",
" 18: 123,\n",
" 13: 259,\n",
" 17: 172,\n",
" 162: 3,\n",
" 105: 5,\n",
" 176: 1,\n",
" 192: 2,\n",
" 1328: 1,\n",
" 24: 81,\n",
" 58: 21,\n",
" 29: 65,\n",
" 1194: 1,\n",
" 57: 22,\n",
" 206: 2,\n",
" 353: 1,\n",
" 873: 1,\n",
" 11: 369,\n",
" 122: 6,\n",
" 77: 14,\n",
" 226: 3,\n",
" 30: 64,\n",
" 76: 13,\n",
" 1645: 1,\n",
" 39: 34,\n",
" 1432: 1,\n",
" 445: 1,\n",
" 51: 12,\n",
" 26: 68,\n",
" 743: 1,\n",
" 438: 3,\n",
" 627: 1,\n",
" 1363: 1,\n",
" 1122: 1,\n",
" 197: 3,\n",
" 1624: 1,\n",
" 451: 1,\n",
" 619: 1,\n",
" 640: 1,\n",
" 301: 1,\n",
" 1372: 1,\n",
" 67: 15,\n",
" 180: 3,\n",
" 1366: 1,\n",
" 442: 1,\n",
" 536: 1,\n",
" 661: 1,\n",
" 81: 10,\n",
" 47: 33,\n",
" 22: 100,\n",
" 1219: 1,\n",
" 648: 1,\n",
" 72: 11,\n",
" 123: 2,\n",
" 407: 1,\n",
" 187: 3,\n",
" 339: 1,\n",
" 50: 16,\n",
" 1322: 1,\n",
" 449: 1,\n",
" 152: 6,\n",
" 63: 22,\n",
" 1548: 1,\n",
" 43: 32,\n",
" 901: 1,\n",
" 345: 1,\n",
" 881: 1,\n",
" 479: 1,\n",
" 183: 2,\n",
" 71: 15,\n",
" 364: 4,\n",
" 333: 1,\n",
" 387: 1,\n",
" 589: 1,\n",
" 147: 5,\n",
" 174: 2,\n",
" 355: 2,\n",
" 242: 2,\n",
" 181: 1,\n",
" 909: 1,\n",
" 1027: 1,\n",
" 446: 2,\n",
" 858: 1,\n",
" 2170: 1,\n",
" 492: 2,\n",
" 233: 2,\n",
" 179: 2,\n",
" 89: 7,\n",
" 1547: 1,\n",
" 107: 4,\n",
" 124: 5,\n",
" 139: 3,\n",
" 885: 1,\n",
" 402: 2,\n",
" 101: 7,\n",
" 86: 12,\n",
" 3542: 1,\n",
" 157: 1,\n",
" 245: 1,\n",
" 291: 2,\n",
" 324: 1,\n",
" 150: 5,\n",
" 2046: 1,\n",
" 216: 2,\n",
" 299: 2,\n",
" 3620: 1,\n",
" 621: 1,\n",
" 505: 1,\n",
" 56: 14,\n",
" 400: 1,\n",
" 172: 3,\n",
" 110: 1,\n",
" 325: 3,\n",
" 104: 5,\n",
" 126: 4,\n",
" 177: 3,\n",
" 1268: 1,\n",
" 1204: 1,\n",
" 130: 5,\n",
" 504: 1,\n",
" 3184: 1,\n",
" 134: 2,\n",
" 204: 2,\n",
" 1406: 1,\n",
" 1407: 1,\n",
" 121: 3,\n",
" 285: 2,\n",
" 738: 1,\n",
" 924: 1,\n",
" 283: 1,\n",
" 38: 35,\n",
" 1477: 1,\n",
" 53: 14,\n",
" 362: 1,\n",
" 1282: 1,\n",
" 704: 1,\n",
" 125: 9,\n",
" 155: 2,\n",
" 141: 1,\n",
" 686: 1,\n",
" 85: 13,\n",
" 383: 1,\n",
" 2700: 1,\n",
" 839: 2,\n",
" 570: 1,\n",
" 347: 1,\n",
" 210: 3,\n",
" 82: 6,\n",
" 365: 2,\n",
" 282: 1,\n",
" 357: 2,\n",
" 54: 10,\n",
" 1066: 1,\n",
" 170: 1,\n",
" 189: 2,\n",
" 222: 1,\n",
" 545: 1,\n",
" 317: 1,\n",
" 377: 1,\n",
" 356: 1,\n",
" 778: 1,\n",
" 87: 9,\n",
" 154: 2,\n",
" 635: 1,\n",
" 376: 1,\n",
" 68: 12,\n",
" 136: 3,\n",
" 265: 3,\n",
" 213: 4,\n",
" 253: 2,\n",
" 287: 1,\n",
" 1329: 1,\n",
" 184: 4,\n",
" 108: 8,\n",
" 239: 1,\n",
" 397: 2,\n",
" 74: 5,\n",
" 167: 2,\n",
" 255: 1,\n",
" 487: 1,\n",
" 1595: 1,\n",
" 90: 3,\n",
" 1452: 1,\n",
" 135: 2,\n",
" 1677: 1,\n",
" 432: 1,\n",
" 1709: 1,\n",
" 200: 5,\n",
" 250: 2,\n",
" 525: 2,\n",
" 883: 1,\n",
" 42: 17,\n",
" 188: 3,\n",
" 138: 4,\n",
" 414: 1,\n",
" 169: 3,\n",
" 723: 2,\n",
" 137: 3,\n",
" 1737: 1,\n",
" 127: 5,\n",
" 194: 1,\n",
" 768: 1,\n",
" 304: 1,\n",
" 511: 1,\n",
" 932: 1,\n",
" 163: 3,\n",
" 593: 1,\n",
" 120: 1,\n",
" 481: 1,\n",
" 1042: 1,\n",
" 708: 1,\n",
" 143: 10,\n",
" 145: 3,\n",
" 153: 4,\n",
" 337: 1,\n",
" 363: 1,\n",
" 515: 2,\n",
" 11338: 1,\n",
" 117: 6,\n",
" 430: 1,\n",
" 100: 7,\n",
" 112: 7,\n",
" 343: 1,\n",
" 132: 3,\n",
" 241: 4,\n",
" 1283: 1,\n",
" 844: 1,\n",
" 166: 2,\n",
" 352: 1,\n",
" 472: 1,\n",
" 394: 2,\n",
" 234: 3,\n",
" 220: 1,\n",
" 1467: 1,\n",
" 3582: 1,\n",
" 164: 1,\n",
" 368: 2,\n",
" 171: 2,\n",
" 98: 8,\n",
" 483: 1,\n",
" 1324: 1,\n",
" 96: 5,\n",
" 678: 1,\n",
" 36: 29,\n",
" 263: 1,\n",
" 354: 1,\n",
" 286: 1,\n",
" 142: 2,\n",
" 228: 1,\n",
" 378: 1,\n",
" 278: 2,\n",
" 361: 2,\n",
" 539: 1,\n",
" 159: 1,\n",
" 281: 1,\n",
" 186: 1,\n",
" 208: 3,\n",
" 346: 1,\n",
" 209: 2,\n",
" 2345: 1,\n",
" 195: 1,\n",
" 379: 1,\n",
" 259: 2,\n",
" 160: 1,\n",
" 202: 1,\n",
" 115: 1,\n",
" 277: 2,\n",
" 745: 1,\n",
" 623: 1,\n",
" 319: 1,\n",
" 274: 2,\n",
" 212: 1,\n",
" 232: 1,\n",
" 657: 1,\n",
" 328: 1,\n",
" 348: 2,\n",
" 370: 1,\n",
" 221: 1,\n",
" 331: 1,\n",
" 646: 1,\n",
" 332: 1,\n",
" 225: 1,\n",
" 554: 1,\n",
" 360: 1,\n",
" 271: 1,\n",
" 231: 1,\n",
" 251: 1,\n",
" 244: 2,\n",
" 452: 1,\n",
" 436: 1,\n",
" 637: 1,\n",
" 266: 1,\n",
" 128: 2,\n",
" 165: 1,\n",
" 111: 4,\n",
" 230: 1,\n",
" 302: 1,\n",
" 312: 1,\n",
" 1683: 1,\n",
" 899: 1,\n",
" 193: 1,\n",
" 1568: 1,\n",
" 551: 1,\n",
" 289: 1,\n",
" 660: 1,\n",
" 1774: 1,\n",
" 156: 2,\n",
" 532: 1,\n",
" 429: 1,\n",
" 249: 1,\n",
" 295: 1,\n",
" 606: 2,\n",
" 493: 1,\n",
" 405: 2,\n",
" 409: 1,\n",
" 466: 1,\n",
" 393: 1,\n",
" 528: 1,\n",
" 168: 1,\n",
" 513: 1,\n",
" 151: 1,\n",
" 330: 1,\n",
" 191: 1,\n",
" 229: 1,\n",
" 238: 1,\n",
" 95: 1,\n",
" 175: 1})"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Counter(tweet_freq.values())"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('i@2', 1335),\n",
" ('i@1', 1331),\n",
" ('i@-2', 1296),\n",
" ('the@-1', 1087),\n",
" ('and@1', 978)]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# As opposed to the above, this computes the number of unique terms that this feature\n",
"# appears in. Q: How do you expect to affect the output?\n",
"tweet_freq_2 = Counter()\n",
"for context in contexts.values():\n",
" tweet_freq_2.update(context.keys())\n",
"tweet_freq_2.most_common(5)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('m@1', 0.4992255422401076),\n",
" ('rt@-1', 0.4843410432712745),\n",
" ('t@2', 0.24288406965963458),\n",
" ('love@1', 0.14382229255205484),\n",
" ('rt@-2', 0.1405945803985488)]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Transform each context vector to be term freq / tweet frequency. \n",
"# Also then normalize by length.\n",
"for term, context in contexts.items():\n",
" for term2, frequency in context.items():\n",
" # tf / [ 1 + log(df) ]\n",
" context[term2] = frequency / (1. + math.log(tweet_freq[term2]))\n",
" length = math.sqrt(sum([v*v for v in context.values()]))\n",
" for term2, frequency in context.items():\n",
" context[term2] = 1. * frequency / length\n",
" \n",
"contexts['i'].most_common(5)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('worthless@-2', 0.33764514643816557),\n",
" ('high@-1', 0.2816246484908819),\n",
" ('holding@2', 0.2323291846302529),\n",
" ('at@-1', 0.2204425813331341),\n",
" ('to@-1', 0.20596171539422523),\n",
" ('son@-2', 0.18688995834710687),\n",
" ('and@1', 0.18452879080306742),\n",
" ('her@2', 0.16723062126872543),\n",
" ('in@-2', 0.16473427852389158),\n",
" ('uneducated@1', 0.15939656236107255)]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"contexts['school'].most_common(10)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('i@-1', 0.8517303851828948),\n",
" ('you@1', 0.341304167014421),\n",
" ('rt@-2', 0.14556771763853768),\n",
" ('so@2', 0.11603489382180723),\n",
" ('i@-2', 0.11090169450417534),\n",
" ('this@1', 0.09080734081934172),\n",
" ('u@1', 0.08118539291674556),\n",
" ('it@1', 0.080832127605461),\n",
" ('in@-1', 0.07776084336903491),\n",
" ('i@2', 0.06941393383495353)]"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"contexts['love'].most_common(10)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[('i@-1', 0.6736723361573105),\n",
" ('grocery@1', 0.3146058368934483),\n",
" ('store@2', 0.2657744340570913),\n",
" ('fucking@-1', 0.23378948537609445),\n",
" ('pressure@2', 0.23209310502084446),\n",
" ('the@1', 0.2063211337802833),\n",
" ('i@-2', 0.18442359838340952),\n",
" ('rt@-2', 0.17387587244484196),\n",
" ('amp@-2', 0.141357574805482),\n",
" ('me@1', 0.07726901404021158)]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"contexts['hate'].most_common(10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"At this point we have a list of dictionaries, one per term, indicating the terms that co-occur (weighted by inverse tweet frequency).\n",
"\n",
"Next, we have to cluster these vectors. To do this, we'll need to be able to compute the euclidean distance between two vectors."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.4142135623730951\n",
"2.23606797749979\n"
]
}
],
"source": [
"# n.b. This is not efficient!\n",
"def distance(c1, c2):\n",
" if len(c1.keys()) == 0 or len(c2.keys()) == 0:\n",
" return 1e9\n",
" keys = set(c1.keys()) | set(c2.keys())\n",
" distance = 0.\n",
" for k in keys:\n",
" distance += (c1[k] - c2[k]) ** 2\n",
" return math.sqrt(distance)\n",
"\n",
"print(distance({'hi':10, 'bye': 5}, {'hi': 9, 'bye': 4}))\n",
"print(distance({'hi':10, 'bye': 5}, {'hi': 8, 'bye': 4}))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(['love', 'hope', 'miss', 'm', 'am', 'guess', 'll', 'mean', 'cant',\n",
" 'wish'], dtype=' 1]\n",
"contexts = dict([(term, contexts[term]) for term in nz_contexts])\n",
"print(len(nz_contexts), 'nonzero contexts')"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"im\n",
"[('it@1', 0.01264806377201966), ('maybe@2', 0.025412211140884603), ('y@-2', 0.01883753470306806)]\n"
]
}
],
"source": [
"# e.g., what are three context features for the term \"rt\"?\n",
"print(list(contexts.keys())[0])\n",
"print(list(list(contexts.values())[0].items())[:3])"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['a@-1' 'a@-2' 'a@1' 'a@2' 'aaaaaa@-1' 'aaaaaa@-2' 'aaaaaa@1' 'aaaaaa@2'\n",
" 'aaron@-1' 'aaron@-2']\n",
" (0, 1)\t0.012176616904934769\n",
" (0, 2)\t0.10774881812308722\n",
" (0, 3)\t0.08406853449098409\n",
" (0, 234)\t0.02946119105248044\n",
" (0, 265)\t0.024956944731107572\n",
" (0, 332)\t0.02065859263633043\n",
" (0, 368)\t0.014543761318494223\n",
" (0, 369)\t0.014329995051336608\n",
" (0, 393)\t0.02368590894509358\n",
" (0, 413)\t0.01978485270160371\n",
" (0, 434)\t0.018043994507438255\n",
" (0, 436)\t0.015269990983388075\n",
" (0, 475)\t0.015840112090628022\n",
" (0, 480)\t0.01604913069072621\n",
" (0, 483)\t0.036733833080258765\n",
" (0, 485)\t0.02443477905749157\n",
" (0, 486)\t0.04897844410701169\n",
" (0, 611)\t0.04022372022569253\n",
" (0, 641)\t0.03840265052217021\n",
" (0, 880)\t0.05818855015204386\n",
" (0, 881)\t0.02916261855721708\n",
" (0, 1053)\t0.01678036618972246\n",
" (0, 1063)\t0.020161956350368875\n",
" (0, 1206)\t0.0446093161177403\n",
" (0, 1209)\t0.022203132192976167\n",
" :\t:\n",
" (0, 14849)\t0.04492780311830858\n",
" (0, 14975)\t0.4561416854630144\n",
" (0, 15102)\t0.019454490377035092\n",
" (0, 15137)\t0.03246274061678437\n",
" (0, 15216)\t0.018578551803747474\n",
" (0, 15217)\t0.018633750350250212\n",
" (0, 15224)\t0.01978485270160371\n",
" (0, 15490)\t0.014819584801508275\n",
" (0, 15491)\t0.015207803801313627\n",
" (0, 15502)\t0.3141806233384448\n",
" (0, 15532)\t0.03659194385094082\n",
" (0, 15671)\t0.013794628438236247\n",
" (0, 15735)\t0.04108584562516659\n",
" (0, 15836)\t0.024359499004844704\n",
" (0, 15870)\t0.025924583046390885\n",
" (0, 15873)\t0.024548591946210438\n",
" (0, 15882)\t0.022409430110929447\n",
" (0, 15893)\t0.01883753470306806\n",
" (0, 15897)\t0.02417935783176741\n",
" (0, 15944)\t0.01820112202264155\n",
" (0, 15971)\t0.051086598423935724\n",
" (0, 16003)\t0.012327178077380431\n",
" (0, 16004)\t0.0497998532898863\n",
" (0, 16006)\t0.03695460482404303\n",
" (0, 16059)\t0.03639315451392994\n"
]
}
],
"source": [
"# Transform context dicts to a sparse vector\n",
"# for sklearn.\n",
"from sklearn.feature_extraction import DictVectorizer\n",
"\n",
"vec = DictVectorizer()\n",
"X = vec.fit_transform(contexts.values())\n",
"names = np.array(vec.get_feature_names())\n",
"print(names[:10])\n",
"print(X[0])"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"56\n",
" (0, 0)\t0.01166315589869699\n",
" (0, 1)\t0.009432626921748103\n",
" (0, 2)\t0.0069556430065184915\n",
" (0, 3)\t0.020932638324993653\n",
" (0, 12)\t0.009893569523380101\n",
" (0, 28)\t0.0029242019578079237\n",
" (0, 29)\t0.005935227260208678\n",
" (0, 31)\t0.005850253394142066\n",
" (0, 40)\t0.00502062522553737\n",
" (0, 134)\t0.01569827105681456\n",
" (0, 136)\t0.00393482667128263\n",
" (0, 181)\t0.008700840482568875\n",
" (0, 219)\t0.009893569523380101\n",
" (0, 233)\t0.0058241406815781925\n",
" (0, 241)\t0.003442994829099182\n",
" (0, 251)\t0.003660037555588828\n",
" (0, 278)\t0.0057055343458282184\n",
" (0, 334)\t0.0038489132041543044\n",
" (0, 367)\t0.002790339436949068\n",
" (0, 368)\t0.008449753057068984\n",
" (0, 369)\t0.002775185795961421\n",
" (0, 370)\t0.019557649156184565\n",
" (0, 402)\t0.005416543436788495\n",
" (0, 410)\t0.009893569523380101\n",
" (0, 412)\t0.003966821569702731\n",
" :\t:\n",
" (0, 15787)\t0.004102954333698908\n",
" (0, 15828)\t0.03718458281719579\n",
" (0, 15829)\t0.0031113083765471677\n",
" (0, 15831)\t0.006289728235394356\n",
" (0, 15847)\t0.022014344111415302\n",
" (0, 15857)\t0.005957911837271409\n",
" (0, 15863)\t0.004299918936961106\n",
" (0, 15871)\t0.00502062522553737\n",
" (0, 15885)\t0.004193444669798336\n",
" (0, 15894)\t0.007005863991464066\n",
" (0, 15895)\t0.003619603953271363\n",
" (0, 15946)\t0.003511581004101353\n",
" (0, 15969)\t0.020762766580557777\n",
" (0, 15971)\t0.009893569523380101\n",
" (0, 15990)\t0.004428087090688781\n",
" (0, 16003)\t0.01193657408187391\n",
" (0, 16004)\t0.012055468703578194\n",
" (0, 16005)\t0.341304167014421\n",
" (0, 16006)\t0.035783646133738856\n",
" (0, 16014)\t0.004876067603250058\n",
" (0, 16019)\t0.011392909025650088\n",
" (0, 16020)\t0.002893415638699762\n",
" (0, 16021)\t0.031212968706279383\n",
" (0, 16022)\t0.0028439084729393136\n",
" (0, 16060)\t0.007437161692980253\n",
"while@1\n"
]
}
],
"source": [
"# Which row of X is the word \"love\"?\n",
"love_idx = list(contexts.keys()).index('love')\n",
"print(love_idx)\n",
"# What are the context feature values for love?\n",
"print(X[love_idx])\n",
"# Print a highly ranking feature.\n",
"print(names[15534])"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"KMeans(algorithm='auto', copy_x=True, init='k-means++', max_iter=300,\n",
" n_clusters=20, n_init=10, n_jobs=None, precompute_distances='auto',\n",
" random_state=None, tol=0.0001, verbose=0)"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Let's cluster!\n",
"# http://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html\n",
"from sklearn.cluster import KMeans\n",
"num_clusters = 20\n",
"kmeans = KMeans(num_clusters)\n",
"kmeans.fit(X)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 to@1 i@-2 i@-1 be@2 the@2\n",
"1 a@-1 i@1 i@2 a@-2 the@-1\n",
"2 the@1 up@1 i@-2 i@2 my@1\n",
"3 my@-1 and@1 i@2 the@-1 i@1\n",
"4 i@-2 m@-1 i@1 m@-2 for@1\n",
"5 of@1 the@-1 the@2 a@-1 a@-2\n",
"6 it@-1 so@-2 much@-1 many@-1 i@1\n",
"7 rt@-1 i@2 i@1 and@1 is@1\n",
"8 in@1 the@2 i@-2 my@2 the@-1\n",
"9 i@-1 rt@-2 i@-2 you@1 it@1\n",
"10 the@2 but@1 i@2 on@1 be@-1\n",
"11 i@1 m@2 rt@-1 i@2 i@-2\n",
"12 me@1 and@-1 a@1 i@-2 i@2\n",
"13 t@1 i@-1 you@-1 be@2 rt@-2\n",
"14 with@1 my@2 i@-2 the@2 to@-1\n",
"15 i@2 i@1 and@1 the@-2 a@-2\n",
"16 to@-1 i@-2 i@-1 t@-1 i@2\n",
"17 the@-1 in@-2 i@1 i@2 of@-2\n",
"18 i@-2 i@2 just@-1 was@-1 and@2\n",
"19 of@-1 and@1 i@1 the@-2 the@-1\n"
]
}
],
"source": [
"# Let's print out the top features for each mean vector.\n",
"# This is swamped by common terms\n",
"for i in range(num_clusters):\n",
" print(i, ' '.join(names[np.argsort(\n",
" kmeans.cluster_centers_[i])[::-1][:5]]))"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"distance from term \"love\" to each cluster:\n",
"[1.00874162 1.0266301 0.97164003 1.03189144 1.03358099 1.04555106\n",
" 0.9921334 1.0001127 0.99631078 0.5906229 0.98516281 1.01236052\n",
" 0.96816683 1.1055978 0.99184346 0.98134482 0.95310019 1.02250443\n",
" 0.95707704 1.00474645]\n",
"closest cluster to \"love\":\n",
"9\n"
]
}
],
"source": [
"# .transform will compute the distance from each context to each cluster.\n",
"distances = kmeans.transform(X)\n",
"# e.g., what is the distance from the word \"love\" to each cluster?\n",
"print('distance from term \"love\" to each cluster:')\n",
"print(distances[love_idx])\n",
"# what is the closest cluster for the word \"love\"?\n",
"print('closest cluster to \"love\":')\n",
"print(np.argmin(distances[love_idx]))"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 supposed listening going wanted used decided listen able needs \n",
"\n",
"1 good few bit wrap couple dream virgin requirement little \n",
"\n",
"2 for on of at with all is during from \n",
"\n",
"3 favorite face hair life heart boyfriend wife brother head \n",
"\n",
"4 not screaming sorry dying doin scared gonna assuming sobbing \n",
"\n",
"5 part rest one majority type habit tired instead benefit \n",
"\n",
"6 everyday aaaaaa whew and but yard sounds things seemed \n",
"\n",
"7 hello hi okay i mingyu nah psa mutual age \n",
"\n",
"8 participate doctor punched zoomed manga engage lived live jump \n",
"\n",
"9 am love hope just ll have guess mean miss \n",
"\n",
"10 shade one me pumped this and okay that here \n",
"\n",
"11 but when because and omg what before where ok \n",
"\n",
"12 with tell making call is help on in make \n",
"\n",
"13 didn wasn ain wouldn couldn haven doesn isn can \n",
"\n",
"14 ruv communicate agree familiar wrong flights along interact dealing \n",
"\n",
"15 that me time but the day this my is \n",
"\n",
"16 be see make do say save hear go cry \n",
"\n",
"17 way best world weekends first saddest waist bathroom bus \n",
"\n",
"18 a never like you sorry u done to so \n",
"\n",
"19 us mone nowhere the concerning them luck people prof \n",
"\n"
]
}
],
"source": [
"# Finally, we'll print the words that are closest\n",
"# to the mean of each cluster.\n",
"terms = np.array(list(contexts.keys()))\n",
"for i in range(distances.shape[1]):\n",
" print(i, ' '.join(terms[np.argsort(distances[:,i])[1:10]]), '\\n')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Clearly, interpreting these results requires a bit of investigation.\n",
"\n",
"As the number of tweets increases, we expect these clusters to become more coherent."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**How does error decrease with number of cluster?**"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-3727.0865918080995"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"kmeans.score(X)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"k=5 score=3860.63\n",
"k=10 score=3786.09\n",
"k=20 score=3722.77\n",
"k=50 score=3638.46\n",
"k=100 score=3553.49\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"scores = []\n",
"num_cluster_options = [5,10,20,50,100]\n",
"\n",
"for num_clusters in num_cluster_options:\n",
" kmeans = KMeans(num_clusters, n_init=10, max_iter=10)\n",
" kmeans.fit(X)\n",
" score = -1 * kmeans.score(X)\n",
" scores.append(score)\n",
" print('k=%d score=%g' % (num_clusters, score))\n",
" \n",
"plt.figure()\n",
"plt.plot(num_cluster_options, scores, 'bo-')\n",
"plt.xlabel('num clusters')\n",
"plt.ylabel('error')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"** How does error vary by initalization? **"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"score=3758.22\n",
"score=3747.06\n",
"score=3741.92\n",
"score=3749.15\n",
"score=3741.44\n",
"score=3752.9\n",
"score=3739.67\n",
"score=3748.15\n",
"score=3744.9\n",
"score=3741.4\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"scores = []\n",
"for i in range(10):\n",
" kmeans = KMeans(20, n_init=1,\n",
" max_iter=10,\n",
" init='random')\n",
" kmeans.fit(X)\n",
" score = -1 * kmeans.score(X)\n",
" scores.append(score)\n",
" print('score=%g' % (score))\n",
" \n",
" \n",
"plt.figure()\n",
"plt.plot(range(10), sorted(scores), 'bo-')\n",
"plt.xlabel('sample')\n",
"plt.ylabel('error')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We now have a way to represent words in 20-dimensional space.\n",
"- The distance from each word vector to the means of each of the 20 clusters."
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1.00874162 1.0266301 0.97164003 1.03189144 1.03358099 1.04555106\n",
" 0.9921334 1.0001127 0.99631078 0.5906229 0.98516281 1.01236052\n",
" 0.96816683 1.1055978 0.99184346 0.98134482 0.95310019 1.02250443\n",
" 0.95707704 1.00474645]\n",
"[0.94998976 0.99347211 0.92797163 1.00760459 0.87369431 1.02867307\n",
" 0.97299124 0.97489501 0.97819442 0.89745615 0.95924255 0.90359807\n",
" 0.93594526 1.2176145 0.98537882 0.96513233 0.93622821 0.9995155\n",
" 0.85784706 0.98290799]\n",
"[1.01027941 1.0360027 0.94782441 1.04189133 1.02496428 1.05721972\n",
" 1.00490278 1.00783092 1.00071431 0.71358098 0.99042856 1.02849363\n",
" 0.97045892 1.14231512 1.01697286 0.98572949 0.96373614 1.03701544\n",
" 0.95307083 1.00673499]\n",
"[1.0771791 1.0382791 1.01844359 1.04147931 1.09082894 1.06598877\n",
" 1.00127222 1.01626417 1.02439718 1.09741815 0.99488921 1.00199894\n",
" 0.99562545 1.28351384 1.04215818 0.98982216 1.04985219 1.04174968\n",
" 1.01364399 1.01655176]\n"
]
}
],
"source": [
"def get_distances(word, contexts, distances):\n",
" wd_idx = list(contexts.keys()).index(word)\n",
" return distances[wd_idx]\n",
"\n",
"print(get_distances('love', contexts, distances))\n",
"print(get_distances('like', contexts, distances))\n",
"print(get_distances('hate', contexts, distances))\n",
"print(get_distances('pizza', contexts, distances))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can use these vectors to compute how similar two words are."
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9958552518939588\n"
]
}
],
"source": [
"from math import sqrt\n",
"def sim(v1, v2):\n",
" \"\"\" cosine similarity of two vectors. \"\"\"\n",
" return np.dot(v1, v2) / (sqrt(np.dot(v1, v1)) * sqrt(np.dot(v2,v2)))\n",
" \n",
"# FIXME: sqrt norm\n",
"print(sim(get_distances('love', contexts, distances),\n",
" get_distances('like', contexts, distances)))"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9995853218896937\n"
]
}
],
"source": [
"print(sim(get_distances('love', contexts, distances),\n",
" get_distances('hate', contexts, distances)))"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.993983274406897\n"
]
}
],
"source": [
"print(sim(get_distances('love', contexts, distances),\n",
" get_distances('pizza', contexts, distances)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So, `love` is more similar to `like` than to `pizza`.\n",
"\n",
"\n",
"
\n",
"**However**, this approach treats each word the same when computing similarity. \n",
"\n",
"- Presumably, some context words are more important than others (e.g., `the` versus `hippopotamus`). \n",
"- `tf-idf` captures this to some extent.\n",
"- Can we use machine learning to weight features based on how predictive they are?\n",
" - But what is the classification task?\n",
" \n",
"- See `Word2Vec.ipynb`\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"code_folding": [
0
]
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#\n",
"from IPython.core.display import HTML\n",
"HTML(open('../custom.css').read())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 1
}