{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 5.0 MNIST\n", "We are getting ~80 correct predictions. This might be because the ten digit signatures are insufficient to capture the different variation of each digit. We might try to define the different flavors of each digit using a round of k-means clustering and labeling in the training dataset then a \"narrow\" signature definition, then prediction on the test dataset using this narrow set of signatures. Also, visualization of these digits as images will be useful. We can also carry over the similarity score as a value based category or use it to generate a ROC curve. " ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "from clustergrammer2 import net\n", "df = {}" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "import clustergrammer_groupby as cby\n", "from copy import deepcopy\n", "import random\n", "random.seed(99)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(784, 70000)" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.load_file('../data/big_data/MNIST_row_labels.txt')\n", "df['mnist'] = net.export_df()\n", "df['mnist'].shape" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "('Zero-0', 'Digit: Zero')\n" ] } ], "source": [ "cols = df['mnist'].columns.tolist()\n", "new_cols = [(x, 'Digit: ' + x.split('-')[0]) for x in cols]\n", "df['mnist-cat'] = deepcopy(df['mnist'])\n", "df['mnist-cat'].columns = new_cols\n", "print(new_cols[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Make Train and Predict" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(784, 35000) (784, 35000)\n" ] } ], "source": [ "cols = df['mnist-cat'].columns.tolist()\n", "random.shuffle(cols)\n", "df['mnist-train'] = df['mnist-cat'][cols[:35000]]\n", "df['mnist-pred'] = df['mnist-cat'][cols[35000:]]\n", "print(df['mnist-train'].shape, df['mnist-pred'].shape)\n", "\n", "net.load_df(df['mnist-train'])\n", "net.normalize(axis='row', norm_type='zscore')\n", "df['mnist-train-z'] = net.export_df()\n", "\n", "net.load_df(df['mnist-pred'])\n", "net.normalize(axis='row', norm_type='zscore')\n", "df['mnist-pred-z'] = net.export_df()" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "def set_cat_colors(cat_color, axis, cat_index, cat_title=False):\n", " for inst_ct in cat_color:\n", " if cat_title != False:\n", " cat_name = cat_title + ': ' + inst_ct\n", " else:\n", " cat_name = inst_ct\n", " \n", " inst_color = cat_color[inst_ct]\n", " net.set_cat_color(axis=axis, cat_index=cat_index, cat_name=cat_name, inst_color=inst_color)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Make Signatures" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make Narrow Digit Signatures" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " (276, 10)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/nickfernandez/anaconda3/envs/py36lab/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:879: RuntimeWarning: invalid value encountered in greater\n", " return (self.a < x) & (x < self.b)\n", "/Users/nickfernandez/anaconda3/envs/py36lab/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:879: RuntimeWarning: invalid value encountered in less\n", " return (self.a < x) & (x < self.b)\n", "/Users/nickfernandez/anaconda3/envs/py36lab/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:1821: RuntimeWarning: invalid value encountered in less_equal\n", " cond2 = cond0 & (x <= self.a)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "-z (276, 10)\n" ] } ], "source": [ "pval_cutoff = 0.00001\n", "num_top_dims = 50\n", "for inst_norm in ['', '-z']:\n", " df['sig' + inst_norm], keep_genes_dict, df_gene_pval, fold_info = cby.generate_signatures(\n", " df['mnist-train' + inst_norm],\n", " 'Digit', num_top_dims=num_top_dims)\n", " print(inst_norm, df['sig' + inst_norm].shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make Narrow Signatures" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Make Broad DataFrames to Construct Narrow Signatures\n", "At the coarse grained level we appear to be able to distinguish\n", "* Three-Five-Eight\n", "* One\n", "* Zero-Two-Six\n", "* Four-Seven-Nine" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "coarse_digits = {}\n", "coarse_digits['One'] = ['One']\n", "coarse_digits['Three-Five-Eight'] = ['Three', 'Five', 'Eight']\n", "coarse_digits['Four-Seven-Nine'] = ['Four', 'Seven', 'Nine']\n", "coarse_digits['Zero-Two-Six'] = ['Zero', 'Two', 'Six']" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(784, 35000) (784, 35000)\n", "One (784, 3926)\n", "Three-Five-Eight (784, 10101)\n", "Four-Seven-Nine (784, 10575)\n", "Zero-Two-Six (784, 10398)\n" ] } ], "source": [ "cols = df['mnist-cat'].columns.tolist()\n", "\n", "random.shuffle(cols)\n", "\n", "df['mnist-train'] = df['mnist-cat'][cols[:35000]]\n", "df['mnist-pred'] = df['mnist-cat'][cols[35000:]]\n", "print(df['mnist-train'].shape, df['mnist-pred'].shape)\n", "\n", "for inst_group in coarse_digits:\n", " cols = df['mnist-train']\n", " keep_cols = [x for x in cols if x[1].split(': ')[1] in coarse_digits[inst_group]]\n", " df[inst_group] = df['mnist-train'][keep_cols]\n", " print(inst_group, df[inst_group].shape)\n", " \n", " net.load_df(df[inst_group])\n", " net.normalize(axis='row', norm_type='zscore')\n", " df[inst_group + '-z'] = net.export_df()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Make Broad Subset Signatures" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Three-Five-Eight\n", "Three-Five-Eight (103, 3)\n", "Three-Five-Eight\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/nickfernandez/anaconda3/envs/py36lab/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:879: RuntimeWarning: invalid value encountered in greater\n", " return (self.a < x) & (x < self.b)\n", "/Users/nickfernandez/anaconda3/envs/py36lab/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:879: RuntimeWarning: invalid value encountered in less\n", " return (self.a < x) & (x < self.b)\n", "/Users/nickfernandez/anaconda3/envs/py36lab/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:1821: RuntimeWarning: invalid value encountered in less_equal\n", " cond2 = cond0 & (x <= self.a)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Three-Five-Eight-z (103, 3)\n", "Four-Seven-Nine\n", "Four-Seven-Nine (99, 3)\n", "Four-Seven-Nine\n", "Four-Seven-Nine-z (99, 3)\n", "Zero-Two-Six\n", "Zero-Two-Six (130, 3)\n", "Zero-Two-Six\n", "Zero-Two-Six-z (130, 3)\n" ] } ], "source": [ "# Generate Signatures\n", "pval_cutoff = 1e-10\n", "num_top_dims=50\n", "\n", "for inst_group in coarse_digits:\n", "\n", " for inst_norm in ['', '-z']:\n", " \n", " if '-' in inst_group:\n", " print(inst_group)\n", " df['sig-' + inst_group + inst_norm], keep_genes_dict, df_gene_pval, fold_info = cby.generate_signatures(\n", " df[inst_group + inst_norm], 'Digit', pval_cutoff=pval_cutoff, num_top_dims=num_top_dims)\n", "\n", " print(inst_group + inst_norm, df['sig-' + inst_group + inst_norm].shape) " ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(99, 3)" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df['sig-Four-Seven-Nine'].shape" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'Eight': '#393b79',\n", " 'Five': '#ff7f0e',\n", " 'Four': 'blue',\n", " 'Nine': 'grey',\n", " 'One': 'black',\n", " 'Seven': 'red',\n", " 'Six': '#FFDB58',\n", " 'Three': '#e377c2',\n", " 'Two': '#2ca02c',\n", " 'Zero': 'yellow'}" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.load_df(df['sig'])\n", "net.cluster()\n", "tmp_cat_color = deepcopy(net.viz['cat_colors']['col']['cat-0'])\n", "cat_color = {}\n", "for inst_key in tmp_cat_color:\n", " cat_color[inst_key.split(': ')[1]] = tmp_cat_color[inst_key]\n", " \n", "cat_color['Zero'] = 'yellow'\n", "cat_color['Four'] = 'blue'\n", "cat_color['Seven'] = 'red'\n", "cat_color['Nine'] = 'grey'\n", "cat_color['One'] = 'black'\n", "\n", "set_cat_colors(cat_color, axis='col', cat_index=1, cat_title='Digit')\n", "cat_color " ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "df2cc8ad6383447fa9af3a8acc97849f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "ExampleWidget(network='{\"row_nodes\": [{\"name\": \"pos_10-10\", \"ini\": 276, \"clust\": 234, \"rank\": 218, \"rankvar\": …" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "net.load_df(df['sig'])\n", "net.widget()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Predict Digit Type Using Signatures" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Predict using Narrow Signatures" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predict: 0.816485714286\n" ] }, { "data": { "text/plain": [ "One 0.933435\n", "Zero 0.901810\n", "Six 0.884370\n", "Seven 0.833700\n", "Three 0.812517\n", "Four 0.803198\n", "Two 0.789207\n", "Nine 0.780338\n", "Eight 0.754257\n", "Five 0.643282\n", "dtype: float64" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Predict\n", "##################\n", "df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred'], df['sig'], truth_level=1,\n", " predict_level='Pred Digit', unknown_thresh=0.0)\n", "\n", "df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)\n", "print('Predict: ', fraction_correct)\n", "ser_correct.sort_values(ascending=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Artifically Broadening the Narrow Digits Improves Performance\n", "Will test running narrow prediction on broad digits." ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "merge_358 = ['Three', 'Five', 'Eight']\n", "merge_479 = ['Four', 'Seven', 'Nine']\n", "merge_026 = ['Zero', 'Two', 'Six']" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predict: 0.900114285714\n", "\n", "broad cell type: 0.900114285714 \n", "\n", "One 0.933435\n", "Four-Seven-Nine 0.910286\n", "Zero-Two-Six 0.906856\n", "Three-Five-Eight 0.869817\n", "dtype: float64\n" ] } ], "source": [ "y_broad = {}\n", "\n", "inst_true = []\n", "for inst_cat in y_info['true']:\n", " if inst_cat in merge_358:\n", " inst_cat = 'Three-Five-Eight'\n", "\n", " if inst_cat in merge_479:\n", " inst_cat = 'Four-Seven-Nine' \n", "\n", " if inst_cat in merge_026:\n", " inst_cat = 'Zero-Two-Six'\n", " \n", " inst_true.append(inst_cat)\n", "\n", "inst_pred = []\n", "for inst_cat in y_info['pred']:\n", " if inst_cat in merge_358:\n", " inst_cat = 'Three-Five-Eight'\n", "\n", " if inst_cat in merge_479:\n", " inst_cat = 'Four-Seven-Nine' \n", "\n", " if inst_cat in merge_026:\n", " inst_cat = 'Zero-Two-Six'\n", " \n", " inst_pred.append(inst_cat)\n", "\n", "y_broad['true'] = inst_true\n", "y_broad['pred'] = inst_pred\n", "\n", "df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_broad)\n", "print('Predict: ', fraction_correct)\n", "\n", "df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_broad)\n", "print('\\nbroad cell type: ', fraction_correct, '\\n')\n", "print(ser_correct.sort_values(ascending=False))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Predict using Broad then Narrow" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "# Predict\n", "##################\n", "df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred'], df['sig'], truth_level=1,\n", " predict_level='Broad Digit', unknown_thresh=0.0)\n", "\n", "ini_broad_cols = df_pred_cat.columns.tolist()\n", "\n", "broad_cols = []\n", "for inst_col in ini_broad_cols:\n", " \n", " inst_cat = inst_col[2].split(': ')[1]\n", "\n", " broad_predict = inst_col[2]\n", " for inst_group in coarse_digits:\n", " if inst_cat in coarse_digits[inst_group]:\n", " broad_predict = 'Broad Digit: ' + inst_group\n", " \n", " broad_col = (inst_col[0], inst_col[1], broad_predict)\n", " broad_cols.append(broad_col)\n", " \n", "df['pred-broad'] = deepcopy(df['mnist-pred'])\n", "df['pred-broad'].columns = broad_cols\n" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "One (784, 4456)\n", "Three-Five-Eight (784, 9918)\n", "Four-Seven-Nine (784, 10312)\n", "Zero-Two-Six (784, 10314)\n" ] }, { "data": { "text/plain": [ "(784, 35000)" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Re-run prediction on individual broad digits if necessary\n", "df_list = []\n", "for inst_broad in coarse_digits:\n", " cols = df['pred-broad'].columns.tolist()\n", " keep_cols = [x for x in cols if x[2].split(': ')[1] == inst_broad]\n", "\n", " inst_df = df['pred-broad'][keep_cols]\n", " \n", " print(inst_broad, inst_df.shape) \n", " \n", " # run prediction if necessary\n", " if '-' in inst_broad:\n", " \n", " tmp_cols = inst_df.columns.tolist()\n", " # drop previous prediction \n", " new_cols = [(x[0], x[1]) for x in tmp_cols]\n", " inst_df.columns = new_cols\n", " \n", " # predict using narrow signature\n", " df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(inst_df, \n", " df['sig-' + inst_broad], truth_level=1, predict_level='Digit', unknown_thresh=0.0) \n", " \n", " inst_df.columns = df_pred_cat.columns.tolist()\n", " df_list.append(inst_df)\n", "\n", " else:\n", " tmp_cols = inst_df.columns.tolist()\n", " new_cols = [(x[0], x[1], x[2].replace('Broad Digit:', 'Digit:')) for x in tmp_cols]\n", " inst_df.columns = new_cols\n", " df_list.append(inst_df)\n", " \n", "df['pred-narrow'] = pd.concat(df_list, axis=1)\n", "df['pred-narrow'].shape\n" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Predict: 0.8172\n" ] }, { "data": { "text/plain": [ "One 0.933435\n", "Zero 0.922872\n", "Six 0.877687\n", "Seven 0.825165\n", "Four 0.815407\n", "Three 0.812238\n", "Nine 0.805134\n", "Two 0.777122\n", "Eight 0.751615\n", "Five 0.622612\n", "dtype: float64" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cols = df['pred-narrow'].columns.tolist()\n", "cols[0]\n", "\n", "y_info = {}\n", "y_info['true'] = [x[1].split(': ')[1] for x in cols]\n", "y_info['pred'] = [x[2].split(': ')[1] for x in cols]\n", "\n", "df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)\n", "print('Predict: ', fraction_correct)\n", "ser_correct.sort_values(ascending=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df_pred_cat.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "net.load_df(df_pred_cat)\n", "set_cat_colors(cat_color, axis='col', cat_index=2, cat_title='Pred Digit')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "net.load_df(df_pred_cat)\n", "net.random_sample(axis='col', num_samples=2500, random_state=100)\n", "net.widget()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Z-scored Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Predict\n", "##################\n", "df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred-z'], df['sig-z'], truth_level=1,\n", " predict_level='Pred Digit', unknown_thresh=0.0)\n", "\n", "df_sig_sim.columns = df_pred_cat.columns.tolist()\n", "df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)\n", "print('Predict: ', fraction_correct)\n", "\n", "df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)\n", "print('\\nbroad cell type: ', fraction_correct, '\\n')\n", "print(ser_correct.sort_values(ascending=False))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df_conf.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# net.load_df(df_conf)\n", "# net.widget()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df_sig_sim.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "net.load_df(df_sig_sim)\n", "net.random_sample(axis='col', num_samples=2500, random_state=99)\n", "net.load_df(net.export_df().round(2))\n", "net.widget()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "net.load_df(df_pred_cat)\n", "net.random_sample(axis='col', num_samples=2500, random_state=99)\n", "net.load_df(net.export_df().round(2))\n", "net.widget()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Broad Predictions\n", "## Make Broad Signature" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Merge Categories" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cols = df['mnist-cat'].columns.tolist()\n", "\n", "\n", "random.shuffle(cols)\n", "\n", "df['mnist-train'] = df['mnist-cat'][cols[:35000]]\n", "df['mnist-pred'] = df['mnist-cat'][cols[35000:]]\n", "print(df['mnist-train'].shape, df['mnist-pred'].shape)\n", "\n", "for inst_data in ['mnist-train', 'mnist-pred']:\n", " cols = df[inst_data]\n", " new_cols = []\n", " for inst_col in cols:\n", " inst_cat = inst_col[1].split(': ')[1]\n", " \n", " if inst_cat in merge_358:\n", " inst_cat = 'Three-Five-Eight'\n", " \n", " if inst_cat in merge_479:\n", " inst_cat = 'Four-Seven-Nine' \n", " \n", " if inst_cat in merge_026:\n", " inst_cat = 'Zero-Two-Six'\n", " \n", " new_col = (inst_col[0], 'Coarse: ' + inst_cat, inst_col[1])\n", " new_cols.append(new_col)\n", " \n", " df[inst_data + '-coarse'] = deepcopy(df[inst_data])\n", " df[inst_data + '-coarse'].columns = new_cols\n", " print(df[inst_data + '-coarse'].shape)\n", " \n", " net.load_df(df[inst_data + '-coarse'])\n", " net.normalize(axis='row', norm_type='zscore')\n", " df[inst_data + '-coarse-z'] = net.export_df()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Make Broad Signature" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pval_cutoff = 0.00001\n", "num_top_dims = 50\n", "for inst_norm in ['', '-z']:\n", " df['sig-broad' + inst_norm], keep_genes_dict, df_gene_pval, fold_info = cby.generate_signatures(\n", " df['mnist-train-coarse' + inst_norm],\n", " 'Coarse', num_top_dims=num_top_dims)\n", " print(inst_norm, df['sig-broad' + inst_norm].shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Predict Broad then Narrow\n", "Need to predict broad digits, then separate each of the broad categories and predict using the narrow signature." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Predict Broad Digits" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Predict\n", "##################\n", "df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred-coarse'], df['sig-broad'], truth_level=1,\n", " predict_level='Pred Digit', unknown_thresh=0.0)\n", "\n", "df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)\n", "print('Predict: ', fraction_correct)\n", "\n", "df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)\n", "print('\\nbroad cell type: ', fraction_correct, '\\n')\n", "print(ser_correct.sort_values(ascending=False))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Predict\n", "##################\n", "df_pred_cat, df_sig_sim, y_info = cby.predict_cats_from_sigs(df['mnist-pred-358-z'], df['sig-z'], truth_level=1,\n", " predict_level='Pred Digit', unknown_thresh=0.0)\n", "\n", "df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)\n", "print('Predict: ', fraction_correct)\n", "\n", "df_conf, population, ser_correct, fraction_correct = cby.confusion_matrix_and_correct_series(y_info)\n", "print(ser_correct.sort_values(ascending=False))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df_pred_cat.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "net.load_df(df_pred_cat)\n", "net.random_sample(axis='col', num_samples=2500, random_state=99)\n", "net.load_df(net.export_df().round(2))\n", "net.widget()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cat_color['Three-Five-Eight'] = 'red'\n", "cat_color['Four-Seven-Nine'] = 'blue'\n", "cat_color['Zero-Two-Six'] = 'yellow'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "net.load_df(df_pred_cat)\n", "set_cat_colors(cat_color, axis='col', cat_index=1, cat_title='Coarse')\n", "set_cat_colors(cat_color, axis='col', cat_index=3, cat_title='Pred Digit')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "net.load_df(df_pred_cat)\n", "net.random_sample(axis='col', num_samples=2500, random_state=99)\n", "net.load_df(net.export_df().round(2))\n", "net.widget()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" } }, "nbformat": 4, "nbformat_minor": 2 }