{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import glob\n",
    "import os\n",
    "from collections import defaultdict\n",
    "\n",
    "import keras.backend as K\n",
    "from keras.preprocessing import image\n",
    "from keras.applications.vgg16 import VGG16\n",
    "from keras.applications.vgg16 import preprocess_input as vgg_preprocess_input\n",
    "\n",
    "from helpers import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get VGG activations for images of different categories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_1 (InputLayer)         (None, 224, 224, 3)       0         \n",
      "_________________________________________________________________\n",
      "block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      \n",
      "_________________________________________________________________\n",
      "block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     \n",
      "_________________________________________________________________\n",
      "block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         \n",
      "_________________________________________________________________\n",
      "block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     \n",
      "_________________________________________________________________\n",
      "block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    \n",
      "_________________________________________________________________\n",
      "block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         \n",
      "_________________________________________________________________\n",
      "block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    \n",
      "_________________________________________________________________\n",
      "block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    \n",
      "_________________________________________________________________\n",
      "block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    \n",
      "_________________________________________________________________\n",
      "block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         \n",
      "_________________________________________________________________\n",
      "block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   \n",
      "_________________________________________________________________\n",
      "block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   \n",
      "_________________________________________________________________\n",
      "block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   \n",
      "_________________________________________________________________\n",
      "block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         \n",
      "_________________________________________________________________\n",
      "block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   \n",
      "_________________________________________________________________\n",
      "block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   \n",
      "_________________________________________________________________\n",
      "block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   \n",
      "_________________________________________________________________\n",
      "block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         \n",
      "_________________________________________________________________\n",
      "flatten (Flatten)            (None, 25088)             0         \n",
      "_________________________________________________________________\n",
      "fc1 (Dense)                  (None, 4096)              102764544 \n",
      "_________________________________________________________________\n",
      "fc2 (Dense)                  (None, 4096)              16781312  \n",
      "_________________________________________________________________\n",
      "predictions (Dense)          (None, 1000)              4097000   \n",
      "=================================================================\n",
      "Total params: 138,357,544\n",
      "Trainable params: 138,357,544\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "# Load a model\n",
    "model = VGG16(weights='imagenet', include_top=True)\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['block3_conv1', 'block3_conv2', 'block3_conv3', 'block4_conv1', 'block4_conv2', 'block4_conv3', 'block5_conv1', 'block5_conv2', 'block5_conv3', 'flatten', 'fc1', 'fc2']\n"
     ]
    }
   ],
   "source": [
    "LAYERS = [l.name for l in model.layers[7:-1] if 'pool' not in l.name]  # layer of interest\n",
    "print(LAYERS)\n",
    "CAT_FLD = 'images/'        # folder with images separated by category into subfolders\n",
    "ACT_FLD = 'activations/'\n",
    "SEL_FLD = 'selectivity'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def vgg_image(path, target_size=(224, 224)):\n",
    "    \"\"\"Tranform image for vgg with keras module\"\"\"\n",
    "    img = image.load_img(path, target_size=target_size)\n",
    "    x = image.img_to_array(img)\n",
    "    x = np.expand_dims(img, axis=0).astype(np.float)\n",
    "    x = vgg_preprocess_input(x)\n",
    "    return x\n",
    "\n",
    "\n",
    "def vgg_images_from_filenames(filenames):\n",
    "    \"\"\"Load and preprocess images from filenames\"\"\"\n",
    "    images = [vgg_image(f) for f in filenames]\n",
    "    images = [x for x in images if x is not None]\n",
    "    images = np.squeeze(np.vstack(images))\n",
    "    return images\n",
    "\n",
    "\n",
    "def get_activations(model, images, layers_names):\n",
    "    \"\"\"Get activations for the input from specified layer\"\"\"\n",
    "    \n",
    "    inp = model.input                                          \n",
    "    names_outputs = [(l.name, l.output) for l in model.layers if l.name in layers_names] \n",
    "    names = [x[0] for x in names_outputs]\n",
    "    outputs = [x[1] for x in names_outputs]\n",
    "    functor = K.function([inp] + [K.learning_phase()], outputs) \n",
    "    \n",
    "    return dict(zip(names, functor([images])))\n",
    "\n",
    "def get_activations_batches(model, filenames, layers, output_fld, img_loader_func, \n",
    "                            batch_size=128, crop_model_=True, **kwargs): \n",
    "    \"\"\"Split input into batches and get activations\"\"\"\n",
    "    \n",
    "    file_gen = batch_generator(filenames, batch_size, equal_size=False)\n",
    "    \n",
    "    for l in layers:\n",
    "        l_fld = os.path.join(output_fld, l)\n",
    "        if not os.path.exists(l_fld):\n",
    "            os.makedirs(l_fld)\n",
    "    \n",
    "    # get activations\n",
    "    for batch_n, batch_files in enumerate(file_gen):\n",
    "        print(\"batch\", batch_n)\n",
    "        images = img_loader_func(batch_files,  **kwargs)\n",
    "        activations = get_activations(model, images, layers)\n",
    "        \n",
    "        for l_name, v in activations.items():\n",
    "            for path, l_acts in zip(batch_files, v):\n",
    "                name = os.path.basename(path).split('.')[0]            \n",
    "                new_path = os.path.join(output_fld, l_name, name + '.npy')\n",
    "                np.save(new_path, l_acts)\n",
    "\n",
    "    print(\"All activations saved to folder\", output_fld)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "In category FURNITURE found 301 files\n",
      "Getting activations for files...\n",
      "batch 0\n",
      "batch 1\n",
      "batch 2\n",
      "All activations saved to folder activations/furniture\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for cat in os.listdir(CAT_FLD):\n",
    "    if cat == 'furniture':\n",
    "        cat_files = glob.glob(os.path.join(CAT_FLD, cat, '*'))\n",
    "        print('In category {} found {} files'.format(cat.upper(), len(cat_files)))\n",
    "        print('Getting activations for files...')\n",
    "\n",
    "        output_cat_fld = os.path.join(ACT_FLD, cat)\n",
    "        get_activations_batches(model, cat_files, LAYERS, output_fld=output_cat_fld, \n",
    "                                img_loader_func=vgg_images_from_filenames, batch_size=128)\n",
    "        print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Find category selectivity for filters based on retrieved activations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cat_mask(categories, cat):\n",
    "    return np.array([x == cat for x in categories])\n",
    "    \n",
    "def category_selectivity_index(activations, cat_mask, axis=0):\n",
    "    \n",
    "    act_norm = norm_values(activations)\n",
    "    cat = act_norm[cat_mask]    \n",
    "    noncat = act_norm[~cat_mask]\n",
    "    selectivity = np.sum(cat, axis)/cat.shape[axis] - np.sum(noncat, axis)/noncat.shape[axis]\n",
    "    return selectivity\n",
    "\n",
    "def cat_selective_func(func, activations, labels, **kwargs):\n",
    "    \n",
    "    cat_selectivity = []\n",
    "    \n",
    "    for l in np.unique(labels):\n",
    "        category_mask = get_cat_mask(labels, l)\n",
    "        selectivity = func(activations, category_mask, **kwargs)\n",
    "        cat_selectivity.append([selectivity])\n",
    "    cat_selectivity = np.vstack(cat_selectivity)    \n",
    "    return cat_selectivity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load all activations\n",
    "\n",
    "def load_act_from_fld(fld, layer=''):\n",
    "    \n",
    "    label_dict = {}\n",
    "    activations = []\n",
    "    labels = []\n",
    "    \n",
    "    for (n, cat) in enumerate(os.listdir(fld)):\n",
    "        label_dict[n] = cat\n",
    "        cat_fld = os.path.join(fld, cat, layer)\n",
    "        np_files = glob.glob(cat_fld + '/*.np[yz]')\n",
    "        if len(np_files) > 0:\n",
    "            labels.extend([n] * len(np_files))\n",
    "            cat_activations = []\n",
    "            print('Loading {} files from {}'.format(len(np_files), cat_fld))\n",
    "            for (i, f) in enumerate(np_files):\n",
    "                activations.append(np.load(f))\n",
    "\n",
    "    return np.asarray(activations), labels, label_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading 301 files from activations/furniture/block3_conv1\n",
      "Loading 301 files from activations/dog/block3_conv1\n",
      "Loading 301 files from activations/iseeface/block3_conv1\n",
      "Loading 301 files from activations/cat/block3_conv1\n",
      "Loading 301 files from activations/portrait/block3_conv1\n",
      "Loading 301 files from activations/car/block3_conv1\n",
      "Loading 301 files from activations/furniture/block3_conv2\n",
      "Loading 301 files from activations/dog/block3_conv2\n",
      "Loading 301 files from activations/iseeface/block3_conv2\n",
      "Loading 301 files from activations/cat/block3_conv2\n",
      "Loading 301 files from activations/portrait/block3_conv2\n",
      "Loading 301 files from activations/car/block3_conv2\n",
      "Loading 301 files from activations/furniture/block3_conv3\n",
      "Loading 301 files from activations/dog/block3_conv3\n",
      "Loading 301 files from activations/iseeface/block3_conv3\n",
      "Loading 301 files from activations/cat/block3_conv3\n",
      "Loading 301 files from activations/portrait/block3_conv3\n",
      "Loading 301 files from activations/car/block3_conv3\n",
      "Loading 301 files from activations/furniture/block4_conv1\n",
      "Loading 301 files from activations/dog/block4_conv1\n",
      "Loading 301 files from activations/iseeface/block4_conv1\n",
      "Loading 301 files from activations/cat/block4_conv1\n",
      "Loading 301 files from activations/portrait/block4_conv1\n",
      "Loading 301 files from activations/car/block4_conv1\n",
      "Loading 301 files from activations/furniture/block4_conv2\n",
      "Loading 301 files from activations/dog/block4_conv2\n",
      "Loading 301 files from activations/iseeface/block4_conv2\n",
      "Loading 301 files from activations/cat/block4_conv2\n",
      "Loading 301 files from activations/portrait/block4_conv2\n",
      "Loading 301 files from activations/car/block4_conv2\n",
      "Loading 301 files from activations/furniture/block4_conv3\n",
      "Loading 301 files from activations/dog/block4_conv3\n",
      "Loading 301 files from activations/iseeface/block4_conv3\n",
      "Loading 301 files from activations/cat/block4_conv3\n",
      "Loading 301 files from activations/portrait/block4_conv3\n",
      "Loading 301 files from activations/car/block4_conv3\n",
      "Loading 301 files from activations/furniture/block5_conv1\n",
      "Loading 301 files from activations/dog/block5_conv1\n",
      "Loading 301 files from activations/iseeface/block5_conv1\n",
      "Loading 301 files from activations/cat/block5_conv1\n",
      "Loading 301 files from activations/portrait/block5_conv1\n",
      "Loading 301 files from activations/car/block5_conv1\n",
      "Loading 301 files from activations/furniture/block5_conv2\n",
      "Loading 301 files from activations/dog/block5_conv2\n",
      "Loading 301 files from activations/iseeface/block5_conv2\n",
      "Loading 301 files from activations/cat/block5_conv2\n",
      "Loading 301 files from activations/portrait/block5_conv2\n",
      "Loading 301 files from activations/car/block5_conv2\n",
      "Loading 301 files from activations/furniture/block5_conv3\n",
      "Loading 301 files from activations/dog/block5_conv3\n",
      "Loading 301 files from activations/iseeface/block5_conv3\n",
      "Loading 301 files from activations/cat/block5_conv3\n",
      "Loading 301 files from activations/portrait/block5_conv3\n",
      "Loading 301 files from activations/car/block5_conv3\n"
     ]
    }
   ],
   "source": [
    "ACTIVATIONS = defaultdict(lambda : defaultdict())\n",
    "\n",
    "for l in LAYERS:\n",
    "    acts, labels, label_dict = load_act_from_fld(ACT_FLD, layer=l)\n",
    "    ACTIVATIONS[l]['activations'] = acts\n",
    "    ACTIVATIONS[l]['label_dict'] = label_dict\n",
    "    ACTIVATIONS[l]['labels'] = labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "block3_conv1\n",
      "block5_conv3\n",
      "block4_conv1\n",
      "block3_conv2\n",
      "block4_conv2\n",
      "block5_conv2\n",
      "block4_conv3\n",
      "block3_conv3\n",
      "block5_conv1\n"
     ]
    }
   ],
   "source": [
    "# Find selectivity for filters in different layers\n",
    "\n",
    "for l, v in ACTIVATIONS.items():  \n",
    "    print(l)\n",
    "    cat_sel = cat_selective_func(category_selectivity_index, v['activations'], v['labels'])\n",
    "    # select by mean filter activation\n",
    "    cat_sel_filter = np.mean(cat_sel, axis=(1, 2))\n",
    "\n",
    "    # # select by activation of filter's cantral unit\n",
    "    # central_unit = np.ceil(cat_sel.shape[1]/2).astype(int)\n",
    "    # cat_sel_filter = cat_sel[:, central_unit, central_unit, :]\n",
    "    ACTIVATIONS[l]['filter_selectivity'] = cat_sel_filter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "block3_conv1\n",
      "block5_conv3\n",
      "block4_conv1\n",
      "block3_conv2\n",
      "block4_conv2\n",
      "block5_conv2\n",
      "block4_conv3\n",
      "block3_conv3\n",
      "block5_conv1\n"
     ]
    }
   ],
   "source": [
    "# Convert selectivity index into dataframe\n",
    "\n",
    "if not os.path.exists(SEL_FLD):\n",
    "    os.makedirs(SEL_FLD)\n",
    "        \n",
    "for l, v in ACTIVATIONS.items():  \n",
    "    print(l)\n",
    "    \n",
    "    df = pd.DataFrame(v['filter_selectivity'].T)\n",
    "    df['filter'] = df.index\n",
    "    df = df.melt(id_vars=df.columns[-1], value_vars=df.columns[:-1], \n",
    "                 var_name='category', value_name='selectivity')\n",
    "\n",
    "    df['label'] = df['category'].map(v['label_dict'])\n",
    "    df['layer'] = l\n",
    "    df.to_csv(os.path.join(SEL_FLD, '{}.csv'.format(l)), index=False)\n",
    "    df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optional\n",
    "# Join all files into one dataframe\n",
    "df = pd.concat((pd.read_csv(f) for f in os.path.join(SEL_FLD, \"*.csv\")))\n",
    "df.to_csv('selectivity_vgg16.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}