{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"lstm_seq2seq_tf_addons_clr.ipynb","provenance":[],"collapsed_sections":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"code","metadata":{"id":"LNBbqbI2mpVm","colab_type":"code","outputId":"3becdf6a-88f1-4a64-f64f-e90d8f2d8363","executionInfo":{"status":"ok","timestamp":1588208735131,"user_tz":-480,"elapsed":1545,"user":{"displayName":"如子","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gi3ItGjzEGzUOlXTUHjOgeuVA5TICdNcY-Q1TGicA=s64","userId":"01997730851420384589"}},"colab":{"base_uri":"https://localhost:8080/","height":36}},"source":["from google.colab import drive\n","drive.mount('/content/gdrive')\n","import os\n","os.chdir('/content/gdrive/My Drive/finch/tensorflow2/semantic_parsing/tree_slu/main')"],"execution_count":1,"outputs":[{"output_type":"stream","text":["Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount(\"/content/gdrive\", force_remount=True).\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"WiC1HjFxkRD6","colab_type":"code","outputId":"2b5bae80-4516-49a3-cb03-6a747f7d2d23","executionInfo":{"status":"ok","timestamp":1588208746428,"user_tz":-480,"elapsed":6427,"user":{"displayName":"如子","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gi3ItGjzEGzUOlXTUHjOgeuVA5TICdNcY-Q1TGicA=s64","userId":"01997730851420384589"}},"colab":{"base_uri":"https://localhost:8080/","height":55}},"source":["%tensorflow_version 2.x\n","!pip install tensorflow-addons"],"execution_count":2,"outputs":[{"output_type":"stream","text":["Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (0.8.3)\n","Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons) (2.7.1)\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"_lUy-OAjkYWg","colab_type":"code","outputId":"9338e6d1-7f15-4a73-f0e1-342ccaf7ce91","executionInfo":{"status":"ok","timestamp":1588208750546,"user_tz":-480,"elapsed":4107,"user":{"displayName":"如子","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gi3ItGjzEGzUOlXTUHjOgeuVA5TICdNcY-Q1TGicA=s64","userId":"01997730851420384589"}},"colab":{"base_uri":"https://localhost:8080/","height":112}},"source":["from tensorflow_addons.optimizers.cyclical_learning_rate import ExponentialCyclicalLearningRate\n","\n","import tensorflow as tf\n","import tensorflow_addons as tfa\n","\n","import numpy as np\n","import pprint\n","import logging\n","import time\n","import nltk\n","\n","print(\"TensorFlow Version\", tf.__version__)\n","print('GPU Enabled:', tf.test.is_gpu_available())"],"execution_count":3,"outputs":[{"output_type":"stream","text":["TensorFlow Version 2.2.0-rc3\n","WARNING:tensorflow:From :13: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.\n","Instructions for updating:\n","Use `tf.config.list_physical_devices('GPU')` instead.\n","GPU Enabled: True\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"1A-a0EfkkmCv","colab_type":"code","colab":{}},"source":["# stream data from text files\n","def data_generator(f_path, params):\n"," with open(f_path) as f:\n"," print('Reading', f_path)\n"," for line in f:\n"," text_raw, text_tokenized, label = line.split('\\t')\n"," text_tokenized = text_tokenized.lower().split()\n"," label = label.replace('[', '[ ').lower().split()\n"," source = [params['tgt2idx'].get(w, len(params['tgt2idx'])) for w in text_tokenized]\n"," target = [params['tgt2idx'].get(w, len(params['tgt2idx'])) for w in label]\n"," target_in = [1] + target\n"," target_out = target + [2]\n"," yield (source, target_in, target_out)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"chVJEEq8kdEB","colab_type":"code","colab":{}},"source":["def dataset(is_training, params):\n"," _shapes = ([None], [None], [None])\n"," _types = (tf.int32, tf.int32, tf.int32)\n"," _pads = (0, 0, 0)\n"," \n"," if is_training:\n"," ds = tf.data.Dataset.from_generator(\n"," lambda: data_generator(params['train_path'], params),\n"," output_shapes = _shapes,\n"," output_types = _types,)\n"," ds = ds.shuffle(params['buffer_size'])\n"," ds = ds.padded_batch(params['train_batch_size'], _shapes, _pads)\n"," ds = ds.prefetch(tf.data.experimental.AUTOTUNE)\n"," else:\n"," ds = tf.data.Dataset.from_generator(\n"," lambda: data_generator(params['test_path'], params),\n"," output_shapes = _shapes,\n"," output_types = _types,)\n"," ds = ds.padded_batch(params['eval_batch_size'], _shapes, _pads)\n"," ds = ds.prefetch(tf.data.experimental.AUTOTUNE)\n"," \n"," return ds"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"1-sY5GtEgINL","colab_type":"code","colab":{}},"source":["class Embed(tf.keras.Model):\n"," def __init__(self):\n"," super().__init__()\n"," self.embedding = tf.Variable(np.load('../vocab/word.npy'),\n"," dtype=tf.float32,\n"," name='pretrained_embedding')\n"," \n"," def call(self, inputs):\n"," if inputs.dtype != tf.int32:\n"," inputs = tf.cast(inputs, tf.int32)\n"," x = tf.nn.embedding_lookup(self.embedding, inputs)\n"," return x"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"kHJdP4Txsqfs","colab_type":"code","colab":{}},"source":["class Encoder(tf.keras.Model):\n"," def __init__(self, params):\n"," super().__init__()\n"," self.dropout = tf.keras.layers.Dropout(params['dropout_rate'])\n"," self.bilstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(\n"," params['rnn_units'], return_state=True, return_sequences=True, zero_output_for_mask=True))\n"," self.state_fc = tf.keras.layers.Dense(params['rnn_units'], params['activation'], name='state_fc')\n"," \n"," def call(self, inputs, mask, training):\n"," if mask.dtype != tf.bool:\n"," mask = tf.cast(mask, tf.bool)\n"," x = self.dropout(inputs, training=training)\n"," \n"," encoder_o, state_fw_h, state_fw_c, state_bw_h, state_bw_c = self.bilstm(x, mask=mask)\n"," encoder_s = [\n"," self.state_fc(tf.concat((state_fw_h, state_bw_h), -1)),\n"," self.state_fc(tf.concat((state_fw_c, state_bw_c), -1)),]\n"," \n"," return encoder_o, encoder_s"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"C70-t7VKYJvd","colab_type":"code","colab":{}},"source":["class TiedDense(tf.keras.layers.Layer):\n"," def __init__(self, tied_embed, out_dim):\n"," super().__init__()\n"," self.tied_embed = tied_embed\n"," self.out_dim = out_dim\n"," \n"," def build(self, input_shape):\n"," self.bias = self.add_weight(name='bias',\n"," shape=[self.out_dim],\n"," trainable=True)\n"," super().build(input_shape)\n"," \n"," def call(self, inputs):\n"," x = tf.matmul(inputs, self.tied_embed, transpose_b=True)\n"," x = tf.nn.bias_add(x, self.bias)\n"," return x\n"," \n"," def compute_output_shape(self, input_shape):\n"," return input_shape[:-1].concatenate(self.out_dim)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"s5MnR-B1QRzZ","colab_type":"code","colab":{}},"source":["class Model(tf.keras.Model):\n"," def __init__(self, params):\n"," super().__init__()\n"," self.embed = Embed()\n","\n"," self.encoder = Encoder(params)\n","\n"," self.dropout = tf.keras.layers.Dropout(params['dropout_rate'])\n","\n"," self.attn = tfa.seq2seq.BahdanauAttention(params['rnn_units'])\n","\n"," self.decoder_cell = tfa.seq2seq.AttentionWrapper(\n"," tf.keras.layers.LSTMCell(params['rnn_units']),\n"," self.attn,\n"," attention_layer_size=params['rnn_units'])\n"," \n"," self.proj_layer = TiedDense(self.embed.embedding, len(params['tgt2idx'])+1)\n","\n"," self.teach_forcing = tfa.seq2seq.BasicDecoder(\n"," self.decoder_cell,\n"," tfa.seq2seq.sampler.TrainingSampler(),\n"," output_layer = self.proj_layer)\n","\n"," self.beam_search = tfa.seq2seq.BeamSearchDecoder(\n"," self.decoder_cell,\n"," beam_width = params['beam_width'],\n"," embedding_fn = lambda x: self.embed(x),\n"," output_layer = self.proj_layer,\n"," maximum_iterations = 80,)\n","\n"," \n"," def call(self, inputs, training=True):\n"," if training:\n"," source, target_in = inputs\n"," else:\n"," source = inputs\n"," batch_sz = tf.shape(source)[0]\n","\n"," encoder_o, encoder_s = self.encoder(self.embed(source), mask=tf.sign(source), training=training)\n","\n"," if training:\n"," self.attn([encoder_o, tf.math.count_nonzero(source, 1)], setup_memory=True)\n"," attn_state = self.decoder_cell.get_initial_state(batch_size=batch_sz, dtype=tf.float32)\n"," attn_state = attn_state.clone(cell_state=encoder_s)\n","\n"," decoder_o, _, _ = self.teach_forcing(\n"," inputs = self.dropout(self.embed(target_in), training=training),\n"," initial_state = attn_state,\n"," sequence_length = tf.math.count_nonzero(target_in, 1, dtype=tf.int32))\n","\n"," logits_or_ids = decoder_o.rnn_output\n"," else:\n"," encoder_o_t = tfa.seq2seq.tile_batch(encoder_o, params['beam_width'])\n"," encoder_len_t = tfa.seq2seq.tile_batch(tf.math.count_nonzero(source, 1), params['beam_width'])\n"," encoder_s_t = tfa.seq2seq.tile_batch(encoder_s, params['beam_width'])\n","\n"," self.attn([encoder_o_t, encoder_len_t], setup_memory=True)\n"," attn_state = self.decoder_cell.get_initial_state(batch_size=batch_sz*params['beam_width'], dtype=tf.float32)\n"," attn_state = attn_state.clone(cell_state=encoder_s_t)\n","\n"," decoder_o, _, _ = self.beam_search(\n"," None,\n"," start_tokens = tf.tile(tf.constant([1], tf.int32), [batch_sz]),\n"," end_token = 2,\n"," initial_state = attn_state,)\n","\n"," logits_or_ids = decoder_o.predicted_ids[:, :, 0]\n","\n"," return logits_or_ids"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"P62cv6uCoa9t","colab_type":"code","colab":{}},"source":["def get_vocab(f_path):\n"," word2idx = {}\n"," with open(f_path) as f:\n"," for i, line in enumerate(f):\n"," line = line.rstrip()\n"," word2idx[line] = i\n"," return word2idx"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"RQh1p0dLpNki","colab_type":"code","colab":{}},"source":["params = {\n"," 'train_path': '../data/train.tsv',\n"," 'test_path': '../data/test.tsv',\n"," 'vocab_src_path': '../vocab/source.txt',\n"," 'vocab_tgt_path': '../vocab/target.txt',\n"," 'model_path': '../model/',\n"," 'dropout_rate': .2,\n"," 'rnn_units': 300,\n"," 'embed_dim': 300,\n"," 'activation': tf.nn.elu,\n"," 'beam_width': 10,\n"," 'init_lr': 1e-4,\n"," 'max_lr': 8e-4,\n"," 'clip_norm': .1,\n"," 'buffer_size': 31279,\n"," 'train_batch_size': 32,\n"," 'eval_batch_size': 128,\n"," 'num_patience': 10,\n","}"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"nevYc9pfojnd","colab_type":"code","colab":{}},"source":["params['tgt2idx'] = get_vocab(params['vocab_tgt_path'])\n","params['idx2tgt'] = {idx: tgt for tgt, idx in params['tgt2idx'].items()}"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"kWlDyfC58-uH","colab_type":"code","outputId":"1109a5a6-abb2-4658-c376-113f79f861c5","executionInfo":{"status":"ok","timestamp":1588208873742,"user_tz":-480,"elapsed":10090,"user":{"displayName":"如子","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gi3ItGjzEGzUOlXTUHjOgeuVA5TICdNcY-Q1TGicA=s64","userId":"01997730851420384589"}},"colab":{"base_uri":"https://localhost:8080/","height":436}},"source":["model = Model(params)\n","model.build(input_shape=[[None, None], [None, None]])\n","pprint.pprint([(v.name, v.shape) for v in model.trainable_variables])"],"execution_count":13,"outputs":[{"output_type":"stream","text":["[('pretrained_embedding:0', TensorShape([8692, 300])),\n"," ('encoder/bidirectional/forward_lstm/lstm_cell_1/kernel:0',\n"," TensorShape([300, 1200])),\n"," ('encoder/bidirectional/forward_lstm/lstm_cell_1/recurrent_kernel:0',\n"," TensorShape([300, 1200])),\n"," ('encoder/bidirectional/forward_lstm/lstm_cell_1/bias:0', TensorShape([1200])),\n"," ('encoder/bidirectional/backward_lstm/lstm_cell_2/kernel:0',\n"," TensorShape([300, 1200])),\n"," ('encoder/bidirectional/backward_lstm/lstm_cell_2/recurrent_kernel:0',\n"," TensorShape([300, 1200])),\n"," ('encoder/bidirectional/backward_lstm/lstm_cell_2/bias:0',\n"," TensorShape([1200])),\n"," ('encoder/state_fc/kernel:0', TensorShape([600, 300])),\n"," ('encoder/state_fc/bias:0', TensorShape([300])),\n"," ('BahdanauAttention/attention_v:0', TensorShape([300])),\n"," ('attention_wrapper/BahdanauAttention/kernel:0', TensorShape([300, 300])),\n"," ('BahdanauAttention/kernel:0', TensorShape([600, 300])),\n"," ('attention_wrapper/attention_layer/kernel:0', TensorShape([900, 300])),\n"," ('attention_wrapper/lstm_cell_3/kernel:0', TensorShape([600, 1200])),\n"," ('attention_wrapper/lstm_cell_3/recurrent_kernel:0', TensorShape([300, 1200])),\n"," ('attention_wrapper/lstm_cell_3/bias:0', TensorShape([1200])),\n"," ('tied_dense/bias:0', TensorShape([8692]))]\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"7msXJmQ4q3L3","colab_type":"code","colab":{}},"source":["decay_lr = ExponentialCyclicalLearningRate(\n"," initial_learning_rate=params['init_lr'],\n"," maximal_learning_rate=params['max_lr'],\n"," step_size=4*params['buffer_size']//params['train_batch_size'],)\n","optim = tf.optimizers.Adam(params['init_lr'])\n","global_step = 0"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"qQ86G5vBq37N","colab_type":"code","colab":{}},"source":["best_acc = .0\n","count = 0"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"cQFdB9Nwq6FA","colab_type":"code","colab":{}},"source":["t0 = time.time()\n","logger = logging.getLogger('tensorflow')\n","logger.propagate = False\n","logger.setLevel(logging.INFO)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"BaEXdULasUxQ","colab_type":"code","colab":{}},"source":["def minimal_test(model, params):\n"," test_str = ['what', 'times', 'are', 'the', 'nutcracker', 'show', 'playing', 'near', 'me']\n"," test_arr = tf.convert_to_tensor([[params['tgt2idx'][w] for w in test_str]])\n"," generated = model(inputs=test_arr, training=False)\n","\n"," print('-'*12)\n"," print('minimal test')\n"," print('utterance:', ' '.join(test_str))\n"," parsed = ' '.join([params['idx2tgt'][idx] for idx in generated[0].numpy() if (idx != 0 and idx != 2)])\n"," print('parsed:', parsed)\n"," print()\n"," try:\n"," nltk.tree.Tree.fromstring(parsed.replace('[ ', '(').replace(' ]', ')')).pretty_print()\n"," except:\n"," pass\n"," print('-'*12)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"ql3yTFjSmNwU","colab_type":"code","outputId":"20dcef11-ab01-403a-9715-486c1c9bda30","executionInfo":{"status":"ok","timestamp":1588232470979,"user_tz":-480,"elapsed":23590097,"user":{"displayName":"如子","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14Gi3ItGjzEGzUOlXTUHjOgeuVA5TICdNcY-Q1TGicA=s64","userId":"01997730851420384589"}},"colab":{"base_uri":"https://localhost:8080/","height":1000}},"source":["while True:\n"," # TRAINING\n"," is_training = True\n"," for i, (source, target_in, target_out) in enumerate(dataset(is_training=is_training, params=params)):\n"," with tf.GradientTape() as tape:\n"," logits_or_ids = model((source, target_in), training=is_training)\n"," \n"," loss = tf.compat.v1.losses.softmax_cross_entropy(\n"," onehot_labels = tf.one_hot(target_out, len(params['tgt2idx'])+1),\n"," logits = logits_or_ids,\n"," weights = tf.cast(tf.sign(target_out), tf.float32),\n"," label_smoothing = .2) \n","\n"," variables = model.trainable_variables\n"," optim.lr.assign(decay_lr(global_step))\n"," grads = tape.gradient(loss, variables)\n"," grads, _ = tf.clip_by_global_norm(grads, params['clip_norm'])\n"," optim.apply_gradients(zip(grads, variables))\n"," \n"," if global_step % 50 == 0:\n"," logger.info(\"Step {} | Loss: {:.4f} | Spent: {:.1f} secs | LR: {:.6f}\".format(\n"," global_step, loss.numpy().item(), time.time()-t0, optim.lr.numpy().item()))\n"," t0 = time.time()\n"," \n"," global_step += 1\n","\n"," # EVALUATION\n"," is_training = False\n"," minimal_test(model, params)\n"," m = tf.keras.metrics.Mean()\n"," \n"," parse_fn = lambda x: [e for e in x if (e != 0 and e != 2)]\n","\n"," for i, (source, target_in, target_out) in enumerate(dataset(is_training=is_training, params=params)):\n"," generated = model(inputs=source, training=is_training)\n"," for pred, tgt in zip(generated.numpy(), target_out.numpy()):\n"," matched = np.array_equal(parse_fn(pred), parse_fn(tgt))\n"," m.update_state(int(matched))\n"," \n"," acc = m.result().numpy()\n"," logger.info(\"Evaluation: Testing EM: {:.3f}\".format(acc))\n","\n"," if acc > best_acc:\n"," best_acc = acc\n"," count = 0\n"," else:\n"," count += 1\n"," logger.info(\"Best EM: {:.3f}\".format(best_acc))\n","\n"," if count == params['num_patience']:\n"," print(params['num_patience'], \"times not improve the best result, therefore stop training\")\n"," break"],"execution_count":18,"outputs":[{"output_type":"stream","text":["Reading ../data/train.tsv\n","INFO:tensorflow:Step 0 | Loss: 8.9311 | Spent: 15.6 secs | LR: 0.000100\n","INFO:tensorflow:Step 50 | Loss: 6.1842 | Spent: 22.1 secs | LR: 0.000109\n","INFO:tensorflow:Step 100 | Loss: 5.6022 | Spent: 22.8 secs | LR: 0.000118\n","INFO:tensorflow:Step 150 | Loss: 4.9817 | Spent: 21.6 secs | LR: 0.000127\n","INFO:tensorflow:Step 200 | Loss: 4.6414 | Spent: 23.7 secs | LR: 0.000136\n","INFO:tensorflow:Step 250 | Loss: 4.4588 | Spent: 22.5 secs | LR: 0.000145\n","INFO:tensorflow:Step 300 | Loss: 4.5149 | Spent: 22.8 secs | LR: 0.000154\n","INFO:tensorflow:Step 350 | Loss: 4.1099 | Spent: 21.0 secs | LR: 0.000163\n","INFO:tensorflow:Step 400 | Loss: 4.0889 | Spent: 21.8 secs | LR: 0.000172\n","INFO:tensorflow:Step 450 | Loss: 4.0253 | Spent: 23.5 secs | LR: 0.000181\n","INFO:tensorflow:Step 500 | Loss: 3.7661 | Spent: 21.6 secs | LR: 0.000190\n","INFO:tensorflow:Step 550 | Loss: 3.6650 | Spent: 22.3 secs | LR: 0.000198\n","INFO:tensorflow:Step 600 | Loss: 3.6460 | Spent: 21.6 secs | LR: 0.000207\n","INFO:tensorflow:Step 650 | Loss: 3.5861 | Spent: 21.2 secs | LR: 0.000216\n","INFO:tensorflow:Step 700 | Loss: 3.6209 | Spent: 22.3 secs | LR: 0.000225\n","INFO:tensorflow:Step 750 | Loss: 3.4799 | Spent: 21.6 secs | LR: 0.000234\n","INFO:tensorflow:Step 800 | Loss: 3.5453 | Spent: 21.3 secs | LR: 0.000243\n","INFO:tensorflow:Step 850 | Loss: 3.2477 | Spent: 22.1 secs | LR: 0.000252\n","INFO:tensorflow:Step 900 | Loss: 3.2967 | Spent: 21.5 secs | LR: 0.000261\n","INFO:tensorflow:Step 950 | Loss: 3.1771 | Spent: 22.7 secs | LR: 0.000270\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what time are the [ sl:location nutcracker show ] [ sl:date_time playing ] [ sl:location [ in:get_location [ sl:search_radius near ] ] ] ]\n","\n"," in:get_event \n"," _________________|_______________________________________________________ \n"," | | | | | | sl:location \n"," | | | | | | | \n"," | | | | | | in:get_location \n"," | | | | | | | \n"," | | | | sl:location sl:date_time sl:search_radius\n"," | | | | ___________|_______ | | \n","what time are the nutcracker show playing near \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.115\n","INFO:tensorflow:Best EM: 0.115\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 1000 | Loss: 3.3968 | Spent: 167.4 secs | LR: 0.000279\n","INFO:tensorflow:Step 1050 | Loss: 3.0195 | Spent: 21.2 secs | LR: 0.000288\n","INFO:tensorflow:Step 1100 | Loss: 2.9976 | Spent: 21.6 secs | LR: 0.000297\n","INFO:tensorflow:Step 1150 | Loss: 2.9960 | Spent: 21.9 secs | LR: 0.000306\n","INFO:tensorflow:Step 1200 | Loss: 3.0892 | Spent: 21.6 secs | LR: 0.000315\n","INFO:tensorflow:Step 1250 | Loss: 2.9465 | Spent: 21.4 secs | LR: 0.000324\n","INFO:tensorflow:Step 1300 | Loss: 2.9528 | Spent: 21.2 secs | LR: 0.000333\n","INFO:tensorflow:Step 1350 | Loss: 2.9703 | Spent: 21.0 secs | LR: 0.000342\n","INFO:tensorflow:Step 1400 | Loss: 2.8736 | Spent: 21.5 secs | LR: 0.000351\n","INFO:tensorflow:Step 1450 | Loss: 3.0001 | Spent: 21.8 secs | LR: 0.000360\n","INFO:tensorflow:Step 1500 | Loss: 3.0549 | Spent: 21.5 secs | LR: 0.000369\n","INFO:tensorflow:Step 1550 | Loss: 2.8442 | Spent: 21.7 secs | LR: 0.000378\n","INFO:tensorflow:Step 1600 | Loss: 2.8264 | Spent: 21.3 secs | LR: 0.000387\n","INFO:tensorflow:Step 1650 | Loss: 2.7400 | Spent: 20.9 secs | LR: 0.000395\n","INFO:tensorflow:Step 1700 | Loss: 2.7917 | Spent: 20.7 secs | LR: 0.000404\n","INFO:tensorflow:Step 1750 | Loss: 2.7420 | Spent: 22.1 secs | LR: 0.000413\n","INFO:tensorflow:Step 1800 | Loss: 2.6783 | Spent: 21.2 secs | LR: 0.000422\n","INFO:tensorflow:Step 1850 | Loss: 2.7636 | Spent: 21.2 secs | LR: 0.000431\n","INFO:tensorflow:Step 1900 | Loss: 2.7380 | Spent: 22.1 secs | LR: 0.000440\n","INFO:tensorflow:Step 1950 | Loss: 2.8155 | Spent: 21.9 secs | LR: 0.000449\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are the nutcracker show playing near me ]\n","\n"," in:get_event \n"," ______________________|__________________________ \n","what times are the nutcracker show playing near me\n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.436\n","INFO:tensorflow:Best EM: 0.436\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 2000 | Loss: 2.6276 | Spent: 172.5 secs | LR: 0.000458\n","INFO:tensorflow:Step 2050 | Loss: 2.6713 | Spent: 21.4 secs | LR: 0.000467\n","INFO:tensorflow:Step 2100 | Loss: 2.6842 | Spent: 20.9 secs | LR: 0.000476\n","INFO:tensorflow:Step 2150 | Loss: 2.6747 | Spent: 21.1 secs | LR: 0.000485\n","INFO:tensorflow:Step 2200 | Loss: 2.6105 | Spent: 21.6 secs | LR: 0.000494\n","INFO:tensorflow:Step 2250 | Loss: 2.6687 | Spent: 21.3 secs | LR: 0.000503\n","INFO:tensorflow:Step 2300 | Loss: 2.5524 | Spent: 21.2 secs | LR: 0.000512\n","INFO:tensorflow:Step 2350 | Loss: 2.5493 | Spent: 22.1 secs | LR: 0.000521\n","INFO:tensorflow:Step 2400 | Loss: 2.5027 | Spent: 21.3 secs | LR: 0.000530\n","INFO:tensorflow:Step 2450 | Loss: 2.5642 | Spent: 21.6 secs | LR: 0.000539\n","INFO:tensorflow:Step 2500 | Loss: 2.5276 | Spent: 21.5 secs | LR: 0.000548\n","INFO:tensorflow:Step 2550 | Loss: 2.5371 | Spent: 21.7 secs | LR: 0.000557\n","INFO:tensorflow:Step 2600 | Loss: 2.5557 | Spent: 21.3 secs | LR: 0.000566\n","INFO:tensorflow:Step 2650 | Loss: 2.4885 | Spent: 21.1 secs | LR: 0.000575\n","INFO:tensorflow:Step 2700 | Loss: 2.6464 | Spent: 20.9 secs | LR: 0.000583\n","INFO:tensorflow:Step 2750 | Loss: 2.5262 | Spent: 21.5 secs | LR: 0.000592\n","INFO:tensorflow:Step 2800 | Loss: 2.4980 | Spent: 21.5 secs | LR: 0.000601\n","INFO:tensorflow:Step 2850 | Loss: 2.5096 | Spent: 20.8 secs | LR: 0.000610\n","INFO:tensorflow:Step 2900 | Loss: 2.5100 | Spent: 21.2 secs | LR: 0.000619\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:unsupported what times are the nutcracker show playing near me ]\n","\n"," in:unsupported \n"," _______________________|___________________________ \n","what times are the nutcracker show playing near me\n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.561\n","INFO:tensorflow:Best EM: 0.561\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 2950 | Loss: 2.5117 | Spent: 171.3 secs | LR: 0.000628\n","INFO:tensorflow:Step 3000 | Loss: 2.5492 | Spent: 21.2 secs | LR: 0.000637\n","INFO:tensorflow:Step 3050 | Loss: 2.4878 | Spent: 21.5 secs | LR: 0.000646\n","INFO:tensorflow:Step 3100 | Loss: 2.5425 | Spent: 21.3 secs | LR: 0.000655\n","INFO:tensorflow:Step 3150 | Loss: 2.4586 | Spent: 21.3 secs | LR: 0.000664\n","INFO:tensorflow:Step 3200 | Loss: 2.5240 | Spent: 21.4 secs | LR: 0.000673\n","INFO:tensorflow:Step 3250 | Loss: 2.4603 | Spent: 21.8 secs | LR: 0.000682\n","INFO:tensorflow:Step 3300 | Loss: 2.4707 | Spent: 20.9 secs | LR: 0.000691\n","INFO:tensorflow:Step 3350 | Loss: 2.4419 | Spent: 22.7 secs | LR: 0.000700\n","INFO:tensorflow:Step 3400 | Loss: 2.4727 | Spent: 21.1 secs | LR: 0.000709\n","INFO:tensorflow:Step 3450 | Loss: 2.5017 | Spent: 21.8 secs | LR: 0.000718\n","INFO:tensorflow:Step 3500 | Loss: 2.4495 | Spent: 22.3 secs | LR: 0.000727\n","INFO:tensorflow:Step 3550 | Loss: 2.4709 | Spent: 21.7 secs | LR: 0.000736\n","INFO:tensorflow:Step 3600 | Loss: 2.4798 | Spent: 21.6 secs | LR: 0.000745\n","INFO:tensorflow:Step 3650 | Loss: 2.4529 | Spent: 21.5 secs | LR: 0.000754\n","INFO:tensorflow:Step 3700 | Loss: 2.4668 | Spent: 21.8 secs | LR: 0.000763\n","INFO:tensorflow:Step 3750 | Loss: 2.4798 | Spent: 21.7 secs | LR: 0.000772\n","INFO:tensorflow:Step 3800 | Loss: 2.4363 | Spent: 20.9 secs | LR: 0.000780\n","INFO:tensorflow:Step 3850 | Loss: 2.4644 | Spent: 22.3 secs | LR: 0.000789\n","INFO:tensorflow:Step 3900 | Loss: 2.4728 | Spent: 21.6 secs | LR: 0.000798\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," __________________________|_____________________________________________________ \n"," | | | | | sl:location \n"," | | | | | | \n"," | | | | | in:get_location \n"," | | | | | ________________|_______________ \n"," | | | | sl:category_even sl:search_radius sl:location_user\n"," | | | | t | | \n"," | | | | ______________|__________ | | \n","what times are playing the nutcracker show near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.648\n","INFO:tensorflow:Best EM: 0.648\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 3950 | Loss: 2.4486 | Spent: 178.3 secs | LR: 0.000793\n","INFO:tensorflow:Step 4000 | Loss: 2.4474 | Spent: 21.4 secs | LR: 0.000784\n","INFO:tensorflow:Step 4050 | Loss: 2.4709 | Spent: 21.5 secs | LR: 0.000775\n","INFO:tensorflow:Step 4100 | Loss: 2.4214 | Spent: 21.5 secs | LR: 0.000766\n","INFO:tensorflow:Step 4150 | Loss: 2.4645 | Spent: 21.4 secs | LR: 0.000757\n","INFO:tensorflow:Step 4200 | Loss: 2.4090 | Spent: 21.7 secs | LR: 0.000748\n","INFO:tensorflow:Step 4250 | Loss: 2.4534 | Spent: 20.9 secs | LR: 0.000739\n","INFO:tensorflow:Step 4300 | Loss: 2.4195 | Spent: 21.1 secs | LR: 0.000730\n","INFO:tensorflow:Step 4350 | Loss: 2.4353 | Spent: 21.0 secs | LR: 0.000721\n","INFO:tensorflow:Step 4400 | Loss: 2.4376 | Spent: 21.7 secs | LR: 0.000712\n","INFO:tensorflow:Step 4450 | Loss: 2.4351 | Spent: 21.3 secs | LR: 0.000703\n","INFO:tensorflow:Step 4500 | Loss: 2.4597 | Spent: 21.4 secs | LR: 0.000694\n","INFO:tensorflow:Step 4550 | Loss: 2.4291 | Spent: 21.1 secs | LR: 0.000685\n","INFO:tensorflow:Step 4600 | Loss: 2.4042 | Spent: 21.7 secs | LR: 0.000676\n","INFO:tensorflow:Step 4650 | Loss: 2.4440 | Spent: 22.1 secs | LR: 0.000667\n","INFO:tensorflow:Step 4700 | Loss: 2.4347 | Spent: 20.2 secs | LR: 0.000658\n","INFO:tensorflow:Step 4750 | Loss: 2.4030 | Spent: 21.1 secs | LR: 0.000649\n","INFO:tensorflow:Step 4800 | Loss: 2.4182 | Spent: 20.7 secs | LR: 0.000640\n","INFO:tensorflow:Step 4850 | Loss: 2.3977 | Spent: 21.4 secs | LR: 0.000631\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," __________________________|_____________________________________________________ \n"," | | | | | sl:location \n"," | | | | | | \n"," | | | | | in:get_location \n"," | | | | | ________________|_______________ \n"," | | | | sl:category_even sl:search_radius sl:location_user\n"," | | | | t | | \n"," | | | | ______________|__________ | | \n","what times are playing the nutcracker show near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.685\n","INFO:tensorflow:Best EM: 0.685\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 4900 | Loss: 2.4061 | Spent: 155.5 secs | LR: 0.000623\n","INFO:tensorflow:Step 4950 | Loss: 2.4061 | Spent: 21.4 secs | LR: 0.000614\n","INFO:tensorflow:Step 5000 | Loss: 2.4095 | Spent: 21.5 secs | LR: 0.000605\n","INFO:tensorflow:Step 5050 | Loss: 2.3989 | Spent: 21.5 secs | LR: 0.000596\n","INFO:tensorflow:Step 5100 | Loss: 2.4122 | Spent: 21.7 secs | LR: 0.000587\n","INFO:tensorflow:Step 5150 | Loss: 2.4467 | Spent: 21.4 secs | LR: 0.000578\n","INFO:tensorflow:Step 5200 | Loss: 2.3916 | Spent: 20.8 secs | LR: 0.000569\n","INFO:tensorflow:Step 5250 | Loss: 2.4007 | Spent: 21.2 secs | LR: 0.000560\n","INFO:tensorflow:Step 5300 | Loss: 2.3913 | Spent: 21.7 secs | LR: 0.000551\n","INFO:tensorflow:Step 5350 | Loss: 2.3920 | Spent: 20.9 secs | LR: 0.000542\n","INFO:tensorflow:Step 5400 | Loss: 2.4219 | Spent: 21.0 secs | LR: 0.000533\n","INFO:tensorflow:Step 5450 | Loss: 2.4154 | Spent: 20.6 secs | LR: 0.000524\n","INFO:tensorflow:Step 5500 | Loss: 2.3987 | Spent: 21.7 secs | LR: 0.000515\n","INFO:tensorflow:Step 5550 | Loss: 2.3987 | Spent: 22.4 secs | LR: 0.000506\n","INFO:tensorflow:Step 5600 | Loss: 2.4496 | Spent: 22.3 secs | LR: 0.000497\n","INFO:tensorflow:Step 5650 | Loss: 2.4207 | Spent: 21.3 secs | LR: 0.000488\n","INFO:tensorflow:Step 5700 | Loss: 2.4028 | Spent: 21.7 secs | LR: 0.000479\n","INFO:tensorflow:Step 5750 | Loss: 2.3911 | Spent: 21.7 secs | LR: 0.000470\n","INFO:tensorflow:Step 5800 | Loss: 2.3929 | Spent: 21.0 secs | LR: 0.000461\n","INFO:tensorflow:Step 5850 | Loss: 2.3942 | Spent: 21.2 secs | LR: 0.000452\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.701\n","INFO:tensorflow:Best EM: 0.701\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 5900 | Loss: 2.4120 | Spent: 168.6 secs | LR: 0.000443\n","INFO:tensorflow:Step 5950 | Loss: 2.3945 | Spent: 21.1 secs | LR: 0.000435\n","INFO:tensorflow:Step 6000 | Loss: 2.3827 | Spent: 21.3 secs | LR: 0.000426\n","INFO:tensorflow:Step 6050 | Loss: 2.3849 | Spent: 21.6 secs | LR: 0.000417\n","INFO:tensorflow:Step 6100 | Loss: 2.3593 | Spent: 21.5 secs | LR: 0.000408\n","INFO:tensorflow:Step 6150 | Loss: 2.3779 | Spent: 21.9 secs | LR: 0.000399\n","INFO:tensorflow:Step 6200 | Loss: 2.3678 | Spent: 21.8 secs | LR: 0.000390\n","INFO:tensorflow:Step 6250 | Loss: 2.4246 | Spent: 21.9 secs | LR: 0.000381\n","INFO:tensorflow:Step 6300 | Loss: 2.3906 | Spent: 21.3 secs | LR: 0.000372\n","INFO:tensorflow:Step 6350 | Loss: 2.3678 | Spent: 22.2 secs | LR: 0.000363\n","INFO:tensorflow:Step 6400 | Loss: 2.3848 | Spent: 22.1 secs | LR: 0.000354\n","INFO:tensorflow:Step 6450 | Loss: 2.3761 | Spent: 22.0 secs | LR: 0.000345\n","INFO:tensorflow:Step 6500 | Loss: 2.3694 | Spent: 22.0 secs | LR: 0.000336\n","INFO:tensorflow:Step 6550 | Loss: 2.3602 | Spent: 22.1 secs | LR: 0.000327\n","INFO:tensorflow:Step 6600 | Loss: 2.3771 | Spent: 21.6 secs | LR: 0.000318\n","INFO:tensorflow:Step 6650 | Loss: 2.3674 | Spent: 22.0 secs | LR: 0.000309\n","INFO:tensorflow:Step 6700 | Loss: 2.3779 | Spent: 22.0 secs | LR: 0.000300\n","INFO:tensorflow:Step 6750 | Loss: 2.3815 | Spent: 21.7 secs | LR: 0.000291\n","INFO:tensorflow:Step 6800 | Loss: 2.3829 | Spent: 22.1 secs | LR: 0.000282\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.718\n","INFO:tensorflow:Best EM: 0.718\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 6850 | Loss: 2.3764 | Spent: 164.0 secs | LR: 0.000273\n","INFO:tensorflow:Step 6900 | Loss: 2.3665 | Spent: 21.0 secs | LR: 0.000264\n","INFO:tensorflow:Step 6950 | Loss: 2.3675 | Spent: 21.7 secs | LR: 0.000255\n","INFO:tensorflow:Step 7000 | Loss: 2.3875 | Spent: 21.2 secs | LR: 0.000246\n","INFO:tensorflow:Step 7050 | Loss: 2.3654 | Spent: 21.7 secs | LR: 0.000238\n","INFO:tensorflow:Step 7100 | Loss: 2.3705 | Spent: 21.9 secs | LR: 0.000229\n","INFO:tensorflow:Step 7150 | Loss: 2.3750 | Spent: 20.8 secs | LR: 0.000220\n","INFO:tensorflow:Step 7200 | Loss: 2.3644 | Spent: 22.6 secs | LR: 0.000211\n","INFO:tensorflow:Step 7250 | Loss: 2.3655 | Spent: 21.6 secs | LR: 0.000202\n","INFO:tensorflow:Step 7300 | Loss: 2.3649 | Spent: 21.5 secs | LR: 0.000193\n","INFO:tensorflow:Step 7350 | Loss: 2.3631 | Spent: 21.3 secs | LR: 0.000184\n","INFO:tensorflow:Step 7400 | Loss: 2.3650 | Spent: 21.0 secs | LR: 0.000175\n","INFO:tensorflow:Step 7450 | Loss: 2.3583 | Spent: 22.2 secs | LR: 0.000166\n","INFO:tensorflow:Step 7500 | Loss: 2.3868 | Spent: 21.8 secs | LR: 0.000157\n","INFO:tensorflow:Step 7550 | Loss: 2.3644 | Spent: 21.4 secs | LR: 0.000148\n","INFO:tensorflow:Step 7600 | Loss: 2.3678 | Spent: 21.7 secs | LR: 0.000139\n","INFO:tensorflow:Step 7650 | Loss: 2.3574 | Spent: 21.1 secs | LR: 0.000130\n","INFO:tensorflow:Step 7700 | Loss: 2.3606 | Spent: 20.9 secs | LR: 0.000121\n","INFO:tensorflow:Step 7750 | Loss: 2.3706 | Spent: 21.6 secs | LR: 0.000112\n","INFO:tensorflow:Step 7800 | Loss: 2.3841 | Spent: 21.6 secs | LR: 0.000103\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.726\n","INFO:tensorflow:Best EM: 0.726\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 7850 | Loss: 2.3640 | Spent: 172.9 secs | LR: 0.000106\n","INFO:tensorflow:Step 7900 | Loss: 2.3546 | Spent: 21.7 secs | LR: 0.000115\n","INFO:tensorflow:Step 7950 | Loss: 2.3635 | Spent: 22.0 secs | LR: 0.000124\n","INFO:tensorflow:Step 8000 | Loss: 2.3637 | Spent: 23.1 secs | LR: 0.000133\n","INFO:tensorflow:Step 8050 | Loss: 2.3567 | Spent: 21.0 secs | LR: 0.000142\n","INFO:tensorflow:Step 8100 | Loss: 2.3626 | Spent: 20.6 secs | LR: 0.000150\n","INFO:tensorflow:Step 8150 | Loss: 2.3608 | Spent: 22.3 secs | LR: 0.000159\n","INFO:tensorflow:Step 8200 | Loss: 2.3611 | Spent: 21.6 secs | LR: 0.000168\n","INFO:tensorflow:Step 8250 | Loss: 2.3544 | Spent: 21.8 secs | LR: 0.000177\n","INFO:tensorflow:Step 8300 | Loss: 2.3574 | Spent: 22.6 secs | LR: 0.000186\n","INFO:tensorflow:Step 8350 | Loss: 2.3683 | Spent: 21.9 secs | LR: 0.000195\n","INFO:tensorflow:Step 8400 | Loss: 2.3761 | Spent: 21.5 secs | LR: 0.000204\n","INFO:tensorflow:Step 8450 | Loss: 2.3753 | Spent: 21.8 secs | LR: 0.000213\n","INFO:tensorflow:Step 8500 | Loss: 2.3604 | Spent: 21.7 secs | LR: 0.000222\n","INFO:tensorflow:Step 8550 | Loss: 2.3604 | Spent: 21.8 secs | LR: 0.000231\n","INFO:tensorflow:Step 8600 | Loss: 2.3767 | Spent: 22.1 secs | LR: 0.000240\n","INFO:tensorflow:Step 8650 | Loss: 2.3629 | Spent: 22.3 secs | LR: 0.000249\n","INFO:tensorflow:Step 8700 | Loss: 2.3667 | Spent: 21.5 secs | LR: 0.000258\n","INFO:tensorflow:Step 8750 | Loss: 2.3709 | Spent: 21.7 secs | LR: 0.000267\n","INFO:tensorflow:Step 8800 | Loss: 2.3642 | Spent: 21.4 secs | LR: 0.000276\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.721\n","INFO:tensorflow:Best EM: 0.726\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 8850 | Loss: 2.3576 | Spent: 162.4 secs | LR: 0.000285\n","INFO:tensorflow:Step 8900 | Loss: 2.3654 | Spent: 20.9 secs | LR: 0.000294\n","INFO:tensorflow:Step 8950 | Loss: 2.3672 | Spent: 21.3 secs | LR: 0.000303\n","INFO:tensorflow:Step 9000 | Loss: 2.3783 | Spent: 22.1 secs | LR: 0.000312\n","INFO:tensorflow:Step 9050 | Loss: 2.3716 | Spent: 21.7 secs | LR: 0.000321\n","INFO:tensorflow:Step 9100 | Loss: 2.3679 | Spent: 22.0 secs | LR: 0.000330\n","INFO:tensorflow:Step 9150 | Loss: 2.3705 | Spent: 21.6 secs | LR: 0.000339\n","INFO:tensorflow:Step 9200 | Loss: 2.3747 | Spent: 21.2 secs | LR: 0.000347\n","INFO:tensorflow:Step 9250 | Loss: 2.3626 | Spent: 21.7 secs | LR: 0.000356\n","INFO:tensorflow:Step 9300 | Loss: 2.3647 | Spent: 21.1 secs | LR: 0.000365\n","INFO:tensorflow:Step 9350 | Loss: 2.3585 | Spent: 22.2 secs | LR: 0.000374\n","INFO:tensorflow:Step 9400 | Loss: 2.3777 | Spent: 21.3 secs | LR: 0.000383\n","INFO:tensorflow:Step 9450 | Loss: 2.3739 | Spent: 21.4 secs | LR: 0.000392\n","INFO:tensorflow:Step 9500 | Loss: 2.3725 | Spent: 21.6 secs | LR: 0.000401\n","INFO:tensorflow:Step 9550 | Loss: 2.3680 | Spent: 21.8 secs | LR: 0.000410\n","INFO:tensorflow:Step 9600 | Loss: 2.3950 | Spent: 21.6 secs | LR: 0.000419\n","INFO:tensorflow:Step 9650 | Loss: 2.3702 | Spent: 22.0 secs | LR: 0.000428\n","INFO:tensorflow:Step 9700 | Loss: 2.3903 | Spent: 21.5 secs | LR: 0.000437\n","INFO:tensorflow:Step 9750 | Loss: 2.3611 | Spent: 22.0 secs | LR: 0.000446\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.722\n","INFO:tensorflow:Best EM: 0.726\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 9800 | Loss: 2.3725 | Spent: 171.5 secs | LR: 0.000455\n","INFO:tensorflow:Step 9850 | Loss: 2.3643 | Spent: 21.4 secs | LR: 0.000464\n","INFO:tensorflow:Step 9900 | Loss: 2.3555 | Spent: 20.9 secs | LR: 0.000473\n","INFO:tensorflow:Step 9950 | Loss: 2.3612 | Spent: 21.3 secs | LR: 0.000482\n","INFO:tensorflow:Step 10000 | Loss: 2.3669 | Spent: 20.4 secs | LR: 0.000491\n","INFO:tensorflow:Step 10050 | Loss: 2.3561 | Spent: 21.8 secs | LR: 0.000500\n","INFO:tensorflow:Step 10100 | Loss: 2.3544 | Spent: 22.9 secs | LR: 0.000509\n","INFO:tensorflow:Step 10150 | Loss: 2.3689 | Spent: 20.8 secs | LR: 0.000518\n","INFO:tensorflow:Step 10200 | Loss: 2.3546 | Spent: 21.3 secs | LR: 0.000527\n","INFO:tensorflow:Step 10250 | Loss: 2.3792 | Spent: 21.9 secs | LR: 0.000536\n","INFO:tensorflow:Step 10300 | Loss: 2.3799 | Spent: 21.7 secs | LR: 0.000544\n","INFO:tensorflow:Step 10350 | Loss: 2.3635 | Spent: 22.1 secs | LR: 0.000553\n","INFO:tensorflow:Step 10400 | Loss: 2.3839 | Spent: 21.8 secs | LR: 0.000562\n","INFO:tensorflow:Step 10450 | Loss: 2.3639 | Spent: 21.8 secs | LR: 0.000571\n","INFO:tensorflow:Step 10500 | Loss: 2.3655 | Spent: 21.6 secs | LR: 0.000580\n","INFO:tensorflow:Step 10550 | Loss: 2.3791 | Spent: 21.8 secs | LR: 0.000589\n","INFO:tensorflow:Step 10600 | Loss: 2.3723 | Spent: 21.6 secs | LR: 0.000598\n","INFO:tensorflow:Step 10650 | Loss: 2.3640 | Spent: 21.7 secs | LR: 0.000607\n","INFO:tensorflow:Step 10700 | Loss: 2.3625 | Spent: 21.9 secs | LR: 0.000616\n","INFO:tensorflow:Step 10750 | Loss: 2.3605 | Spent: 21.7 secs | LR: 0.000625\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.720\n","INFO:tensorflow:Best EM: 0.726\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 10800 | Loss: 2.3594 | Spent: 164.1 secs | LR: 0.000634\n","INFO:tensorflow:Step 10850 | Loss: 2.3799 | Spent: 22.7 secs | LR: 0.000643\n","INFO:tensorflow:Step 10900 | Loss: 2.3680 | Spent: 21.4 secs | LR: 0.000652\n","INFO:tensorflow:Step 10950 | Loss: 2.3587 | Spent: 21.0 secs | LR: 0.000661\n","INFO:tensorflow:Step 11000 | Loss: 2.3693 | Spent: 20.6 secs | LR: 0.000670\n","INFO:tensorflow:Step 11050 | Loss: 2.3692 | Spent: 21.6 secs | LR: 0.000679\n","INFO:tensorflow:Step 11100 | Loss: 2.3616 | Spent: 21.4 secs | LR: 0.000688\n","INFO:tensorflow:Step 11150 | Loss: 2.3684 | Spent: 20.8 secs | LR: 0.000697\n","INFO:tensorflow:Step 11200 | Loss: 2.3755 | Spent: 22.1 secs | LR: 0.000706\n","INFO:tensorflow:Step 11250 | Loss: 2.3682 | Spent: 21.6 secs | LR: 0.000715\n","INFO:tensorflow:Step 11300 | Loss: 2.3738 | Spent: 22.4 secs | LR: 0.000724\n","INFO:tensorflow:Step 11350 | Loss: 2.3685 | Spent: 21.7 secs | LR: 0.000732\n","INFO:tensorflow:Step 11400 | Loss: 2.3585 | Spent: 21.2 secs | LR: 0.000741\n","INFO:tensorflow:Step 11450 | Loss: 2.3918 | Spent: 21.0 secs | LR: 0.000750\n","INFO:tensorflow:Step 11500 | Loss: 2.3774 | Spent: 22.0 secs | LR: 0.000759\n","INFO:tensorflow:Step 11550 | Loss: 2.3639 | Spent: 22.0 secs | LR: 0.000768\n","INFO:tensorflow:Step 11600 | Loss: 2.3854 | Spent: 21.5 secs | LR: 0.000777\n","INFO:tensorflow:Step 11650 | Loss: 2.3732 | Spent: 21.1 secs | LR: 0.000786\n","INFO:tensorflow:Step 11700 | Loss: 2.3879 | Spent: 21.2 secs | LR: 0.000795\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.713\n","INFO:tensorflow:Best EM: 0.726\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 11750 | Loss: 2.3613 | Spent: 192.1 secs | LR: 0.000796\n","INFO:tensorflow:Step 11800 | Loss: 2.3682 | Spent: 21.9 secs | LR: 0.000787\n","INFO:tensorflow:Step 11850 | Loss: 2.3796 | Spent: 21.6 secs | LR: 0.000778\n","INFO:tensorflow:Step 11900 | Loss: 2.3652 | Spent: 21.6 secs | LR: 0.000769\n","INFO:tensorflow:Step 11950 | Loss: 2.3821 | Spent: 21.1 secs | LR: 0.000760\n","INFO:tensorflow:Step 12000 | Loss: 2.3810 | Spent: 22.2 secs | LR: 0.000751\n","INFO:tensorflow:Step 12050 | Loss: 2.3626 | Spent: 20.9 secs | LR: 0.000742\n","INFO:tensorflow:Step 12100 | Loss: 2.3781 | Spent: 21.7 secs | LR: 0.000733\n","INFO:tensorflow:Step 12150 | Loss: 2.3632 | Spent: 21.5 secs | LR: 0.000724\n","INFO:tensorflow:Step 12200 | Loss: 2.3675 | Spent: 23.4 secs | LR: 0.000715\n","INFO:tensorflow:Step 12250 | Loss: 2.3582 | Spent: 22.1 secs | LR: 0.000706\n","INFO:tensorflow:Step 12300 | Loss: 2.3562 | Spent: 22.1 secs | LR: 0.000697\n","INFO:tensorflow:Step 12350 | Loss: 2.3680 | Spent: 22.4 secs | LR: 0.000688\n","INFO:tensorflow:Step 12400 | Loss: 2.3626 | Spent: 22.2 secs | LR: 0.000679\n","INFO:tensorflow:Step 12450 | Loss: 2.3629 | Spent: 21.6 secs | LR: 0.000671\n","INFO:tensorflow:Step 12500 | Loss: 2.3652 | Spent: 22.6 secs | LR: 0.000662\n","INFO:tensorflow:Step 12550 | Loss: 2.3505 | Spent: 22.0 secs | LR: 0.000653\n","INFO:tensorflow:Step 12600 | Loss: 2.3679 | Spent: 21.3 secs | LR: 0.000644\n","INFO:tensorflow:Step 12650 | Loss: 2.3644 | Spent: 22.5 secs | LR: 0.000635\n","INFO:tensorflow:Step 12700 | Loss: 2.3703 | Spent: 21.8 secs | LR: 0.000626\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.714\n","INFO:tensorflow:Best EM: 0.726\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 12750 | Loss: 2.3572 | Spent: 159.4 secs | LR: 0.000617\n","INFO:tensorflow:Step 12800 | Loss: 2.3613 | Spent: 21.9 secs | LR: 0.000608\n","INFO:tensorflow:Step 12850 | Loss: 2.3635 | Spent: 22.5 secs | LR: 0.000599\n","INFO:tensorflow:Step 12900 | Loss: 2.3644 | Spent: 22.1 secs | LR: 0.000590\n","INFO:tensorflow:Step 12950 | Loss: 2.3551 | Spent: 21.7 secs | LR: 0.000581\n","INFO:tensorflow:Step 13000 | Loss: 2.3552 | Spent: 22.4 secs | LR: 0.000572\n","INFO:tensorflow:Step 13050 | Loss: 2.3766 | Spent: 21.4 secs | LR: 0.000563\n","INFO:tensorflow:Step 13100 | Loss: 2.3605 | Spent: 22.2 secs | LR: 0.000554\n","INFO:tensorflow:Step 13150 | Loss: 2.3654 | Spent: 21.4 secs | LR: 0.000545\n","INFO:tensorflow:Step 13200 | Loss: 2.3469 | Spent: 22.2 secs | LR: 0.000536\n","INFO:tensorflow:Step 13250 | Loss: 2.3492 | Spent: 22.7 secs | LR: 0.000527\n","INFO:tensorflow:Step 13300 | Loss: 2.3602 | Spent: 22.0 secs | LR: 0.000518\n","INFO:tensorflow:Step 13350 | Loss: 2.3575 | Spent: 21.7 secs | LR: 0.000509\n","INFO:tensorflow:Step 13400 | Loss: 2.3552 | Spent: 21.1 secs | LR: 0.000500\n","INFO:tensorflow:Step 13450 | Loss: 2.3526 | Spent: 21.0 secs | LR: 0.000491\n","INFO:tensorflow:Step 13500 | Loss: 2.3706 | Spent: 21.3 secs | LR: 0.000483\n","INFO:tensorflow:Step 13550 | Loss: 2.3526 | Spent: 21.1 secs | LR: 0.000474\n","INFO:tensorflow:Step 13600 | Loss: 2.3545 | Spent: 21.7 secs | LR: 0.000465\n","INFO:tensorflow:Step 13650 | Loss: 2.3587 | Spent: 22.1 secs | LR: 0.000456\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," __________________________|_____________________________________________________ \n"," | | | | | sl:location \n"," | | | | | | \n"," | | | | | in:get_location \n"," | | | | | ________________|_______________ \n"," | | | | sl:category_even sl:search_radius sl:location_user\n"," | | | | t | | \n"," | | | | ______________|__________ | | \n","what times are playing the nutcracker show near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.730\n","INFO:tensorflow:Best EM: 0.730\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 13700 | Loss: 2.3505 | Spent: 170.6 secs | LR: 0.000447\n","INFO:tensorflow:Step 13750 | Loss: 2.3537 | Spent: 21.9 secs | LR: 0.000438\n","INFO:tensorflow:Step 13800 | Loss: 2.3565 | Spent: 21.2 secs | LR: 0.000429\n","INFO:tensorflow:Step 13850 | Loss: 2.3491 | Spent: 21.1 secs | LR: 0.000420\n","INFO:tensorflow:Step 13900 | Loss: 2.3498 | Spent: 21.5 secs | LR: 0.000411\n","INFO:tensorflow:Step 13950 | Loss: 2.3530 | Spent: 21.1 secs | LR: 0.000402\n","INFO:tensorflow:Step 14000 | Loss: 2.3548 | Spent: 21.7 secs | LR: 0.000393\n","INFO:tensorflow:Step 14050 | Loss: 2.3475 | Spent: 22.2 secs | LR: 0.000384\n","INFO:tensorflow:Step 14100 | Loss: 2.3486 | Spent: 21.6 secs | LR: 0.000375\n","INFO:tensorflow:Step 14150 | Loss: 2.3525 | Spent: 21.7 secs | LR: 0.000366\n","INFO:tensorflow:Step 14200 | Loss: 2.3498 | Spent: 21.8 secs | LR: 0.000357\n","INFO:tensorflow:Step 14250 | Loss: 2.3504 | Spent: 22.2 secs | LR: 0.000348\n","INFO:tensorflow:Step 14300 | Loss: 2.3509 | Spent: 21.8 secs | LR: 0.000339\n","INFO:tensorflow:Step 14350 | Loss: 2.3579 | Spent: 22.2 secs | LR: 0.000330\n","INFO:tensorflow:Step 14400 | Loss: 2.3542 | Spent: 21.7 secs | LR: 0.000321\n","INFO:tensorflow:Step 14450 | Loss: 2.3515 | Spent: 22.1 secs | LR: 0.000312\n","INFO:tensorflow:Step 14500 | Loss: 2.3592 | Spent: 22.3 secs | LR: 0.000303\n","INFO:tensorflow:Step 14550 | Loss: 2.3621 | Spent: 21.1 secs | LR: 0.000294\n","INFO:tensorflow:Step 14600 | Loss: 2.3452 | Spent: 22.6 secs | LR: 0.000286\n","INFO:tensorflow:Step 14650 | Loss: 2.3504 | Spent: 22.1 secs | LR: 0.000277\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.733\n","INFO:tensorflow:Best EM: 0.733\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 14700 | Loss: 2.3450 | Spent: 158.6 secs | LR: 0.000268\n","INFO:tensorflow:Step 14750 | Loss: 2.3443 | Spent: 22.2 secs | LR: 0.000259\n","INFO:tensorflow:Step 14800 | Loss: 2.3418 | Spent: 21.8 secs | LR: 0.000250\n","INFO:tensorflow:Step 14850 | Loss: 2.3424 | Spent: 23.0 secs | LR: 0.000241\n","INFO:tensorflow:Step 14900 | Loss: 2.3603 | Spent: 21.7 secs | LR: 0.000232\n","INFO:tensorflow:Step 14950 | Loss: 2.3416 | Spent: 21.7 secs | LR: 0.000223\n","INFO:tensorflow:Step 15000 | Loss: 2.3495 | Spent: 22.4 secs | LR: 0.000214\n","INFO:tensorflow:Step 15050 | Loss: 2.3448 | Spent: 22.6 secs | LR: 0.000205\n","INFO:tensorflow:Step 15100 | Loss: 2.3385 | Spent: 22.7 secs | LR: 0.000196\n","INFO:tensorflow:Step 15150 | Loss: 2.3420 | Spent: 21.5 secs | LR: 0.000187\n","INFO:tensorflow:Step 15200 | Loss: 2.3415 | Spent: 21.6 secs | LR: 0.000178\n","INFO:tensorflow:Step 15250 | Loss: 2.3433 | Spent: 21.6 secs | LR: 0.000169\n","INFO:tensorflow:Step 15300 | Loss: 2.3435 | Spent: 21.5 secs | LR: 0.000160\n","INFO:tensorflow:Step 15350 | Loss: 2.3530 | Spent: 21.1 secs | LR: 0.000151\n","INFO:tensorflow:Step 15400 | Loss: 2.3550 | Spent: 21.6 secs | LR: 0.000142\n","INFO:tensorflow:Step 15450 | Loss: 2.3419 | Spent: 21.2 secs | LR: 0.000133\n","INFO:tensorflow:Step 15500 | Loss: 2.3450 | Spent: 21.7 secs | LR: 0.000124\n","INFO:tensorflow:Step 15550 | Loss: 2.3479 | Spent: 22.8 secs | LR: 0.000115\n","INFO:tensorflow:Step 15600 | Loss: 2.3499 | Spent: 21.7 secs | LR: 0.000106\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.738\n","INFO:tensorflow:Best EM: 0.738\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 15650 | Loss: 2.3574 | Spent: 157.8 secs | LR: 0.000103\n","INFO:tensorflow:Step 15700 | Loss: 2.3475 | Spent: 22.4 secs | LR: 0.000111\n","INFO:tensorflow:Step 15750 | Loss: 2.3460 | Spent: 21.5 secs | LR: 0.000120\n","INFO:tensorflow:Step 15800 | Loss: 2.3434 | Spent: 22.5 secs | LR: 0.000129\n","INFO:tensorflow:Step 15850 | Loss: 2.3489 | Spent: 21.2 secs | LR: 0.000138\n","INFO:tensorflow:Step 15900 | Loss: 2.3509 | Spent: 22.0 secs | LR: 0.000147\n","INFO:tensorflow:Step 15950 | Loss: 2.3463 | Spent: 22.3 secs | LR: 0.000156\n","INFO:tensorflow:Step 16000 | Loss: 2.3425 | Spent: 21.1 secs | LR: 0.000165\n","INFO:tensorflow:Step 16050 | Loss: 2.3442 | Spent: 21.4 secs | LR: 0.000174\n","INFO:tensorflow:Step 16100 | Loss: 2.3482 | Spent: 21.7 secs | LR: 0.000183\n","INFO:tensorflow:Step 16150 | Loss: 2.3437 | Spent: 21.3 secs | LR: 0.000192\n","INFO:tensorflow:Step 16200 | Loss: 2.3424 | Spent: 22.9 secs | LR: 0.000201\n","INFO:tensorflow:Step 16250 | Loss: 2.3425 | Spent: 22.6 secs | LR: 0.000210\n","INFO:tensorflow:Step 16300 | Loss: 2.3450 | Spent: 22.3 secs | LR: 0.000219\n","INFO:tensorflow:Step 16350 | Loss: 2.3469 | Spent: 22.2 secs | LR: 0.000228\n","INFO:tensorflow:Step 16400 | Loss: 2.3524 | Spent: 21.8 secs | LR: 0.000237\n","INFO:tensorflow:Step 16450 | Loss: 2.3536 | Spent: 23.0 secs | LR: 0.000246\n","INFO:tensorflow:Step 16500 | Loss: 2.3401 | Spent: 22.2 secs | LR: 0.000255\n","INFO:tensorflow:Step 16550 | Loss: 2.3392 | Spent: 22.8 secs | LR: 0.000264\n","INFO:tensorflow:Step 16600 | Loss: 2.3427 | Spent: 21.9 secs | LR: 0.000273\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.733\n","INFO:tensorflow:Best EM: 0.738\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 16650 | Loss: 2.3453 | Spent: 159.9 secs | LR: 0.000282\n","INFO:tensorflow:Step 16700 | Loss: 2.3434 | Spent: 22.5 secs | LR: 0.000291\n","INFO:tensorflow:Step 16750 | Loss: 2.3421 | Spent: 22.3 secs | LR: 0.000299\n","INFO:tensorflow:Step 16800 | Loss: 2.3392 | Spent: 22.7 secs | LR: 0.000308\n","INFO:tensorflow:Step 16850 | Loss: 2.3532 | Spent: 22.4 secs | LR: 0.000317\n","INFO:tensorflow:Step 16900 | Loss: 2.3529 | Spent: 22.1 secs | LR: 0.000326\n","INFO:tensorflow:Step 16950 | Loss: 2.3489 | Spent: 22.3 secs | LR: 0.000335\n","INFO:tensorflow:Step 17000 | Loss: 2.3557 | Spent: 21.7 secs | LR: 0.000344\n","INFO:tensorflow:Step 17050 | Loss: 2.3409 | Spent: 23.6 secs | LR: 0.000353\n","INFO:tensorflow:Step 17100 | Loss: 2.3471 | Spent: 22.3 secs | LR: 0.000362\n","INFO:tensorflow:Step 17150 | Loss: 2.3491 | Spent: 21.6 secs | LR: 0.000371\n","INFO:tensorflow:Step 17200 | Loss: 2.3471 | Spent: 21.8 secs | LR: 0.000380\n","INFO:tensorflow:Step 17250 | Loss: 2.3474 | Spent: 22.4 secs | LR: 0.000389\n","INFO:tensorflow:Step 17300 | Loss: 2.3542 | Spent: 23.4 secs | LR: 0.000398\n","INFO:tensorflow:Step 17350 | Loss: 2.3399 | Spent: 22.1 secs | LR: 0.000407\n","INFO:tensorflow:Step 17400 | Loss: 2.3544 | Spent: 21.9 secs | LR: 0.000416\n","INFO:tensorflow:Step 17450 | Loss: 2.3479 | Spent: 23.5 secs | LR: 0.000425\n","INFO:tensorflow:Step 17500 | Loss: 2.3558 | Spent: 22.8 secs | LR: 0.000434\n","INFO:tensorflow:Step 17550 | Loss: 2.3483 | Spent: 22.1 secs | LR: 0.000443\n","INFO:tensorflow:Step 17600 | Loss: 2.3494 | Spent: 21.2 secs | LR: 0.000452\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," __________________________|_____________________________________________________ \n"," | | | | | sl:location \n"," | | | | | | \n"," | | | | | in:get_location \n"," | | | | | ________________|_______________ \n"," | | | | sl:category_even sl:search_radius sl:location_user\n"," | | | | t | | \n"," | | | | ______________|__________ | | \n","what times are playing the nutcracker show near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.729\n","INFO:tensorflow:Best EM: 0.738\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 17650 | Loss: 2.3572 | Spent: 156.2 secs | LR: 0.000461\n","INFO:tensorflow:Step 17700 | Loss: 2.3419 | Spent: 22.8 secs | LR: 0.000470\n","INFO:tensorflow:Step 17750 | Loss: 2.3379 | Spent: 22.5 secs | LR: 0.000479\n","INFO:tensorflow:Step 17800 | Loss: 2.3517 | Spent: 21.8 secs | LR: 0.000488\n","INFO:tensorflow:Step 17850 | Loss: 2.3412 | Spent: 22.4 secs | LR: 0.000496\n","INFO:tensorflow:Step 17900 | Loss: 2.3535 | Spent: 21.5 secs | LR: 0.000505\n","INFO:tensorflow:Step 17950 | Loss: 2.3568 | Spent: 21.8 secs | LR: 0.000514\n","INFO:tensorflow:Step 18000 | Loss: 2.3485 | Spent: 21.6 secs | LR: 0.000523\n","INFO:tensorflow:Step 18050 | Loss: 2.3423 | Spent: 21.8 secs | LR: 0.000532\n","INFO:tensorflow:Step 18100 | Loss: 2.3596 | Spent: 22.0 secs | LR: 0.000541\n","INFO:tensorflow:Step 18150 | Loss: 2.3456 | Spent: 23.2 secs | LR: 0.000550\n","INFO:tensorflow:Step 18200 | Loss: 2.3495 | Spent: 22.6 secs | LR: 0.000559\n","INFO:tensorflow:Step 18250 | Loss: 2.3518 | Spent: 21.8 secs | LR: 0.000568\n","INFO:tensorflow:Step 18300 | Loss: 2.3460 | Spent: 21.8 secs | LR: 0.000577\n","INFO:tensorflow:Step 18350 | Loss: 2.3533 | Spent: 22.4 secs | LR: 0.000586\n","INFO:tensorflow:Step 18400 | Loss: 2.3497 | Spent: 21.7 secs | LR: 0.000595\n","INFO:tensorflow:Step 18450 | Loss: 2.3493 | Spent: 21.6 secs | LR: 0.000604\n","INFO:tensorflow:Step 18500 | Loss: 2.3615 | Spent: 21.7 secs | LR: 0.000613\n","INFO:tensorflow:Step 18550 | Loss: 2.3479 | Spent: 22.3 secs | LR: 0.000622\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.728\n","INFO:tensorflow:Best EM: 0.738\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 18600 | Loss: 2.3517 | Spent: 158.6 secs | LR: 0.000631\n","INFO:tensorflow:Step 18650 | Loss: 2.3439 | Spent: 21.3 secs | LR: 0.000640\n","INFO:tensorflow:Step 18700 | Loss: 2.3481 | Spent: 22.5 secs | LR: 0.000649\n","INFO:tensorflow:Step 18750 | Loss: 2.3564 | Spent: 22.9 secs | LR: 0.000658\n","INFO:tensorflow:Step 18800 | Loss: 2.3495 | Spent: 22.8 secs | LR: 0.000667\n","INFO:tensorflow:Step 18850 | Loss: 2.3544 | Spent: 22.4 secs | LR: 0.000676\n","INFO:tensorflow:Step 18900 | Loss: 2.3541 | Spent: 21.9 secs | LR: 0.000684\n","INFO:tensorflow:Step 18950 | Loss: 2.3481 | Spent: 23.1 secs | LR: 0.000693\n","INFO:tensorflow:Step 19000 | Loss: 2.3630 | Spent: 21.7 secs | LR: 0.000702\n","INFO:tensorflow:Step 19050 | Loss: 2.3536 | Spent: 22.5 secs | LR: 0.000711\n","INFO:tensorflow:Step 19100 | Loss: 2.3600 | Spent: 21.7 secs | LR: 0.000720\n","INFO:tensorflow:Step 19150 | Loss: 2.3518 | Spent: 21.8 secs | LR: 0.000729\n","INFO:tensorflow:Step 19200 | Loss: 2.3476 | Spent: 22.6 secs | LR: 0.000738\n","INFO:tensorflow:Step 19250 | Loss: 2.3622 | Spent: 23.3 secs | LR: 0.000747\n","INFO:tensorflow:Step 19300 | Loss: 2.3452 | Spent: 21.7 secs | LR: 0.000756\n","INFO:tensorflow:Step 19350 | Loss: 2.3476 | Spent: 22.6 secs | LR: 0.000765\n","INFO:tensorflow:Step 19400 | Loss: 2.3578 | Spent: 22.8 secs | LR: 0.000774\n","INFO:tensorflow:Step 19450 | Loss: 2.3528 | Spent: 22.3 secs | LR: 0.000783\n","INFO:tensorflow:Step 19500 | Loss: 2.3832 | Spent: 22.0 secs | LR: 0.000792\n","INFO:tensorflow:Step 19550 | Loss: 2.3656 | Spent: 21.8 secs | LR: 0.000799\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.726\n","INFO:tensorflow:Best EM: 0.738\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 19600 | Loss: 2.3545 | Spent: 165.4 secs | LR: 0.000790\n","INFO:tensorflow:Step 19650 | Loss: 2.3541 | Spent: 22.2 secs | LR: 0.000781\n","INFO:tensorflow:Step 19700 | Loss: 2.3456 | Spent: 22.3 secs | LR: 0.000772\n","INFO:tensorflow:Step 19750 | Loss: 2.3430 | Spent: 22.4 secs | LR: 0.000763\n","INFO:tensorflow:Step 19800 | Loss: 2.3525 | Spent: 21.9 secs | LR: 0.000754\n","INFO:tensorflow:Step 19850 | Loss: 2.3652 | Spent: 23.4 secs | LR: 0.000745\n","INFO:tensorflow:Step 19900 | Loss: 2.3428 | Spent: 21.7 secs | LR: 0.000736\n","INFO:tensorflow:Step 19950 | Loss: 2.3517 | Spent: 21.5 secs | LR: 0.000727\n","INFO:tensorflow:Step 20000 | Loss: 2.3524 | Spent: 22.0 secs | LR: 0.000719\n","INFO:tensorflow:Step 20050 | Loss: 2.3506 | Spent: 21.7 secs | LR: 0.000710\n","INFO:tensorflow:Step 20100 | Loss: 2.3525 | Spent: 22.0 secs | LR: 0.000701\n","INFO:tensorflow:Step 20150 | Loss: 2.3540 | Spent: 22.7 secs | LR: 0.000692\n","INFO:tensorflow:Step 20200 | Loss: 2.3496 | Spent: 21.4 secs | LR: 0.000683\n","INFO:tensorflow:Step 20250 | Loss: 2.3487 | Spent: 22.1 secs | LR: 0.000674\n","INFO:tensorflow:Step 20300 | Loss: 2.4059 | Spent: 22.1 secs | LR: 0.000665\n","INFO:tensorflow:Step 20350 | Loss: 2.3500 | Spent: 21.9 secs | LR: 0.000656\n","INFO:tensorflow:Step 20400 | Loss: 2.3514 | Spent: 22.4 secs | LR: 0.000647\n","INFO:tensorflow:Step 20450 | Loss: 2.3528 | Spent: 22.5 secs | LR: 0.000638\n","INFO:tensorflow:Step 20500 | Loss: 2.3516 | Spent: 21.7 secs | LR: 0.000629\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," __________________________|_____________________________________________________ \n"," | | | | | sl:location \n"," | | | | | | \n"," | | | | | in:get_location \n"," | | | | | ________________|_______________ \n"," | | | | sl:category_even sl:search_radius sl:location_user\n"," | | | | t | | \n"," | | | | ______________|__________ | | \n","what times are playing the nutcracker show near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.725\n","INFO:tensorflow:Best EM: 0.738\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 20550 | Loss: 2.3398 | Spent: 159.5 secs | LR: 0.000620\n","INFO:tensorflow:Step 20600 | Loss: 2.3410 | Spent: 23.0 secs | LR: 0.000611\n","INFO:tensorflow:Step 20650 | Loss: 2.3458 | Spent: 22.5 secs | LR: 0.000602\n","INFO:tensorflow:Step 20700 | Loss: 2.3431 | Spent: 24.0 secs | LR: 0.000593\n","INFO:tensorflow:Step 20750 | Loss: 2.3490 | Spent: 23.0 secs | LR: 0.000584\n","INFO:tensorflow:Step 20800 | Loss: 2.3452 | Spent: 22.0 secs | LR: 0.000575\n","INFO:tensorflow:Step 20850 | Loss: 2.3443 | Spent: 21.6 secs | LR: 0.000566\n","INFO:tensorflow:Step 20900 | Loss: 2.3405 | Spent: 22.8 secs | LR: 0.000557\n","INFO:tensorflow:Step 20950 | Loss: 2.3539 | Spent: 22.2 secs | LR: 0.000548\n","INFO:tensorflow:Step 21000 | Loss: 2.3455 | Spent: 21.8 secs | LR: 0.000539\n","INFO:tensorflow:Step 21050 | Loss: 2.3505 | Spent: 22.6 secs | LR: 0.000530\n","INFO:tensorflow:Step 21100 | Loss: 2.3505 | Spent: 21.7 secs | LR: 0.000522\n","INFO:tensorflow:Step 21150 | Loss: 2.3414 | Spent: 22.2 secs | LR: 0.000513\n","INFO:tensorflow:Step 21200 | Loss: 2.3399 | Spent: 22.7 secs | LR: 0.000504\n","INFO:tensorflow:Step 21250 | Loss: 2.3528 | Spent: 22.1 secs | LR: 0.000495\n","INFO:tensorflow:Step 21300 | Loss: 2.3585 | Spent: 22.5 secs | LR: 0.000486\n","INFO:tensorflow:Step 21350 | Loss: 2.3402 | Spent: 22.6 secs | LR: 0.000477\n","INFO:tensorflow:Step 21400 | Loss: 2.3403 | Spent: 21.6 secs | LR: 0.000468\n","INFO:tensorflow:Step 21450 | Loss: 2.3368 | Spent: 21.3 secs | LR: 0.000459\n","INFO:tensorflow:Step 21500 | Loss: 2.3371 | Spent: 22.2 secs | LR: 0.000450\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.739\n","INFO:tensorflow:Best EM: 0.739\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 21550 | Loss: 2.3376 | Spent: 164.7 secs | LR: 0.000441\n","INFO:tensorflow:Step 21600 | Loss: 2.3420 | Spent: 22.6 secs | LR: 0.000432\n","INFO:tensorflow:Step 21650 | Loss: 2.3418 | Spent: 22.3 secs | LR: 0.000423\n","INFO:tensorflow:Step 21700 | Loss: 2.3402 | Spent: 22.5 secs | LR: 0.000414\n","INFO:tensorflow:Step 21750 | Loss: 2.3399 | Spent: 22.6 secs | LR: 0.000405\n","INFO:tensorflow:Step 21800 | Loss: 2.3361 | Spent: 22.8 secs | LR: 0.000396\n","INFO:tensorflow:Step 21850 | Loss: 2.3401 | Spent: 22.7 secs | LR: 0.000387\n","INFO:tensorflow:Step 21900 | Loss: 2.3418 | Spent: 22.4 secs | LR: 0.000378\n","INFO:tensorflow:Step 21950 | Loss: 2.3419 | Spent: 22.0 secs | LR: 0.000369\n","INFO:tensorflow:Step 22000 | Loss: 2.3419 | Spent: 23.0 secs | LR: 0.000360\n","INFO:tensorflow:Step 22050 | Loss: 2.3403 | Spent: 22.3 secs | LR: 0.000351\n","INFO:tensorflow:Step 22100 | Loss: 2.3387 | Spent: 22.9 secs | LR: 0.000342\n","INFO:tensorflow:Step 22150 | Loss: 2.3366 | Spent: 21.9 secs | LR: 0.000334\n","INFO:tensorflow:Step 22200 | Loss: 2.3387 | Spent: 21.8 secs | LR: 0.000325\n","INFO:tensorflow:Step 22250 | Loss: 2.3400 | Spent: 22.2 secs | LR: 0.000316\n","INFO:tensorflow:Step 22300 | Loss: 2.3344 | Spent: 22.2 secs | LR: 0.000307\n","INFO:tensorflow:Step 22350 | Loss: 2.3407 | Spent: 22.1 secs | LR: 0.000298\n","INFO:tensorflow:Step 22400 | Loss: 2.3405 | Spent: 22.2 secs | LR: 0.000289\n","INFO:tensorflow:Step 22450 | Loss: 2.3392 | Spent: 22.3 secs | LR: 0.000280\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.734\n","INFO:tensorflow:Best EM: 0.739\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 22500 | Loss: 2.3356 | Spent: 154.0 secs | LR: 0.000271\n","INFO:tensorflow:Step 22550 | Loss: 2.3418 | Spent: 22.5 secs | LR: 0.000262\n","INFO:tensorflow:Step 22600 | Loss: 2.3375 | Spent: 22.1 secs | LR: 0.000253\n","INFO:tensorflow:Step 22650 | Loss: 2.3432 | Spent: 22.2 secs | LR: 0.000244\n","INFO:tensorflow:Step 22700 | Loss: 2.3390 | Spent: 21.9 secs | LR: 0.000235\n","INFO:tensorflow:Step 22750 | Loss: 2.3358 | Spent: 22.4 secs | LR: 0.000226\n","INFO:tensorflow:Step 22800 | Loss: 2.3388 | Spent: 22.3 secs | LR: 0.000217\n","INFO:tensorflow:Step 22850 | Loss: 2.3314 | Spent: 22.3 secs | LR: 0.000208\n","INFO:tensorflow:Step 22900 | Loss: 2.3336 | Spent: 21.9 secs | LR: 0.000199\n","INFO:tensorflow:Step 22950 | Loss: 2.3353 | Spent: 22.4 secs | LR: 0.000190\n","INFO:tensorflow:Step 23000 | Loss: 2.3493 | Spent: 22.9 secs | LR: 0.000181\n","INFO:tensorflow:Step 23050 | Loss: 2.3366 | Spent: 21.4 secs | LR: 0.000172\n","INFO:tensorflow:Step 23100 | Loss: 2.3404 | Spent: 21.8 secs | LR: 0.000163\n","INFO:tensorflow:Step 23150 | Loss: 2.3370 | Spent: 21.8 secs | LR: 0.000154\n","INFO:tensorflow:Step 23200 | Loss: 2.3403 | Spent: 21.9 secs | LR: 0.000145\n","INFO:tensorflow:Step 23250 | Loss: 2.3389 | Spent: 21.8 secs | LR: 0.000137\n","INFO:tensorflow:Step 23300 | Loss: 2.3346 | Spent: 21.9 secs | LR: 0.000128\n","INFO:tensorflow:Step 23350 | Loss: 2.3345 | Spent: 22.7 secs | LR: 0.000119\n","INFO:tensorflow:Step 23400 | Loss: 2.3424 | Spent: 22.0 secs | LR: 0.000110\n","INFO:tensorflow:Step 23450 | Loss: 2.3360 | Spent: 22.1 secs | LR: 0.000101\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.738\n","INFO:tensorflow:Best EM: 0.739\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 23500 | Loss: 2.3378 | Spent: 150.7 secs | LR: 0.000108\n","INFO:tensorflow:Step 23550 | Loss: 2.3369 | Spent: 22.0 secs | LR: 0.000117\n","INFO:tensorflow:Step 23600 | Loss: 2.3362 | Spent: 22.0 secs | LR: 0.000126\n","INFO:tensorflow:Step 23650 | Loss: 2.3429 | Spent: 22.0 secs | LR: 0.000135\n","INFO:tensorflow:Step 23700 | Loss: 2.3388 | Spent: 21.5 secs | LR: 0.000144\n","INFO:tensorflow:Step 23750 | Loss: 2.3360 | Spent: 21.9 secs | LR: 0.000153\n","INFO:tensorflow:Step 23800 | Loss: 2.3344 | Spent: 22.9 secs | LR: 0.000162\n","INFO:tensorflow:Step 23850 | Loss: 2.3370 | Spent: 21.3 secs | LR: 0.000171\n","INFO:tensorflow:Step 23900 | Loss: 2.3337 | Spent: 23.1 secs | LR: 0.000180\n","INFO:tensorflow:Step 23950 | Loss: 2.3398 | Spent: 22.8 secs | LR: 0.000189\n","INFO:tensorflow:Step 24000 | Loss: 2.3348 | Spent: 21.8 secs | LR: 0.000198\n","INFO:tensorflow:Step 24050 | Loss: 2.3387 | Spent: 22.3 secs | LR: 0.000207\n","INFO:tensorflow:Step 24100 | Loss: 2.3370 | Spent: 22.0 secs | LR: 0.000216\n","INFO:tensorflow:Step 24150 | Loss: 2.3374 | Spent: 22.1 secs | LR: 0.000225\n","INFO:tensorflow:Step 24200 | Loss: 2.3334 | Spent: 22.5 secs | LR: 0.000234\n","INFO:tensorflow:Step 24250 | Loss: 2.3347 | Spent: 22.4 secs | LR: 0.000243\n","INFO:tensorflow:Step 24300 | Loss: 2.3430 | Spent: 21.6 secs | LR: 0.000251\n","INFO:tensorflow:Step 24350 | Loss: 2.3339 | Spent: 22.2 secs | LR: 0.000260\n","INFO:tensorflow:Step 24400 | Loss: 2.3337 | Spent: 22.7 secs | LR: 0.000269\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.732\n","INFO:tensorflow:Best EM: 0.739\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 24450 | Loss: 2.3329 | Spent: 153.1 secs | LR: 0.000278\n","INFO:tensorflow:Step 24500 | Loss: 2.3375 | Spent: 22.3 secs | LR: 0.000287\n","INFO:tensorflow:Step 24550 | Loss: 2.3377 | Spent: 22.6 secs | LR: 0.000296\n","INFO:tensorflow:Step 24600 | Loss: 2.3356 | Spent: 22.1 secs | LR: 0.000305\n","INFO:tensorflow:Step 24650 | Loss: 2.3364 | Spent: 23.5 secs | LR: 0.000314\n","INFO:tensorflow:Step 24700 | Loss: 2.3335 | Spent: 23.0 secs | LR: 0.000323\n","INFO:tensorflow:Step 24750 | Loss: 2.3449 | Spent: 22.0 secs | LR: 0.000332\n","INFO:tensorflow:Step 24800 | Loss: 2.3380 | Spent: 22.3 secs | LR: 0.000341\n","INFO:tensorflow:Step 24850 | Loss: 2.3355 | Spent: 21.9 secs | LR: 0.000350\n","INFO:tensorflow:Step 24900 | Loss: 2.3349 | Spent: 22.3 secs | LR: 0.000359\n","INFO:tensorflow:Step 24950 | Loss: 2.3361 | Spent: 21.7 secs | LR: 0.000368\n","INFO:tensorflow:Step 25000 | Loss: 2.3373 | Spent: 21.8 secs | LR: 0.000377\n","INFO:tensorflow:Step 25050 | Loss: 2.3392 | Spent: 23.1 secs | LR: 0.000386\n","INFO:tensorflow:Step 25100 | Loss: 2.3354 | Spent: 22.5 secs | LR: 0.000395\n","INFO:tensorflow:Step 25150 | Loss: 2.3417 | Spent: 22.5 secs | LR: 0.000404\n","INFO:tensorflow:Step 25200 | Loss: 2.3418 | Spent: 22.7 secs | LR: 0.000413\n","INFO:tensorflow:Step 25250 | Loss: 2.3349 | Spent: 21.8 secs | LR: 0.000422\n","INFO:tensorflow:Step 25300 | Loss: 2.3388 | Spent: 22.0 secs | LR: 0.000431\n","INFO:tensorflow:Step 25350 | Loss: 2.3390 | Spent: 22.7 secs | LR: 0.000440\n","INFO:tensorflow:Step 25400 | Loss: 2.3412 | Spent: 21.8 secs | LR: 0.000448\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.729\n","INFO:tensorflow:Best EM: 0.739\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 25450 | Loss: 2.3357 | Spent: 155.7 secs | LR: 0.000457\n","INFO:tensorflow:Step 25500 | Loss: 2.3372 | Spent: 22.5 secs | LR: 0.000466\n","INFO:tensorflow:Step 25550 | Loss: 2.3388 | Spent: 22.1 secs | LR: 0.000475\n","INFO:tensorflow:Step 25600 | Loss: 2.3363 | Spent: 21.7 secs | LR: 0.000484\n","INFO:tensorflow:Step 25650 | Loss: 2.3372 | Spent: 23.4 secs | LR: 0.000493\n","INFO:tensorflow:Step 25700 | Loss: 2.3380 | Spent: 21.9 secs | LR: 0.000502\n","INFO:tensorflow:Step 25750 | Loss: 2.3380 | Spent: 22.0 secs | LR: 0.000511\n","INFO:tensorflow:Step 25800 | Loss: 2.3342 | Spent: 23.0 secs | LR: 0.000520\n","INFO:tensorflow:Step 25850 | Loss: 2.3352 | Spent: 22.6 secs | LR: 0.000529\n","INFO:tensorflow:Step 25900 | Loss: 2.3395 | Spent: 22.3 secs | LR: 0.000538\n","INFO:tensorflow:Step 25950 | Loss: 2.3356 | Spent: 22.1 secs | LR: 0.000547\n","INFO:tensorflow:Step 26000 | Loss: 2.3408 | Spent: 23.3 secs | LR: 0.000556\n","INFO:tensorflow:Step 26050 | Loss: 2.3397 | Spent: 22.6 secs | LR: 0.000565\n","INFO:tensorflow:Step 26100 | Loss: 2.3405 | Spent: 22.1 secs | LR: 0.000574\n","INFO:tensorflow:Step 26150 | Loss: 2.3414 | Spent: 23.3 secs | LR: 0.000583\n","INFO:tensorflow:Step 26200 | Loss: 2.3367 | Spent: 22.3 secs | LR: 0.000592\n","INFO:tensorflow:Step 26250 | Loss: 2.3443 | Spent: 22.3 secs | LR: 0.000601\n","INFO:tensorflow:Step 26300 | Loss: 2.3425 | Spent: 22.1 secs | LR: 0.000610\n","INFO:tensorflow:Step 26350 | Loss: 2.3422 | Spent: 22.0 secs | LR: 0.000619\n","INFO:tensorflow:Step 26400 | Loss: 2.3432 | Spent: 22.6 secs | LR: 0.000628\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.731\n","INFO:tensorflow:Best EM: 0.739\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 26450 | Loss: 2.3442 | Spent: 157.7 secs | LR: 0.000637\n","INFO:tensorflow:Step 26500 | Loss: 2.3373 | Spent: 22.6 secs | LR: 0.000645\n","INFO:tensorflow:Step 26550 | Loss: 2.3385 | Spent: 23.5 secs | LR: 0.000654\n","INFO:tensorflow:Step 26600 | Loss: 2.3437 | Spent: 22.7 secs | LR: 0.000663\n","INFO:tensorflow:Step 26650 | Loss: 2.3439 | Spent: 22.8 secs | LR: 0.000672\n","INFO:tensorflow:Step 26700 | Loss: 2.3554 | Spent: 23.2 secs | LR: 0.000681\n","INFO:tensorflow:Step 26750 | Loss: 2.3445 | Spent: 23.3 secs | LR: 0.000690\n","INFO:tensorflow:Step 26800 | Loss: 2.3425 | Spent: 22.1 secs | LR: 0.000699\n","INFO:tensorflow:Step 26850 | Loss: 2.3430 | Spent: 22.4 secs | LR: 0.000708\n","INFO:tensorflow:Step 26900 | Loss: 2.3518 | Spent: 22.8 secs | LR: 0.000717\n","INFO:tensorflow:Step 26950 | Loss: 2.3383 | Spent: 22.4 secs | LR: 0.000726\n","INFO:tensorflow:Step 27000 | Loss: 2.3414 | Spent: 22.9 secs | LR: 0.000735\n","INFO:tensorflow:Step 27050 | Loss: 2.3470 | Spent: 23.3 secs | LR: 0.000744\n","INFO:tensorflow:Step 27100 | Loss: 2.3449 | Spent: 22.3 secs | LR: 0.000753\n","INFO:tensorflow:Step 27150 | Loss: 2.3452 | Spent: 22.2 secs | LR: 0.000762\n","INFO:tensorflow:Step 27200 | Loss: 2.3538 | Spent: 23.1 secs | LR: 0.000771\n","INFO:tensorflow:Step 27250 | Loss: 2.3454 | Spent: 23.2 secs | LR: 0.000780\n","INFO:tensorflow:Step 27300 | Loss: 2.3504 | Spent: 22.5 secs | LR: 0.000789\n","INFO:tensorflow:Step 27350 | Loss: 2.3515 | Spent: 22.5 secs | LR: 0.000798\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.725\n","INFO:tensorflow:Best EM: 0.739\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 27400 | Loss: 2.3426 | Spent: 159.3 secs | LR: 0.000793\n","INFO:tensorflow:Step 27450 | Loss: 2.3437 | Spent: 22.7 secs | LR: 0.000784\n","INFO:tensorflow:Step 27500 | Loss: 2.3439 | Spent: 22.1 secs | LR: 0.000775\n","INFO:tensorflow:Step 27550 | Loss: 2.3649 | Spent: 23.2 secs | LR: 0.000767\n","INFO:tensorflow:Step 27600 | Loss: 2.3438 | Spent: 23.7 secs | LR: 0.000758\n","INFO:tensorflow:Step 27650 | Loss: 2.3433 | Spent: 22.1 secs | LR: 0.000749\n","INFO:tensorflow:Step 27700 | Loss: 2.3414 | Spent: 22.5 secs | LR: 0.000740\n","INFO:tensorflow:Step 27750 | Loss: 2.3534 | Spent: 22.9 secs | LR: 0.000731\n","INFO:tensorflow:Step 27800 | Loss: 2.3407 | Spent: 22.8 secs | LR: 0.000722\n","INFO:tensorflow:Step 27850 | Loss: 2.3433 | Spent: 23.2 secs | LR: 0.000713\n","INFO:tensorflow:Step 27900 | Loss: 2.3434 | Spent: 22.5 secs | LR: 0.000704\n","INFO:tensorflow:Step 27950 | Loss: 2.3485 | Spent: 23.3 secs | LR: 0.000695\n","INFO:tensorflow:Step 28000 | Loss: 2.3453 | Spent: 22.8 secs | LR: 0.000686\n","INFO:tensorflow:Step 28050 | Loss: 2.3467 | Spent: 22.4 secs | LR: 0.000677\n","INFO:tensorflow:Step 28100 | Loss: 2.3427 | Spent: 22.7 secs | LR: 0.000668\n","INFO:tensorflow:Step 28150 | Loss: 2.3397 | Spent: 22.9 secs | LR: 0.000659\n","INFO:tensorflow:Step 28200 | Loss: 2.3461 | Spent: 23.2 secs | LR: 0.000650\n","INFO:tensorflow:Step 28250 | Loss: 2.3431 | Spent: 23.0 secs | LR: 0.000641\n","INFO:tensorflow:Step 28300 | Loss: 2.3413 | Spent: 22.9 secs | LR: 0.000632\n","INFO:tensorflow:Step 28350 | Loss: 2.3382 | Spent: 23.1 secs | LR: 0.000623\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," __________________________|_____________________________________________________ \n"," | | | | | sl:location \n"," | | | | | | \n"," | | | | | in:get_location \n"," | | | | | ________________|_______________ \n"," | | | | sl:category_even sl:search_radius sl:location_user\n"," | | | | t | | \n"," | | | | ______________|__________ | | \n","what times are playing the nutcracker show near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.731\n","INFO:tensorflow:Best EM: 0.739\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 28400 | Loss: 2.3399 | Spent: 158.6 secs | LR: 0.000614\n","INFO:tensorflow:Step 28450 | Loss: 2.3443 | Spent: 22.9 secs | LR: 0.000605\n","INFO:tensorflow:Step 28500 | Loss: 2.3422 | Spent: 23.1 secs | LR: 0.000596\n","INFO:tensorflow:Step 28550 | Loss: 2.3349 | Spent: 23.0 secs | LR: 0.000587\n","INFO:tensorflow:Step 28600 | Loss: 2.3427 | Spent: 22.8 secs | LR: 0.000578\n","INFO:tensorflow:Step 28650 | Loss: 2.3489 | Spent: 23.4 secs | LR: 0.000570\n","INFO:tensorflow:Step 28700 | Loss: 2.3425 | Spent: 23.2 secs | LR: 0.000561\n","INFO:tensorflow:Step 28750 | Loss: 2.3391 | Spent: 22.6 secs | LR: 0.000552\n","INFO:tensorflow:Step 28800 | Loss: 2.3405 | Spent: 22.9 secs | LR: 0.000543\n","INFO:tensorflow:Step 28850 | Loss: 2.3340 | Spent: 22.5 secs | LR: 0.000534\n","INFO:tensorflow:Step 28900 | Loss: 2.3353 | Spent: 22.7 secs | LR: 0.000525\n","INFO:tensorflow:Step 28950 | Loss: 2.3369 | Spent: 23.5 secs | LR: 0.000516\n","INFO:tensorflow:Step 29000 | Loss: 2.3363 | Spent: 22.1 secs | LR: 0.000507\n","INFO:tensorflow:Step 29050 | Loss: 2.3354 | Spent: 23.6 secs | LR: 0.000498\n","INFO:tensorflow:Step 29100 | Loss: 2.3464 | Spent: 23.0 secs | LR: 0.000489\n","INFO:tensorflow:Step 29150 | Loss: 2.3397 | Spent: 23.4 secs | LR: 0.000480\n","INFO:tensorflow:Step 29200 | Loss: 2.3335 | Spent: 22.6 secs | LR: 0.000471\n","INFO:tensorflow:Step 29250 | Loss: 2.3402 | Spent: 23.1 secs | LR: 0.000462\n","INFO:tensorflow:Step 29300 | Loss: 2.3340 | Spent: 23.2 secs | LR: 0.000453\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," __________________________|_____________________________________________________ \n"," | | | | | sl:location \n"," | | | | | | \n"," | | | | | in:get_location \n"," | | | | | ________________|_______________ \n"," | | | | sl:category_even sl:search_radius sl:location_user\n"," | | | | t | | \n"," | | | | ______________|__________ | | \n","what times are playing the nutcracker show near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.731\n","INFO:tensorflow:Best EM: 0.739\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 29350 | Loss: 2.3740 | Spent: 153.7 secs | LR: 0.000444\n","INFO:tensorflow:Step 29400 | Loss: 2.3379 | Spent: 22.7 secs | LR: 0.000435\n","INFO:tensorflow:Step 29450 | Loss: 2.3344 | Spent: 22.5 secs | LR: 0.000426\n","INFO:tensorflow:Step 29500 | Loss: 2.3476 | Spent: 22.6 secs | LR: 0.000417\n","INFO:tensorflow:Step 29550 | Loss: 2.3415 | Spent: 23.1 secs | LR: 0.000408\n","INFO:tensorflow:Step 29600 | Loss: 2.3376 | Spent: 22.4 secs | LR: 0.000399\n","INFO:tensorflow:Step 29650 | Loss: 2.3343 | Spent: 22.5 secs | LR: 0.000390\n","INFO:tensorflow:Step 29700 | Loss: 2.3353 | Spent: 22.8 secs | LR: 0.000382\n","INFO:tensorflow:Step 29750 | Loss: 2.3311 | Spent: 24.1 secs | LR: 0.000373\n","INFO:tensorflow:Step 29800 | Loss: 2.3309 | Spent: 22.5 secs | LR: 0.000364\n","INFO:tensorflow:Step 29850 | Loss: 2.3321 | Spent: 22.9 secs | LR: 0.000355\n","INFO:tensorflow:Step 29900 | Loss: 2.3363 | Spent: 23.5 secs | LR: 0.000346\n","INFO:tensorflow:Step 29950 | Loss: 2.3347 | Spent: 23.3 secs | LR: 0.000337\n","INFO:tensorflow:Step 30000 | Loss: 2.3359 | Spent: 23.1 secs | LR: 0.000328\n","INFO:tensorflow:Step 30050 | Loss: 2.3348 | Spent: 21.9 secs | LR: 0.000319\n","INFO:tensorflow:Step 30100 | Loss: 2.3334 | Spent: 23.0 secs | LR: 0.000310\n","INFO:tensorflow:Step 30150 | Loss: 2.3346 | Spent: 22.3 secs | LR: 0.000301\n","INFO:tensorflow:Step 30200 | Loss: 2.3333 | Spent: 22.0 secs | LR: 0.000292\n","INFO:tensorflow:Step 30250 | Loss: 2.3368 | Spent: 23.8 secs | LR: 0.000283\n","INFO:tensorflow:Step 30300 | Loss: 2.3286 | Spent: 22.5 secs | LR: 0.000274\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," __________________________|_____________________________________________________ \n"," | | | | | sl:location \n"," | | | | | | \n"," | | | | | in:get_location \n"," | | | | | ________________|_______________ \n"," | | | | sl:category_even sl:search_radius sl:location_user\n"," | | | | t | | \n"," | | | | ______________|__________ | | \n","what times are playing the nutcracker show near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.741\n","INFO:tensorflow:Best EM: 0.741\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 30350 | Loss: 2.3349 | Spent: 156.1 secs | LR: 0.000265\n","INFO:tensorflow:Step 30400 | Loss: 2.3348 | Spent: 23.3 secs | LR: 0.000256\n","INFO:tensorflow:Step 30450 | Loss: 2.3333 | Spent: 22.3 secs | LR: 0.000247\n","INFO:tensorflow:Step 30500 | Loss: 2.3336 | Spent: 23.5 secs | LR: 0.000238\n","INFO:tensorflow:Step 30550 | Loss: 2.3358 | Spent: 23.3 secs | LR: 0.000229\n","INFO:tensorflow:Step 30600 | Loss: 2.3300 | Spent: 22.7 secs | LR: 0.000220\n","INFO:tensorflow:Step 30650 | Loss: 2.3308 | Spent: 22.1 secs | LR: 0.000211\n","INFO:tensorflow:Step 30700 | Loss: 2.3300 | Spent: 23.2 secs | LR: 0.000202\n","INFO:tensorflow:Step 30750 | Loss: 2.3321 | Spent: 22.9 secs | LR: 0.000193\n","INFO:tensorflow:Step 30800 | Loss: 2.3343 | Spent: 24.4 secs | LR: 0.000185\n","INFO:tensorflow:Step 30850 | Loss: 2.3316 | Spent: 23.8 secs | LR: 0.000176\n","INFO:tensorflow:Step 30900 | Loss: 2.3343 | Spent: 22.1 secs | LR: 0.000167\n","INFO:tensorflow:Step 30950 | Loss: 2.3294 | Spent: 23.4 secs | LR: 0.000158\n","INFO:tensorflow:Step 31000 | Loss: 2.3369 | Spent: 23.1 secs | LR: 0.000149\n","INFO:tensorflow:Step 31050 | Loss: 2.3313 | Spent: 23.3 secs | LR: 0.000140\n","INFO:tensorflow:Step 31100 | Loss: 2.3278 | Spent: 23.0 secs | LR: 0.000131\n","INFO:tensorflow:Step 31150 | Loss: 2.3332 | Spent: 23.1 secs | LR: 0.000122\n","INFO:tensorflow:Step 31200 | Loss: 2.3308 | Spent: 24.0 secs | LR: 0.000113\n","INFO:tensorflow:Step 31250 | Loss: 2.3311 | Spent: 23.7 secs | LR: 0.000104\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," __________________________|_____________________________________________________ \n"," | | | | | sl:location \n"," | | | | | | \n"," | | | | | in:get_location \n"," | | | | | ________________|_______________ \n"," | | | | sl:category_even sl:search_radius sl:location_user\n"," | | | | t | | \n"," | | | | ______________|__________ | | \n","what times are playing the nutcracker show near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.740\n","INFO:tensorflow:Best EM: 0.741\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 31300 | Loss: 2.3282 | Spent: 154.0 secs | LR: 0.000105\n","INFO:tensorflow:Step 31350 | Loss: 2.3337 | Spent: 22.4 secs | LR: 0.000114\n","INFO:tensorflow:Step 31400 | Loss: 2.3301 | Spent: 23.5 secs | LR: 0.000123\n","INFO:tensorflow:Step 31450 | Loss: 2.3276 | Spent: 24.6 secs | LR: 0.000132\n","INFO:tensorflow:Step 31500 | Loss: 2.3359 | Spent: 23.6 secs | LR: 0.000141\n","INFO:tensorflow:Step 31550 | Loss: 2.3345 | Spent: 22.6 secs | LR: 0.000150\n","INFO:tensorflow:Step 31600 | Loss: 2.3303 | Spent: 22.5 secs | LR: 0.000159\n","INFO:tensorflow:Step 31650 | Loss: 2.3276 | Spent: 23.5 secs | LR: 0.000168\n","INFO:tensorflow:Step 31700 | Loss: 2.3278 | Spent: 23.2 secs | LR: 0.000177\n","INFO:tensorflow:Step 31750 | Loss: 2.3294 | Spent: 22.6 secs | LR: 0.000186\n","INFO:tensorflow:Step 31800 | Loss: 2.3313 | Spent: 23.0 secs | LR: 0.000195\n","INFO:tensorflow:Step 31850 | Loss: 2.3419 | Spent: 23.2 secs | LR: 0.000204\n","INFO:tensorflow:Step 31900 | Loss: 2.3319 | Spent: 23.0 secs | LR: 0.000212\n","INFO:tensorflow:Step 31950 | Loss: 2.3283 | Spent: 23.7 secs | LR: 0.000221\n","INFO:tensorflow:Step 32000 | Loss: 2.3305 | Spent: 23.6 secs | LR: 0.000230\n","INFO:tensorflow:Step 32050 | Loss: 2.3301 | Spent: 22.8 secs | LR: 0.000239\n","INFO:tensorflow:Step 32100 | Loss: 2.3325 | Spent: 22.2 secs | LR: 0.000248\n","INFO:tensorflow:Step 32150 | Loss: 2.3292 | Spent: 23.1 secs | LR: 0.000257\n","INFO:tensorflow:Step 32200 | Loss: 2.3334 | Spent: 23.3 secs | LR: 0.000266\n","INFO:tensorflow:Step 32250 | Loss: 2.3330 | Spent: 22.4 secs | LR: 0.000275\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] playing [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," __________________________|_____________________________________________________ \n"," | | | | | sl:location \n"," | | | | | | \n"," | | | | | in:get_location \n"," | | | | | ________________|_______________ \n"," | | | | sl:category_even sl:search_radius sl:location_user\n"," | | | | t | | \n"," | | | | ______________|__________ | | \n","what times are playing the nutcracker show near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.739\n","INFO:tensorflow:Best EM: 0.741\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 32300 | Loss: 2.3300 | Spent: 157.7 secs | LR: 0.000284\n","INFO:tensorflow:Step 32350 | Loss: 2.3300 | Spent: 23.0 secs | LR: 0.000293\n","INFO:tensorflow:Step 32400 | Loss: 2.3379 | Spent: 22.9 secs | LR: 0.000302\n","INFO:tensorflow:Step 32450 | Loss: 2.3320 | Spent: 23.5 secs | LR: 0.000311\n","INFO:tensorflow:Step 32500 | Loss: 2.3351 | Spent: 23.5 secs | LR: 0.000320\n","INFO:tensorflow:Step 32550 | Loss: 2.3292 | Spent: 23.0 secs | LR: 0.000329\n","INFO:tensorflow:Step 32600 | Loss: 2.3307 | Spent: 23.7 secs | LR: 0.000338\n","INFO:tensorflow:Step 32650 | Loss: 2.3324 | Spent: 24.2 secs | LR: 0.000347\n","INFO:tensorflow:Step 32700 | Loss: 2.3329 | Spent: 24.0 secs | LR: 0.000356\n","INFO:tensorflow:Step 32750 | Loss: 2.3301 | Spent: 22.9 secs | LR: 0.000365\n","INFO:tensorflow:Step 32800 | Loss: 2.3305 | Spent: 24.4 secs | LR: 0.000374\n","INFO:tensorflow:Step 32850 | Loss: 2.3292 | Spent: 23.2 secs | LR: 0.000383\n","INFO:tensorflow:Step 32900 | Loss: 2.3361 | Spent: 23.5 secs | LR: 0.000392\n","INFO:tensorflow:Step 32950 | Loss: 2.3401 | Spent: 23.1 secs | LR: 0.000400\n","INFO:tensorflow:Step 33000 | Loss: 2.3346 | Spent: 23.1 secs | LR: 0.000409\n","INFO:tensorflow:Step 33050 | Loss: 2.3351 | Spent: 23.4 secs | LR: 0.000418\n","INFO:tensorflow:Step 33100 | Loss: 2.3338 | Spent: 23.1 secs | LR: 0.000427\n","INFO:tensorflow:Step 33150 | Loss: 2.3337 | Spent: 22.7 secs | LR: 0.000436\n","INFO:tensorflow:Step 33200 | Loss: 2.3308 | Spent: 23.2 secs | LR: 0.000445\n","INFO:tensorflow:Step 33250 | Loss: 2.3481 | Spent: 23.2 secs | LR: 0.000454\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.737\n","INFO:tensorflow:Best EM: 0.741\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 33300 | Loss: 2.3318 | Spent: 158.2 secs | LR: 0.000463\n","INFO:tensorflow:Step 33350 | Loss: 2.3334 | Spent: 23.0 secs | LR: 0.000472\n","INFO:tensorflow:Step 33400 | Loss: 2.3347 | Spent: 25.0 secs | LR: 0.000481\n","INFO:tensorflow:Step 33450 | Loss: 2.3332 | Spent: 23.0 secs | LR: 0.000490\n","INFO:tensorflow:Step 33500 | Loss: 2.3403 | Spent: 23.0 secs | LR: 0.000499\n","INFO:tensorflow:Step 33550 | Loss: 2.3322 | Spent: 22.9 secs | LR: 0.000508\n","INFO:tensorflow:Step 33600 | Loss: 2.3374 | Spent: 22.8 secs | LR: 0.000517\n","INFO:tensorflow:Step 33650 | Loss: 2.3491 | Spent: 23.4 secs | LR: 0.000526\n","INFO:tensorflow:Step 33700 | Loss: 2.3412 | Spent: 23.0 secs | LR: 0.000535\n","INFO:tensorflow:Step 33750 | Loss: 2.3361 | Spent: 22.7 secs | LR: 0.000544\n","INFO:tensorflow:Step 33800 | Loss: 2.3330 | Spent: 23.9 secs | LR: 0.000553\n","INFO:tensorflow:Step 33850 | Loss: 2.3371 | Spent: 21.7 secs | LR: 0.000562\n","INFO:tensorflow:Step 33900 | Loss: 2.3820 | Spent: 23.0 secs | LR: 0.000571\n","INFO:tensorflow:Step 33950 | Loss: 2.3458 | Spent: 23.3 secs | LR: 0.000580\n","INFO:tensorflow:Step 34000 | Loss: 2.3398 | Spent: 23.0 secs | LR: 0.000589\n","INFO:tensorflow:Step 34050 | Loss: 2.3409 | Spent: 22.9 secs | LR: 0.000597\n","INFO:tensorflow:Step 34100 | Loss: 2.3424 | Spent: 22.6 secs | LR: 0.000606\n","INFO:tensorflow:Step 34150 | Loss: 2.3365 | Spent: 22.9 secs | LR: 0.000615\n","INFO:tensorflow:Step 34200 | Loss: 2.3378 | Spent: 22.6 secs | LR: 0.000624\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," __________________|_____________________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | ______________|__________ | | \n","what times are the nutcracker show near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.732\n","INFO:tensorflow:Best EM: 0.741\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 34250 | Loss: 2.3366 | Spent: 151.8 secs | LR: 0.000633\n","INFO:tensorflow:Step 34300 | Loss: 2.3325 | Spent: 22.5 secs | LR: 0.000642\n","INFO:tensorflow:Step 34350 | Loss: 2.3344 | Spent: 24.1 secs | LR: 0.000651\n","INFO:tensorflow:Step 34400 | Loss: 2.3351 | Spent: 23.4 secs | LR: 0.000660\n","INFO:tensorflow:Step 34450 | Loss: 2.3384 | Spent: 23.2 secs | LR: 0.000669\n","INFO:tensorflow:Step 34500 | Loss: 2.3378 | Spent: 23.0 secs | LR: 0.000678\n","INFO:tensorflow:Step 34550 | Loss: 2.3416 | Spent: 23.2 secs | LR: 0.000687\n","INFO:tensorflow:Step 34600 | Loss: 2.3338 | Spent: 22.6 secs | LR: 0.000696\n","INFO:tensorflow:Step 34650 | Loss: 2.3415 | Spent: 22.9 secs | LR: 0.000705\n","INFO:tensorflow:Step 34700 | Loss: 2.3371 | Spent: 23.7 secs | LR: 0.000714\n","INFO:tensorflow:Step 34750 | Loss: 2.3423 | Spent: 22.8 secs | LR: 0.000723\n","INFO:tensorflow:Step 34800 | Loss: 2.3426 | Spent: 23.4 secs | LR: 0.000732\n","INFO:tensorflow:Step 34850 | Loss: 2.3388 | Spent: 22.9 secs | LR: 0.000741\n","INFO:tensorflow:Step 34900 | Loss: 2.3311 | Spent: 22.2 secs | LR: 0.000750\n","INFO:tensorflow:Step 34950 | Loss: 2.3349 | Spent: 23.0 secs | LR: 0.000759\n","INFO:tensorflow:Step 35000 | Loss: 2.3482 | Spent: 22.9 secs | LR: 0.000768\n","INFO:tensorflow:Step 35050 | Loss: 2.3513 | Spent: 23.1 secs | LR: 0.000777\n","INFO:tensorflow:Step 35100 | Loss: 2.3489 | Spent: 23.2 secs | LR: 0.000785\n","INFO:tensorflow:Step 35150 | Loss: 2.3371 | Spent: 23.0 secs | LR: 0.000794\n","INFO:tensorflow:Step 35200 | Loss: 2.3359 | Spent: 22.2 secs | LR: 0.000797\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.719\n","INFO:tensorflow:Best EM: 0.741\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 35250 | Loss: 2.3403 | Spent: 158.7 secs | LR: 0.000788\n","INFO:tensorflow:Step 35300 | Loss: 2.3427 | Spent: 22.6 secs | LR: 0.000779\n","INFO:tensorflow:Step 35350 | Loss: 2.3353 | Spent: 23.1 secs | LR: 0.000770\n","INFO:tensorflow:Step 35400 | Loss: 2.3339 | Spent: 23.8 secs | LR: 0.000761\n","INFO:tensorflow:Step 35450 | Loss: 2.3375 | Spent: 23.1 secs | LR: 0.000752\n","INFO:tensorflow:Step 35500 | Loss: 2.3421 | Spent: 22.5 secs | LR: 0.000743\n","INFO:tensorflow:Step 35550 | Loss: 2.3435 | Spent: 23.1 secs | LR: 0.000734\n","INFO:tensorflow:Step 35600 | Loss: 2.3417 | Spent: 22.6 secs | LR: 0.000725\n","INFO:tensorflow:Step 35650 | Loss: 2.3349 | Spent: 23.1 secs | LR: 0.000716\n","INFO:tensorflow:Step 35700 | Loss: 2.3389 | Spent: 23.0 secs | LR: 0.000707\n","INFO:tensorflow:Step 35750 | Loss: 2.3530 | Spent: 22.4 secs | LR: 0.000698\n","INFO:tensorflow:Step 35800 | Loss: 2.3400 | Spent: 22.7 secs | LR: 0.000689\n","INFO:tensorflow:Step 35850 | Loss: 2.3420 | Spent: 23.0 secs | LR: 0.000680\n","INFO:tensorflow:Step 35900 | Loss: 2.3419 | Spent: 22.5 secs | LR: 0.000671\n","INFO:tensorflow:Step 35950 | Loss: 2.3410 | Spent: 22.9 secs | LR: 0.000662\n","INFO:tensorflow:Step 36000 | Loss: 2.3415 | Spent: 22.2 secs | LR: 0.000653\n","INFO:tensorflow:Step 36050 | Loss: 2.3354 | Spent: 22.6 secs | LR: 0.000644\n","INFO:tensorflow:Step 36100 | Loss: 2.3426 | Spent: 23.3 secs | LR: 0.000635\n","INFO:tensorflow:Step 36150 | Loss: 2.3425 | Spent: 24.0 secs | LR: 0.000626\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.730\n","INFO:tensorflow:Best EM: 0.741\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 36200 | Loss: 2.3410 | Spent: 155.8 secs | LR: 0.000618\n","INFO:tensorflow:Step 36250 | Loss: 2.3355 | Spent: 22.9 secs | LR: 0.000609\n","INFO:tensorflow:Step 36300 | Loss: 2.3355 | Spent: 23.3 secs | LR: 0.000600\n","INFO:tensorflow:Step 36350 | Loss: 2.3323 | Spent: 23.0 secs | LR: 0.000591\n","INFO:tensorflow:Step 36400 | Loss: 2.3339 | Spent: 23.6 secs | LR: 0.000582\n","INFO:tensorflow:Step 36450 | Loss: 2.3373 | Spent: 24.0 secs | LR: 0.000573\n","INFO:tensorflow:Step 36500 | Loss: 2.3351 | Spent: 22.6 secs | LR: 0.000564\n","INFO:tensorflow:Step 36550 | Loss: 2.3382 | Spent: 23.6 secs | LR: 0.000555\n","INFO:tensorflow:Step 36600 | Loss: 2.3375 | Spent: 23.1 secs | LR: 0.000546\n","INFO:tensorflow:Step 36650 | Loss: 2.3329 | Spent: 22.6 secs | LR: 0.000537\n","INFO:tensorflow:Step 36700 | Loss: 2.3345 | Spent: 23.1 secs | LR: 0.000528\n","INFO:tensorflow:Step 36750 | Loss: 2.3377 | Spent: 23.1 secs | LR: 0.000519\n","INFO:tensorflow:Step 36800 | Loss: 2.3348 | Spent: 23.2 secs | LR: 0.000510\n","INFO:tensorflow:Step 36850 | Loss: 2.3351 | Spent: 23.1 secs | LR: 0.000501\n","INFO:tensorflow:Step 36900 | Loss: 2.3373 | Spent: 23.4 secs | LR: 0.000492\n","INFO:tensorflow:Step 36950 | Loss: 2.3374 | Spent: 22.1 secs | LR: 0.000483\n","INFO:tensorflow:Step 37000 | Loss: 2.3331 | Spent: 23.3 secs | LR: 0.000474\n","INFO:tensorflow:Step 37050 | Loss: 2.3340 | Spent: 22.6 secs | LR: 0.000465\n","INFO:tensorflow:Step 37100 | Loss: 2.3386 | Spent: 23.2 secs | LR: 0.000456\n","INFO:tensorflow:Step 37150 | Loss: 2.3305 | Spent: 24.1 secs | LR: 0.000447\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.733\n","INFO:tensorflow:Best EM: 0.741\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 37200 | Loss: 2.3301 | Spent: 152.7 secs | LR: 0.000438\n","INFO:tensorflow:Step 37250 | Loss: 2.3364 | Spent: 24.0 secs | LR: 0.000429\n","INFO:tensorflow:Step 37300 | Loss: 2.3302 | Spent: 23.5 secs | LR: 0.000421\n","INFO:tensorflow:Step 37350 | Loss: 2.3286 | Spent: 23.5 secs | LR: 0.000412\n","INFO:tensorflow:Step 37400 | Loss: 2.3316 | Spent: 23.5 secs | LR: 0.000403\n","INFO:tensorflow:Step 37450 | Loss: 2.3333 | Spent: 23.0 secs | LR: 0.000394\n","INFO:tensorflow:Step 37500 | Loss: 2.3306 | Spent: 23.0 secs | LR: 0.000385\n","INFO:tensorflow:Step 37550 | Loss: 2.3317 | Spent: 24.4 secs | LR: 0.000376\n","INFO:tensorflow:Step 37600 | Loss: 2.3386 | Spent: 22.6 secs | LR: 0.000367\n","INFO:tensorflow:Step 37650 | Loss: 2.3306 | Spent: 23.3 secs | LR: 0.000358\n","INFO:tensorflow:Step 37700 | Loss: 2.3301 | Spent: 22.8 secs | LR: 0.000349\n","INFO:tensorflow:Step 37750 | Loss: 2.3325 | Spent: 22.7 secs | LR: 0.000340\n","INFO:tensorflow:Step 37800 | Loss: 2.3349 | Spent: 23.4 secs | LR: 0.000331\n","INFO:tensorflow:Step 37850 | Loss: 2.3361 | Spent: 22.7 secs | LR: 0.000322\n","INFO:tensorflow:Step 37900 | Loss: 2.3285 | Spent: 24.3 secs | LR: 0.000313\n","INFO:tensorflow:Step 37950 | Loss: 2.3278 | Spent: 23.6 secs | LR: 0.000304\n","INFO:tensorflow:Step 38000 | Loss: 2.3335 | Spent: 23.5 secs | LR: 0.000295\n","INFO:tensorflow:Step 38050 | Loss: 2.3281 | Spent: 23.2 secs | LR: 0.000286\n","INFO:tensorflow:Step 38100 | Loss: 2.3307 | Spent: 22.4 secs | LR: 0.000277\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.738\n","INFO:tensorflow:Best EM: 0.741\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 38150 | Loss: 2.3287 | Spent: 157.5 secs | LR: 0.000268\n","INFO:tensorflow:Step 38200 | Loss: 2.3280 | Spent: 23.3 secs | LR: 0.000259\n","INFO:tensorflow:Step 38250 | Loss: 2.3263 | Spent: 22.7 secs | LR: 0.000250\n","INFO:tensorflow:Step 38300 | Loss: 2.3310 | Spent: 23.4 secs | LR: 0.000241\n","INFO:tensorflow:Step 38350 | Loss: 2.3324 | Spent: 23.0 secs | LR: 0.000233\n","INFO:tensorflow:Step 38400 | Loss: 2.3282 | Spent: 22.6 secs | LR: 0.000224\n","INFO:tensorflow:Step 38450 | Loss: 2.3307 | Spent: 23.0 secs | LR: 0.000215\n","INFO:tensorflow:Step 38500 | Loss: 2.3300 | Spent: 23.4 secs | LR: 0.000206\n","INFO:tensorflow:Step 38550 | Loss: 2.3284 | Spent: 23.8 secs | LR: 0.000197\n","INFO:tensorflow:Step 38600 | Loss: 2.3304 | Spent: 23.7 secs | LR: 0.000188\n","INFO:tensorflow:Step 38650 | Loss: 2.3269 | Spent: 23.6 secs | LR: 0.000179\n","INFO:tensorflow:Step 38700 | Loss: 2.3315 | Spent: 23.4 secs | LR: 0.000170\n","INFO:tensorflow:Step 38750 | Loss: 2.3367 | Spent: 23.8 secs | LR: 0.000161\n","INFO:tensorflow:Step 38800 | Loss: 2.3274 | Spent: 24.2 secs | LR: 0.000152\n","INFO:tensorflow:Step 38850 | Loss: 2.3285 | Spent: 23.5 secs | LR: 0.000143\n","INFO:tensorflow:Step 38900 | Loss: 2.3301 | Spent: 23.4 secs | LR: 0.000134\n","INFO:tensorflow:Step 38950 | Loss: 2.3281 | Spent: 23.4 secs | LR: 0.000125\n","INFO:tensorflow:Step 39000 | Loss: 2.3265 | Spent: 23.4 secs | LR: 0.000116\n","INFO:tensorflow:Step 39050 | Loss: 2.3299 | Spent: 22.9 secs | LR: 0.000107\n","INFO:tensorflow:Step 39100 | Loss: 2.3322 | Spent: 23.1 secs | LR: 0.000102\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.739\n","INFO:tensorflow:Best EM: 0.741\n","Reading ../data/train.tsv\n","INFO:tensorflow:Step 39150 | Loss: 2.3285 | Spent: 157.1 secs | LR: 0.000111\n","INFO:tensorflow:Step 39200 | Loss: 2.3291 | Spent: 23.3 secs | LR: 0.000120\n","INFO:tensorflow:Step 39250 | Loss: 2.3259 | Spent: 23.7 secs | LR: 0.000129\n","INFO:tensorflow:Step 39300 | Loss: 2.3323 | Spent: 24.2 secs | LR: 0.000138\n","INFO:tensorflow:Step 39350 | Loss: 2.3267 | Spent: 25.4 secs | LR: 0.000147\n","INFO:tensorflow:Step 39400 | Loss: 2.3278 | Spent: 23.0 secs | LR: 0.000156\n","INFO:tensorflow:Step 39450 | Loss: 2.3282 | Spent: 22.8 secs | LR: 0.000164\n","INFO:tensorflow:Step 39500 | Loss: 2.3286 | Spent: 23.1 secs | LR: 0.000173\n","INFO:tensorflow:Step 39550 | Loss: 2.3268 | Spent: 23.8 secs | LR: 0.000182\n","INFO:tensorflow:Step 39600 | Loss: 2.3280 | Spent: 23.2 secs | LR: 0.000191\n","INFO:tensorflow:Step 39650 | Loss: 2.3312 | Spent: 23.1 secs | LR: 0.000200\n","INFO:tensorflow:Step 39700 | Loss: 2.3282 | Spent: 22.6 secs | LR: 0.000209\n","INFO:tensorflow:Step 39750 | Loss: 2.3274 | Spent: 24.0 secs | LR: 0.000218\n","INFO:tensorflow:Step 39800 | Loss: 2.3281 | Spent: 23.8 secs | LR: 0.000227\n","INFO:tensorflow:Step 39850 | Loss: 2.3264 | Spent: 23.3 secs | LR: 0.000236\n","INFO:tensorflow:Step 39900 | Loss: 2.3290 | Spent: 23.0 secs | LR: 0.000245\n","INFO:tensorflow:Step 39950 | Loss: 2.3304 | Spent: 23.7 secs | LR: 0.000254\n","INFO:tensorflow:Step 40000 | Loss: 2.3301 | Spent: 22.7 secs | LR: 0.000263\n","INFO:tensorflow:Step 40050 | Loss: 2.3290 | Spent: 23.3 secs | LR: 0.000272\n","------------\n","minimal test\n","utterance: what times are the nutcracker show playing near me\n","parsed: [ in:get_event what times are [ sl:category_event the nutcracker show playing ] [ sl:location [ in:get_location [ sl:search_radius near ] [ sl:location_user me ] ] ] ]\n","\n"," in:get_event \n"," ________________________|______________________________________________ \n"," | | | | sl:location \n"," | | | | | \n"," | | | | in:get_location \n"," | | | | ________________|_______________ \n"," | | | sl:category_even sl:search_radius sl:location_user\n"," | | | t | | \n"," | | | _________|_________________ | | \n","what times are the nutcracker show playing near me \n","\n","------------\n","Reading ../data/test.tsv\n","INFO:tensorflow:Evaluation: Testing EM: 0.737\n","INFO:tensorflow:Best EM: 0.741\n","10 times not improve the best result, therefore stop training\n"],"name":"stdout"}]}]}