{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 9. Recursive Neural Networks and Constituency Parsing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "I recommend you take a look at these material first." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* http://web.stanford.edu/class/cs224n/lectures/cs224n-2017-lecture14-TreeRNNs.pdf\n", "* https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.autograd import Variable\n", "import torch.optim as optim\n", "import torch.nn.functional as F\n", "import nltk\n", "import random\n", "import numpy as np\n", "from collections import Counter, OrderedDict\n", "import nltk\n", "from copy import deepcopy\n", "import os\n", "from IPython.display import Image, display\n", "from nltk.draw import TreeWidget\n", "from nltk.draw.util import CanvasFrame\n", "from nltk.tree import Tree as nltkTree\n", "flatten = lambda l: [item for sublist in l for item in sublist]\n", "random.seed(1024)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "USE_CUDA = torch.cuda.is_available()\n", "gpus = [0]\n", "torch.cuda.set_device(gpus[0])\n", "\n", "FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor\n", "LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor\n", "ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def getBatch(batch_size, train_data):\n", " random.shuffle(train_data)\n", " sindex = 0\n", " eindex = batch_size\n", " while eindex < len(train_data):\n", " batch = train_data[sindex: eindex]\n", " temp = eindex\n", " eindex = eindex + batch_size\n", " sindex = temp\n", " yield batch\n", " \n", " if eindex >= len(train_data):\n", " batch = train_data[sindex:]\n", " yield batch" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Borrowed from https://stackoverflow.com/questions/31779707/how-do-you-make-nltk-draw-trees-that-are-inline-in-ipython-jupyter\n", "\n", "def draw_nltk_tree(tree):\n", " cf = CanvasFrame()\n", " tc = TreeWidget(cf.canvas(), tree)\n", " tc['node_font'] = 'arial 15 bold'\n", " tc['leaf_font'] = 'arial 15'\n", " tc['node_color'] = '#005990'\n", " tc['leaf_color'] = '#3F8F57'\n", " tc['line_color'] = '#175252'\n", " cf.add_widget(tc, 50, 50)\n", " cf.print_to_file('tmp_tree_output.ps')\n", " cf.destroy()\n", " os.system('convert tmp_tree_output.ps tmp_tree_output.png')\n", " display(Image(filename='tmp_tree_output.png'))\n", " os.system('rm tmp_tree_output.ps tmp_tree_output.png')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data load and Preprocessing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Stanford Sentiment Treebank(https://nlp.stanford.edu/sentiment/index.html)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(3 (2 (1 Deflated) (2 (2 ending) (2 aside))) (4 (2 ,) (4 (2 there) (3 (3 (2 's) (3 (2 much) (2 (2 to) (3 (3 recommend) (2 (2 the) (2 film)))))) (2 .)))))\n", "\n" ] } ], "source": [ "sample = random.choice(open('../dataset/trees/train.txt', 'r', encoding='utf-8').readlines())\n", "print(sample)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhEAAAGVCAMAAAB3iVNzAAAJJGlDQ1BpY2MAAHjalZVnUJNZF8fv\n8zzphUASQodQQ5EqJYCUEFoo0quoQOidUEVsiLgCK4qINEUQUUDBVSmyVkSxsCgoYkE3yCKgrBtX\nERWUF/Sd0Xnf2Q/7n7n3/OY/Z+4995wPFwCCOFgSvLQnJqULvJ3smIFBwUzwg8L4aSkcT0838I96\nPwyg5XhvBfj3IkREpvGX4sLSyuWnCNIBgLKXWDMrPWWZDy8xPTz+K59dZsFSgUt8Y5mjv/Ho15xv\nLPqa4+vNXXoVCgAcKfoHDv+B/3vvslQ4gvTYqMhspk9yVHpWmCCSmbbcCR6Xy/QUJEfFJkT+UPC/\nSv4HpUdmpy9HbnLKBkFsdEw68/8ONTIwNATfZ/HW62uPIUb//85nWd+95HoA2LMAIHu+e+GVAHTu\nAED68XdPbamvlHwAOu7wMwSZ3zzU8oYGBEABdCADFIEq0AS6wAiYAUtgCxyAC/AAviAIrAN8EAMS\ngQBkgVywDRSAIrAH7AdVoBY0gCbQCk6DTnAeXAHXwW1wFwyDJ0AIJsArIALvwTwEQViIDNEgGUgJ\nUod0ICOIDVlDDpAb5A0FQaFQNJQEZUC50HaoCCqFqqA6qAn6BToHXYFuQoPQI2gMmob+hj7BCEyC\n6bACrAHrw2yYA7vCvvBaOBpOhXPgfHg3XAHXwyfgDvgKfBsehoXwK3gWAQgRYSDKiC7CRriIBxKM\nRCECZDNSiJQj9Ugr0o30IfcQITKDfERhUDQUE6WLskQ5o/xQfFQqajOqGFWFOo7qQPWi7qHGUCLU\nFzQZLY/WQVugeehAdDQ6C12ALkc3otvR19DD6An0ewwGw8CwMGYYZ0wQJg6zEVOMOYhpw1zGDGLG\nMbNYLFYGq4O1wnpgw7Dp2AJsJfYE9hJ2CDuB/YAj4pRwRjhHXDAuCZeHK8c14y7ihnCTuHm8OF4d\nb4H3wEfgN+BL8A34bvwd/AR+niBBYBGsCL6EOMI2QgWhlXCNMEp4SyQSVYjmRC9iLHErsYJ4iniD\nOEb8SKKStElcUggpg7SbdIx0mfSI9JZMJmuQbcnB5HTybnIT+Sr5GfmDGE1MT4wnFiG2RaxarENs\nSOw1BU9Rp3Ao6yg5lHLKGcodyow4XlxDnCseJr5ZvFr8nPiI+KwETcJQwkMiUaJYolnipsQUFUvV\noDpQI6j51CPUq9RxGkJTpXFpfNp2WgPtGm2CjqGz6Dx6HL2IfpI+QBdJUiWNJf0lsyWrJS9IChkI\nQ4PBYyQwShinGQ8Yn6QUpDhSkVK7pFqlhqTmpOWkbaUjpQul26SHpT/JMGUcZOJl9sp0yjyVRclq\ny3rJZskekr0mOyNHl7OU48sVyp2WeywPy2vLe8tvlD8i3y8/q6Co4KSQolCpcFVhRpGhaKsYp1im\neFFxWommZK0Uq1SmdEnpJVOSyWEmMCuYvUyRsryys3KGcp3ygPK8CkvFTyVPpU3lqSpBla0apVqm\n2qMqUlNSc1fLVWtRe6yOV2erx6gfUO9Tn9NgaQRo7NTo1JhiSbN4rBxWC2tUk6xpo5mqWa95Xwuj\nxdaK1zqodVcb1jbRjtGu1r6jA+uY6sTqHNQZXIFeYb4iaUX9ihFdki5HN1O3RXdMj6Hnppen16n3\nWl9NP1h/r36f/hcDE4MEgwaDJ4ZUQxfDPMNuw7+NtI34RtVG91eSVzqu3LKya+UbYx3jSONDxg9N\naCbuJjtNekw+m5qZCkxbTafN1MxCzWrMRth0tie7mH3DHG1uZ77F/Lz5RwtTi3SL0xZ/Wepaxls2\nW06tYq2KXNWwatxKxSrMqs5KaM20DrU+bC20UbYJs6m3eW6rahth22g7ydHixHFOcF7bGdgJ7Nrt\n5rgW3E3cy/aIvZN9of2AA9XBz6HK4ZmjimO0Y4ujyMnEaaPTZWe0s6vzXucRngKPz2viiVzMXDa5\n9LqSXH1cq1yfu2m7Cdy63WF3F/d97qOr1Vcnre70AB48j30eTz1Znqmev3phvDy9qr1eeBt653r3\n+dB81vs0+7z3tfMt8X3ip+mX4dfjT/EP8W/ynwuwDygNEAbqB24KvB0kGxQb1BWMDfYPbgyeXeOw\nZv+aiRCTkIKQB2tZa7PX3lwnuy5h3YX1lPVh68+EokMDQptDF8I8wurDZsN54TXhIj6Xf4D/KsI2\noixiOtIqsjRyMsoqqjRqKtoqel/0dIxNTHnMTCw3tir2TZxzXG3cXLxH/LH4xYSAhLZEXGJo4rkk\nalJ8Um+yYnJ28mCKTkpBijDVInV/qkjgKmhMg9LWpnWl05c+xf4MzYwdGWOZ1pnVmR+y/LPOZEtk\nJ2X3b9DesGvDZI5jztGNqI38jT25yrnbcsc2cTbVbYY2h2/u2aK6JX/LxFanrce3EbbFb/stzyCv\nNO/d9oDt3fkK+Vvzx3c47WgpECsQFIzstNxZ+xPqp9ifBnat3FW560thROGtIoOi8qKFYn7xrZ8N\nf674eXF31O6BEtOSQ3swe5L2PNhrs/d4qURpTun4Pvd9HWXMssKyd/vX779Zblxee4BwIOOAsMKt\noqtSrXJP5UJVTNVwtV11W418za6auYMRB4cO2R5qrVWoLar9dDj28MM6p7qOeo368iOYI5lHXjT4\nN/QdZR9tapRtLGr8fCzpmPC49/HeJrOmpmb55pIWuCWjZfpEyIm7J+1PdrXqtta1MdqKToFTGade\n/hL6y4PTrqd7zrDPtJ5VP1vTTmsv7IA6NnSIOmM6hV1BXYPnXM71dFt2t/+q9+ux88rnqy9IXii5\nSLiYf3HxUs6l2cspl2euRF8Z71nf8+Rq4NX7vV69A9dcr9247nj9ah+n79INqxvnb1rcPHeLfavz\ntuntjn6T/vbfTH5rHzAd6Lhjdqfrrvnd7sFVgxeHbIau3LO/d/0+7/7t4dXDgw/8HjwcCRkRPox4\nOPUo4dGbx5mP559sHUWPFj4Vf1r+TP5Z/e9av7cJTYUXxuzH+p/7PH8yzh9/9UfaHwsT+S/IL8on\nlSabpoymzk87Tt99ueblxKuUV/MzBX9K/FnzWvP12b9s/+oXBYom3gjeLP5d/Fbm7bF3xu96Zj1n\nn71PfD8/V/hB5sPxj+yPfZ8CPk3OZy1gFyo+a33u/uL6ZXQxcXHxPy6ikLxyKdSVAAAAIGNIUk0A\nAHomAACAhAAA+gAAAIDoAAB1MAAA6mAAADqYAAAXcJy6UTwAAACrUExURf///wBZkABZkABZkABZ\nkABZkABZkABZkABZkABZkABZkABZkBdSUhdSUhdTUxdSUhdSUhdSUhdSUhdSUhdSUhdSUhdSUhdS\nUhdSUhdSUhdSUhdSUgBZkABZkBdSUgBZkBdSUj+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+PVz+P\nVz+PVz+PVz+PVz+PVxdSUhdSUhdSUhdTUxdSUhhUVABZkBdSUj+PV////2WyAXMAAAA1dFJOUwAR\niO6ZuyJVZjPdqkR3dZnMu6OIZjMRqt3uVSJEzMd31rsziN2ZIkQRqlXuZsx3r+HSW+wg7fJpnQAA\nAAFiS0dEAIgFHUgAAAAJcEhZcwAAAEgAAABIAEbJaz4AAAAHdElNRQfhCwINNSfD9n+TAAASs0lE\nQVR42u2dCZuiuBqFU9pVXVWtrQ6uZffMZVFwq5m7SP7/P7skiFuxhiAJnPd5ulAgIZpD8sVODoQA\nAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA\nAAAAAAAAAAAAAAAAAAAAaDZPne6x++2p7mIAZfh27D53j891FwMoQ6fzQp6Ox7qLAZTiO9oIcEUn\niCNe6i4EUIjX5273te5CAKV4O3brLgJQhvfjE3lBZAnOPPPR54+6iwGU4ekbfqECAAAAgCi9/s9+\nr+5CADXo9QdD3x/9MfL94QCyaDcnMRjjCX87GRuQRWu5E8MFyKJ1THvGLF4MF06ymBm9ad3FBVXC\nxDD3FyPjY5Lj7MmHMVr4c8iimVzEsCyUbglZNA9RMVyALBpDWJVlxFBNXqAGqrmvy7c3oAaqbuQh\nC4143Pig2LgF1EAdvyHk+W0D1EDNvzMm/v4JakCZH50hCxUYqyGGCydZjOsuR2vpKSQG1UsFAABA\nUZ5+/aq7CK1EXReQH1geVgcdZV1Afh2hiDp47bwp6gLy3lWyWK3g5fhedxG+0jl+hyJq4un3UT3P\nh5fuDwJF1MPbu4KCID+6L1BEPbx2uwoO8t6Ovzud47GjoFabztvxXUXjKDbQYKg4CGo478f354C3\nussRB3qNOjjdiwp2HFAEAAAAoAG9n0pOTZmOf46x9uvxLI354s/F3FBrycT0wxj68z/n/tD4gCoe\nyHQ88kfBjRht1aBnjPzFbMw0uhzPFv7IULINayC9wVXbwNqKQf3f/KQ/8/1Zf5K+C1RAjAJuFFJH\nkZIbhKtmA1RBUi9RX+8xHQ8ygoYwtBgo07c1idS2oIbeI39dZ+sGFCZHjT+09yjcHyDYlEneXuFB\nvYdwzIhgUw6F7v2qe4/StzqCzZII1HBlvYescADBpjCivUAFvYfsWkSwKUCpe11q71FRS49gswgS\nalRO71FxNIhgMxeyWv2y+TzoJkawmYUhMTJkbY0hWIwHdvRhmCJY0ObTk3tTimbXe/BgYDpGRAEA\nAIVpryFHez95Kr/e5S53EHQeqcOwBFYkcfx1fJb6vQg6j9RhWAIrklj+epW7JErQeaQOwxJYkSQh\n/XsRdB55sGEJrEgSkf29CDqPPNiwBFYkyUj+XgSdRx5tWAIrkmTkfi+CziOPNiyBFUkKUhUh6Dzy\ncMMSWJEk8YvfKR1p96eg80g9hiXoNWLohLdKR1Z+gs4j9RiWQBEAAACAaix/6jFfrdf/2cdzqGPx\npebWX/xLbF6d3GKkMflgDw/1R3+MfH+OR4h+xZd4Ty9HizH5WAwFJqzJLEYS015/wGQwix4wvPzg\njxAdDvDY+it8efMNjcWMfbFTwzcKf8ESixEHe34xaxiMmCe99canx9Zj6iVHWlVMhouP08ve5eXD\ni3HPqWFgzy9epp5msIdFDmd4QLmsqjD82VXDYCxGxb7YKhTBnlg7L3TzXzUl7e1G5FTFZDi/zScI\nKQpFmHIVcbnje8Xv+KBVMdocdY4kVEVs5FAswpRRDM5EUlTQ3qhTQlX05vO4TApFmBKKETUMMuuw\njVFn6apIqfgCEWa5Ylwqrop2vmVR56jk+sfefJhSC7kjTNFiLHuPatxbE3WWU8R05qenzxthChQj\nbBjmFTUMCZ+3BVFnKUUE0WPmF5MvwixUjFPDMBrUdLs2O+osoYjpLNf9nyvCzFsMhe7RpkadA2FF\njHP/CpUjwswuxvL0P1R1NQzxNDDqNEQVMVj0C1xlMShVjPHpf6hqbxjiiaLOcd0FqZVivwYuS9pc\n9dVvmac6FBIAAEQxqUlMK+0EK3deQDdsGuCs1le7XGp5hLqXHeZ9ouuDdzx9K+QBUvB0qakfxose\nxTxh0+CPtaHbyy6X7bqqdIveJ0pRxO9iHiAFT5ea+mGwlWkaFPOEHVb3ZucRsnXobuW5QaNhskr3\nVkHrsSXsvX06xs9xtimK6HSeyEv+RVEFT5ea+lG8sGI+2hdFnJMiLLon7s4l68Pq3EasHI/sqcff\nR8e8QBbeJkUR/Cso5gFSzjLkwYYjovylWxtBgmZgswq2W3pWhLcOt+x9dGxPg53rdEU8/T5+L1CC\ngqdLTf0ono/vesURhCvCpBw3UsR6w99yRUTHwtNTFfGr2y3ivlDwdKmpH0YnkIQulojnXmNLTDvc\nFSnC3KzPbUR0LFsRr8ffRRb3FzxdaupH8qZH50YukaUT/DuEuyJFsHpfR73G6RjvNawURQR1VKR5\nLHi61NSP4lf3XYcAOOI0+txZTAmfwfhic1aEY5O1ySJOHl2Gx9a7FRuDJCoi+OC/mQdIzqvnOd2W\ndbG6eOry0ee3usuRE/4L1S78hSoYWdKgp4gU4Tr0YNlBgHEIR5/sWDj6TP7NMnLtyXn1PKdTW9LF\nauPlx1GjX6jUx/TqLgFQCm9TdwkAAKBaemOjaVMuC1FwNp7w5D0Jqatm0jPYHFx/9PffI59NxTV6\nis7+qxQogs2jGxujIZ8a3o8mGS7ZbPG57w/ZHG2FpghXT7sVMfkwmCuFPxokzAbOPKF5tFQRxZqA\nuEaksbROEUHlzgTDhHOgMWty6NkeRUirzxKa0oEWKKKaNr+xoWejFfGAuLB5oWczFfHoG7hJoWfT\nFFFrJ9+I0LMxilCoNvQOPfVXhKottq6hp86K0CKq06KQ12ipCP1uP1Ubshi0U8RU4y76KtjRQsea\nUHvgWJ6gwai7CAAAIIZlZqwDvMGll1QFLgA0gq8lzz2l3nMvqfJg5j81F3rZfSShtreKWIXlTMXs\nTaQqQi+7jySU8VZhpiOHLSEb16S7T8KXiJmflPcal307utleKtFmq8MsQvZsG7pTRKmi7JLh9iYu\nDdI61uX81dah/DqOTQqimd1HEsp4q5grj3zSNaGHNd/azGTACRVx3ues2YLBKMme7oOaNPkaY2/D\nV5meU0XZpRDamwRKYmuXL5ffu8QOrufu9iIfQx+7jzRU8FaxeO0dbEJZRbDFw8x45NRGnPeFbccl\nEeH1atFw8WjwMkp1zi4Frgh2nnt1Pl8aysVgCwSdOtl9pKCEt4ob+oyY4dCCqcAmYaWdzAbO20vX\n760cloaJ5GDzFennVOfsUi95yuz6/PBSIcU/hU52H8mo4a1yrucCijAP1unter/a2beKyHfJiyLu\nLi+KPnYfiSjirWJFpodRldz0GtE+3pqfK5vvjt6GrleXXmObeclrRdxeflc4qmRoZveRhDLeKiaP\n5qxz7ds7N4wirxSxOtxElkH1e5876n2ywYLt8MgyShVll4J7Ms8LvZCuL29zO73Pgp9AM7uPBNTx\nVuGWl5+XNoK9N/fmjSLYvqvRp8WMMdeOGY1Cg5o9p4qyS+PARp8kVMTN5dnoc2cXNptohN2HLt4q\nV2x3dZcAKANzyGXmqACcsE8GygDowbTfx/QjcGYyWPzzz2Kg4Xy6O1RYE90AxiN/1iOkN9P/2WrG\noHwebWdpLBbG8utrPTFGdZdAd760C6f2QlegiFJM+/OY2CGIKebaRplQRAmSaz5eKVoARQiT0Tvo\nGmX2oAgh8kSQekaZUIQIue9/DaNMKKIwxWIE7aLMnl93CTSjeA1rFmVCEYUQ7AV0ijKhiPyUiRT1\niTKhiLyUvs81iTKXUEQuljJiARaDqN9OQBH5kDNemPbr/hwAAAAAeCyyfDaa4ToC5PlsNMN1BEjz\n2WiI6wgIkeWz0QzXESDLZ6MhriNAms9GM1xHgEyfjQa4jgBZPhsNcR0B0nw29HAdcYU8dtqFNJ8N\nLVxHbJm+rwAAAACQB1xHGoMUn42Pmf/PP/7so+4PAyRQ3mdjaYRTNdlkSz1mZYM0Sq6Yno5H/mg8\njXkDNKWUIr42C1GDAbRFXBHT/jw2dAiCCr0WhIIbRBWRVu9JWgE6ILSGPrtvQJipLcUVkTN+RJip\nKUUVUeTmR5ipI4VWTBcPEBBmakcBRYjVLsJMzciriDI9AMJMnciliNJRIsJMfcjhsyHnFkeYqQtZ\niljKCwNYIILOQ39kDhVgPAIAAPpiiT83GjSSVeYKBlmWI1XkBuSTvaZFluVIFbkB+bgmsQ6U8kfZ\nxyPJcqSS3EAlODYhe5r6bHpZliNV5Aaks/vMOkOS5UgluQH57HfOapt2gizLkSpyA1XgbW3nkHxY\nouWI9NxAVXg08XcJSZYjleQGKmEXRJZbuk44KstypIrcQDWw0echMZCQZjlSQW4AAAAAaBnLfh+T\nqHRHiuUIZ9kf+v/+tz+EKPRGkiLYbGyuBaYLiEJnZChiOp75V9O5IQqtKa2IUA538/IhCn0pp4hY\nOYRAFJpSRhEfbO1PyqodiEJHhBURyGExyHx6NEShHWKK4HLIuRQMotALAUX0CsghBKLQiKKKmLAV\nvwILRSEKXSikiEAO/kzYFQCi0IL8iignhxCIQn1yKmIpQQ6nnCAKtcmjCFaJMi1kIAqVyVbEsorq\n46KAJjSlort5CfsRAADQGcusuwRALdwYv4lORSYhT1VlDCQSo4hOVSYh3+A+oggb16S7veVQxw0V\nwP5YG0o3XvByH+y/MZt47byRpyqWaXU6L9VkDApCD2vySc01WTlnRXjOyvM2JnF3gSw2zn2SykxC\nvqONUAC6P/0JlBApgi8VXrM2Yx3TdTz9PlbiCRF0SN9e6v46AOHOAezPlSIiB7NzL3LN23s1giCv\nz7AfUYGiinjtdn9VVZa3Y7furwN8VcSe9RpBNLm24xTxdnyvpmV/P8LXTg2uFbGmn2R9YJHlJogs\nD3GKeD++M5MQ+V5Cz3z0+aPurwPcKIJ87qjpmlejT3KviJNJiPyOA/64AAAAAChB72fmijCgJEWe\nIZub6Xjo/+kP8fw/HfHlT6pbGou5MSVTY77Ag0L1w5fduH+M/NFHzGugCXIV8aVdOLUXQB9kKqI3\niIkdWEyR7TsAlEGaIlJqPlYpQFFGchSR0TsgytQHKYrIE0EiytSE8orIff8jytSCUUl/w0IxAqJM\nDSilCIEaRpSpOiUUIdgLIMpUG2FFlIkUEWUqzEBIEaXvc0SZyiLkeGlIiAVYDCLtWQ+gZnpyxguS\nsgEAAADA4/GEXGekuYXAdUQ5tlQklSy3kMrsTIAoNqXUJZ8OpRurQDJZbiGV2ZkAYfhSQLon3mbn\nFUspyy2kMjsTIARTxGETvLDotkg6aW4hVdmZAEGYIqjNXoV/8yLLLaQyOxMgiKgiJLmFVGlnAoQQ\n6zVkuYVUZmcChHGpR/ZBZLk2nQKpZLmFVGZnAoTxDtHoc10glSy3kMrsTAAAAADQBqb/wby6hiDH\nhmQ6/O8QkmgGUhSxHA7/N8Tz3ZqBDEVMFoPg72AxqfvDAAlIUEQoCEiiIZRXRCQISKIZlFbEx+Ky\nVMNYYHmX9pRVxNgfJ74DOlJSEfcSgCS0p5wivgoAktCdUoqIixuu4wqgIWUUET+2uIw9gI6UUETS\nYBOS0BpxRST/+gBJ6IyoIqZp/4+xHOL/vbRFVBGj1DqfDkd1fzDwYCbpjcAUv2cDAEBrMYusOAct\ngD+O9oosBxE4jDSdO0VkOYjAYURJtuY+qMitQ3crjxDboc6e8O3OJsQ1XYeaa5PutkGfsDWpY9mU\nLyKOEphucPCT+RnR3f5OEVkOInAYURKXblzP3blkfViRz2DrBvVqB//cQBJuUPtbtiRw5QQtgMm2\nq2DHmkQJCD2syWewY7WzvNV9r0GyHUTgMKIcLg3CwU1Qt8ykitUxcdfhNlABWzlM6P5kM7Anpx3u\nOQHfyXYESiHrr4rIchCBw4h6sLomJuW4kasE39o0PMjqmSvitGUvogRh6MC2NvkaWWY6iMBhREFC\nRVwrgeRRRGRIkqqILAcROIyoCK/jzSF8w91G9m64XR2SFRElOCsirtfIchCBw4iS8Dp26SfxVpsg\nRmSR5TaMLKmdrIgowVkRq93as+8UkeUgAocRJeF1zAaT3GWEjT6DsSTf2iRZEecEkSL46NO8tTLK\nchCBwwgAAAAAAMhP1mw8OU4lQB8yFQF/kZaRqQg8yq1lQBHgFigC3AJFgFuyFDGCIloGFAFugSLA\nLZmKgK9Iy4AiwC1QBLgFigC3ZCliAEW0jCxFGFAEAAAAAL5gmcnHTGrylWFwIGkNZrQQJBaXWh5x\nPRKzlhw0E4tmKCLcQhFtwaWU2i7dO9QJ+gVvRelhe3eY9xpsueC1ZwloLKwRcHcbj2ycoAdZedx0\n5PZwpIhrzxLQWLgimAiCFxYXw8G+O3xWxJVnCWgsLj0vOud9BO8mbg+fFUGu1hiDpnKriNjDUESr\nuFaERbdxh6GIVsGig0gRxDwwbzvr5jCBIlrGgY0+SVjPbPTJbUrOQBEAAAAAAAAAAABQnf8DmwKy\nJJha+OIAAAAldEVYdGRhdGU6Y3JlYXRlADIwMTctMTEtMDJUMjI6NTM6MzkrMDk6MDBXyJTvAAAA\nJXRFWHRkYXRlOm1vZGlmeQAyMDE3LTExLTAyVDIyOjUzOjM5KzA5OjAwJpUsUwAAACN0RVh0cHM6\nSGlSZXNCb3VuZGluZ0JveAA1Mjl4NDA1LTI2NC0yMDKzNU7+AAAAHHRFWHRwczpMZXZlbABBZG9i\nZS0zLjAgRVBTRi0zLjAKm3C74wAAACJ0RVh0cHM6U3BvdENvbG9yLTAAZm9udCBMaWJlcmF0aW9u\nU2Fuc/4Zp8YAAAAASUVORK5CYII=\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "draw_nltk_tree(nltkTree.fromstring(sample))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Tree Class " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "borrowed code from https://github.com/bogatyy/cs224d/tree/master/assignment3" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class Node: # a node in the tree\n", " def __init__(self, label, word=None):\n", " self.label = label\n", " self.word = word\n", " self.parent = None # reference to parent\n", " self.left = None # reference to left child\n", " self.right = None # reference to right child\n", " # true if I am a leaf (could have probably derived this from if I have\n", " # a word)\n", " self.isLeaf = False\n", " # true if we have finished performing fowardprop on this node (note,\n", " # there are many ways to implement the recursion.. some might not\n", " # require this flag)\n", "\n", " def __str__(self):\n", " if self.isLeaf:\n", " return '[{0}:{1}]'.format(self.word, self.label)\n", " return '({0} <- [{1}:{2}] -> {3})'.format(self.left, self.word, self.label, self.right)\n", "\n", "\n", "class Tree:\n", "\n", " def __init__(self, treeString, openChar='(', closeChar=')'):\n", " tokens = []\n", " self.open = '('\n", " self.close = ')'\n", " for toks in treeString.strip().split():\n", " tokens += list(toks)\n", " self.root = self.parse(tokens)\n", " # get list of labels as obtained through a post-order traversal\n", " self.labels = get_labels(self.root)\n", " self.num_words = len(self.labels)\n", "\n", " def parse(self, tokens, parent=None):\n", " assert tokens[0] == self.open, \"Malformed tree\"\n", " assert tokens[-1] == self.close, \"Malformed tree\"\n", "\n", " split = 2 # position after open and label\n", " countOpen = countClose = 0\n", "\n", " if tokens[split] == self.open:\n", " countOpen += 1\n", " split += 1\n", " # Find where left child and right child split\n", " while countOpen != countClose:\n", " if tokens[split] == self.open:\n", " countOpen += 1\n", " if tokens[split] == self.close:\n", " countClose += 1\n", " split += 1\n", "\n", " # New node\n", " node = Node(int(tokens[1])) # zero index labels\n", "\n", " node.parent = parent\n", "\n", " # leaf Node\n", " if countOpen == 0:\n", " node.word = ''.join(tokens[2: -1]).lower() # lower case?\n", " node.isLeaf = True\n", " return node\n", "\n", " node.left = self.parse(tokens[2: split], parent=node)\n", " node.right = self.parse(tokens[split: -1], parent=node)\n", "\n", " return node\n", "\n", " def get_words(self):\n", " leaves = getLeaves(self.root)\n", " words = [node.word for node in leaves]\n", " return words\n", "\n", "def get_labels(node):\n", " if node is None:\n", " return []\n", " return get_labels(node.left) + get_labels(node.right) + [node.label]\n", "\n", "def getLeaves(node):\n", " if node is None:\n", " return []\n", " if node.isLeaf:\n", " return [node]\n", " else:\n", " return getLeaves(node.left) + getLeaves(node.right)\n", "\n", " \n", "def loadTrees(dataSet='train'):\n", " \"\"\"\n", " Loads training trees. Maps leaf node words to word ids.\n", " \"\"\"\n", " file = '../dataset/trees/%s.txt' % dataSet\n", " print(\"Loading %s trees..\" % dataSet)\n", " with open(file, 'r', encoding='utf-8') as fid:\n", " trees = [Tree(l) for l in fid.readlines()]\n", "\n", " return trees" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading train trees..\n" ] } ], "source": [ "train_data = loadTrees('train')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Build Vocab " ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "vocab = list(set(flatten([t.get_words() for t in train_data])))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "word2index = {'': 0}\n", "for vo in vocab:\n", " if word2index.get(vo) is None:\n", " word2index[vo] = len(word2index)\n", " \n", "index2word = {v:k for k, v in word2index.items()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Modeling " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "
borrowed image from https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf
" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class RNTN(nn.Module):\n", " \n", " def __init__(self, word2index, hidden_size, output_size):\n", " super(RNTN,self).__init__()\n", " \n", " self.word2index = word2index\n", " self.embed = nn.Embedding(len(word2index), hidden_size)\n", "# self.V = nn.ModuleList([nn.Linear(hidden_size*2,hidden_size*2) for _ in range(hidden_size)])\n", "# self.W = nn.Linear(hidden_size*2,hidden_size)\n", " self.V = nn.ParameterList([nn.Parameter(torch.randn(hidden_size * 2, hidden_size * 2)) for _ in range(hidden_size)]) # Tensor\n", " self.W = nn.Parameter(torch.randn(hidden_size * 2, hidden_size))\n", " self.b = nn.Parameter(torch.randn(1, hidden_size))\n", "# self.W_out = nn.Parameter(torch.randn(hidden_size,output_size))\n", " self.W_out = nn.Linear(hidden_size, output_size)\n", " \n", " def init_weight(self):\n", " nn.init.xavier_uniform(self.embed.state_dict()['weight'])\n", " nn.init.xavier_uniform(self.W_out.state_dict()['weight'])\n", " for param in self.V.parameters():\n", " nn.init.xavier_uniform(param)\n", " nn.init.xavier_uniform(self.W)\n", " self.b.data.fill_(0)\n", "# nn.init.xavier_uniform(self.W_out)\n", " \n", " def tree_propagation(self, node):\n", " \n", " recursive_tensor = OrderedDict()\n", " current = None\n", " if node.isLeaf:\n", " tensor = Variable(LongTensor([self.word2index[node.word]])) if node.word in self.word2index.keys() \\\n", " else Variable(LongTensor([self.word2index['']]))\n", " current = self.embed(tensor) # 1xD\n", " else:\n", " recursive_tensor.update(self.tree_propagation(node.left))\n", " recursive_tensor.update(self.tree_propagation(node.right))\n", " \n", " concated = torch.cat([recursive_tensor[node.left], recursive_tensor[node.right]], 1) # 1x2D\n", " xVx = [] \n", " for i, v in enumerate(self.V):\n", "# xVx.append(torch.matmul(v(concated),concated.transpose(0,1)))\n", " xVx.append(torch.matmul(torch.matmul(concated, v), concated.transpose(0, 1)))\n", " \n", " xVx = torch.cat(xVx, 1) # 1xD\n", "# Wx = self.W(concated)\n", " Wx = torch.matmul(concated, self.W) # 1xD\n", "\n", " current = F.tanh(xVx + Wx + self.b) # 1xD\n", " recursive_tensor[node] = current\n", " return recursive_tensor\n", " \n", " def forward(self, Trees, root_only=False):\n", " \n", " propagated = []\n", " if not isinstance(Trees, list):\n", " Trees = [Trees]\n", " \n", " for Tree in Trees:\n", " recursive_tensor = self.tree_propagation(Tree.root)\n", " if root_only:\n", " recursive_tensor = recursive_tensor[Tree.root]\n", " propagated.append(recursive_tensor)\n", " else:\n", " recursive_tensor = [tensor for node,tensor in recursive_tensor.items()]\n", " propagated.extend(recursive_tensor)\n", " \n", " propagated = torch.cat(propagated) # (num_of_node in batch, D)\n", " \n", "# return F.log_softmax(propagated.matmul(self.W_out))\n", " return F.log_softmax(self.W_out(propagated),1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It takes for a while... It builds its computational graph dynamically. So Its computation is difficult to train with batch." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true }, "outputs": [], "source": [ "HIDDEN_SIZE = 30\n", "ROOT_ONLY = False\n", "BATCH_SIZE = 20\n", "EPOCH = 20\n", "LR = 0.01\n", "LAMBDA = 1e-5\n", "RESCHEDULED = False" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": true }, "outputs": [], "source": [ "model = RNTN(word2index, HIDDEN_SIZE,5)\n", "model.init_weight()\n", "if USE_CUDA:\n", " model = model.cuda()\n", "\n", "loss_function = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters(), lr=LR)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0/20] mean_loss : 1.62\n", "[0/20] mean_loss : 1.25\n", "[0/20] mean_loss : 0.95\n", "[0/20] mean_loss : 0.90\n", "[0/20] mean_loss : 0.88\n", "[1/20] mean_loss : 0.88\n", "[1/20] mean_loss : 0.84\n", "[1/20] mean_loss : 0.83\n", "[1/20] mean_loss : 0.82\n", "[1/20] mean_loss : 0.82\n", "[2/20] mean_loss : 0.81\n", "[2/20] mean_loss : 0.79\n", "[2/20] mean_loss : 0.78\n", "[2/20] mean_loss : 0.76\n", "[2/20] mean_loss : 0.75\n", "[3/20] mean_loss : 0.68\n", "[3/20] mean_loss : 0.73\n", "[3/20] mean_loss : 0.74\n", "[3/20] mean_loss : 0.72\n", "[3/20] mean_loss : 0.72\n", "[4/20] mean_loss : 0.74\n", "[4/20] mean_loss : 0.69\n", "[4/20] mean_loss : 0.69\n", "[4/20] mean_loss : 0.68\n", "[4/20] mean_loss : 0.67\n", "[5/20] mean_loss : 0.73\n", "[5/20] mean_loss : 0.65\n", "[5/20] mean_loss : 0.64\n", "[5/20] mean_loss : 0.64\n", "[5/20] mean_loss : 0.65\n", "[6/20] mean_loss : 0.67\n", "[6/20] mean_loss : 0.62\n", "[6/20] mean_loss : 0.62\n", "[6/20] mean_loss : 0.62\n", "[6/20] mean_loss : 0.62\n", "[7/20] mean_loss : 0.57\n", "[7/20] mean_loss : 0.59\n", "[7/20] mean_loss : 0.59\n", "[7/20] mean_loss : 0.59\n", "[7/20] mean_loss : 0.59\n", "[8/20] mean_loss : 0.60\n", "[8/20] mean_loss : 0.58\n", "[8/20] mean_loss : 0.59\n", "[8/20] mean_loss : 0.60\n", "[8/20] mean_loss : 0.60\n", "[9/20] mean_loss : 0.52\n", "[9/20] mean_loss : 0.58\n", "[9/20] mean_loss : 0.60\n", "[9/20] mean_loss : 0.59\n", "[9/20] mean_loss : 0.59\n", "[10/20] mean_loss : 0.56\n", "[10/20] mean_loss : 0.56\n", "[10/20] mean_loss : 0.56\n", "[10/20] mean_loss : 0.56\n", "[10/20] mean_loss : 0.56\n", "[11/20] mean_loss : 0.52\n", "[11/20] mean_loss : 0.54\n", "[11/20] mean_loss : 0.54\n", "[11/20] mean_loss : 0.54\n", "[11/20] mean_loss : 0.55\n", "[12/20] mean_loss : 0.55\n", "[12/20] mean_loss : 0.53\n", "[12/20] mean_loss : 0.53\n", "[12/20] mean_loss : 0.53\n", "[12/20] mean_loss : 0.53\n", "[13/20] mean_loss : 0.59\n", "[13/20] mean_loss : 0.52\n", "[13/20] mean_loss : 0.52\n", "[13/20] mean_loss : 0.53\n", "[13/20] mean_loss : 0.53\n", "[14/20] mean_loss : 0.49\n", "[14/20] mean_loss : 0.51\n", "[14/20] mean_loss : 0.51\n", "[14/20] mean_loss : 0.52\n", "[14/20] mean_loss : 0.52\n", "[15/20] mean_loss : 0.43\n", "[15/20] mean_loss : 0.51\n", "[15/20] mean_loss : 0.51\n", "[15/20] mean_loss : 0.51\n", "[15/20] mean_loss : 0.51\n", "[16/20] mean_loss : 0.46\n", "[16/20] mean_loss : 0.50\n", "[16/20] mean_loss : 0.50\n", "[16/20] mean_loss : 0.50\n", "[16/20] mean_loss : 0.50\n", "[17/20] mean_loss : 0.50\n", "[17/20] mean_loss : 0.50\n", "[17/20] mean_loss : 0.50\n", "[17/20] mean_loss : 0.50\n", "[17/20] mean_loss : 0.51\n", "[18/20] mean_loss : 0.46\n", "[18/20] mean_loss : 0.50\n", "[18/20] mean_loss : 0.50\n", "[18/20] mean_loss : 0.49\n", "[18/20] mean_loss : 0.49\n", "[19/20] mean_loss : 0.49\n", "[19/20] mean_loss : 0.49\n", "[19/20] mean_loss : 0.49\n", "[19/20] mean_loss : 0.50\n", "[19/20] mean_loss : 0.50\n" ] } ], "source": [ "for epoch in range(EPOCH):\n", " losses = []\n", " \n", " # learning rate annealing\n", " if RESCHEDULED == False and epoch == EPOCH//2:\n", " LR *= 0.1\n", " optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=LAMBDA) # L2 norm\n", " RESCHEDULED = True\n", " \n", " for i, batch in enumerate(getBatch(BATCH_SIZE, train_data)):\n", " \n", " if ROOT_ONLY:\n", " labels = [tree.labels[-1] for tree in batch]\n", " labels = Variable(LongTensor(labels))\n", " else:\n", " labels = [tree.labels for tree in batch]\n", " labels = Variable(LongTensor(flatten(labels)))\n", " \n", " model.zero_grad()\n", " preds = model(batch, ROOT_ONLY)\n", " \n", " loss = loss_function(preds, labels)\n", " losses.append(loss.data.tolist()[0])\n", " \n", " loss.backward()\n", " optimizer.step()\n", " \n", " if i % 100 == 0:\n", " print('[%d/%d] mean_loss : %.2f' % (epoch, EPOCH, np.mean(losses)))\n", " losses = []\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The convergence of the model is unstable according to the initial values. I tried to 5~6 times for this." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading test trees..\n" ] } ], "source": [ "test_data = loadTrees('test')" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": true }, "outputs": [], "source": [ "accuracy = 0\n", "num_node = 0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fine-grained all" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In paper, they acheived 80.2 accuracy. " ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "79.33705899068254\n" ] } ], "source": [ "for test in test_data:\n", " model.zero_grad()\n", " preds = model(test, ROOT_ONLY)\n", " labels = test.labels[-1:] if ROOT_ONLY else test.labels\n", " for pred, label in zip(preds.max(1)[1].data.tolist(), labels):\n", " num_node += 1\n", " if pred == label:\n", " accuracy += 1\n", "\n", "print(accuracy/num_node * 100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TODO " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* https://github.com/nearai/pytorch-tools # Dynamic batch using TensorFold" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## Further topics " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks\n", "* A Fast Unified Model for Parsing and Sentence Understanding(SPINN)\n", "* Posting about SPINN" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }