{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Predict and Extract Features with Pre-trained Models\n",
"\n",
"This tutorial will work through how to use pre-trained models for predicting and feature extraction.\n",
"\n",
"## Download pre-trained models\n",
"\n",
"A model often contains two parts, the `.json` file specifying the neural network structure, and the `.params` file containing the binary parameters. The name convention is `name-symbol.json` and `name-epoch.params`, where `name` is the model name, and `epoch` is the epoch number. \n",
"\n",
"\n",
"Here we download a pre-trained Resnet 50-layer model on Imagenet. Other models are available at http://data.mxnet.io/models/"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import os, urllib\n",
"def download(url):\n",
" filename = url.split(\"/\")[-1]\n",
" if not os.path.exists(filename):\n",
" urllib.urlretrieve(url, filename)\n",
"def get_model(prefix, epoch):\n",
" download(prefix+'-symbol.json')\n",
" download(prefix+'-%04d.params' % (epoch,))\n",
"\n",
"get_model('http://data.mxnet.io/models/imagenet/resnet/50-layers/resnet-50', 0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Initialization\n",
"\n",
"We first load the model into memory with `load_checkpoint`. It returns the symbol (see [symbol.ipynb](../basic/symbol.ipynb)) definition of the neural network, and parameters. "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import mxnet as mx\n",
"sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50', 0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can visualize the neural network by `mx.viz.plot_network`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mx.viz.plot_network(sym)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Both argument parameters and auxiliary parameters (e.g mean/std in batch normalization layer) are stored as a dictionary of string name and ndarray value (see [ndarray.ipynb](../basic/ndarray.ipynb)). The arguments contain \n",
"consist of weight and bias. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"{'bn0_beta': ,\n",
" 'bn0_gamma': ,\n",
" 'bn1_beta': ,\n",
" 'bn1_gamma': ,\n",
" 'bn_data_beta': ,\n",
" 'bn_data_gamma': ,\n",
" 'conv0_weight': ,\n",
" 'fc1_bias': ,\n",
" 'fc1_weight': ,\n",
" 'stage1_unit1_bn1_beta': ,\n",
" 'stage1_unit1_bn1_gamma': ,\n",
" 'stage1_unit1_bn2_beta': ,\n",
" 'stage1_unit1_bn2_gamma': ,\n",
" 'stage1_unit1_bn3_beta': ,\n",
" 'stage1_unit1_bn3_gamma': ,\n",
" 'stage1_unit1_conv1_weight': ,\n",
" 'stage1_unit1_conv2_weight': ,\n",
" 'stage1_unit1_conv3_weight': ,\n",
" 'stage1_unit1_sc_weight': ,\n",
" 'stage1_unit2_bn1_beta': ,\n",
" 'stage1_unit2_bn1_gamma': ,\n",
" 'stage1_unit2_bn2_beta': ,\n",
" 'stage1_unit2_bn2_gamma': ,\n",
" 'stage1_unit2_bn3_beta': ,\n",
" 'stage1_unit2_bn3_gamma': ,\n",
" 'stage1_unit2_conv1_weight': ,\n",
" 'stage1_unit2_conv2_weight': ,\n",
" 'stage1_unit2_conv3_weight': ,\n",
" 'stage1_unit3_bn1_beta': ,\n",
" 'stage1_unit3_bn1_gamma': ,\n",
" 'stage1_unit3_bn2_beta': ,\n",
" 'stage1_unit3_bn2_gamma': ,\n",
" 'stage1_unit3_bn3_beta': ,\n",
" 'stage1_unit3_bn3_gamma': ,\n",
" 'stage1_unit3_conv1_weight': ,\n",
" 'stage1_unit3_conv2_weight': ,\n",
" 'stage1_unit3_conv3_weight': ,\n",
" 'stage2_unit1_bn1_beta': ,\n",
" 'stage2_unit1_bn1_gamma': ,\n",
" 'stage2_unit1_bn2_beta': ,\n",
" 'stage2_unit1_bn2_gamma': ,\n",
" 'stage2_unit1_bn3_beta': ,\n",
" 'stage2_unit1_bn3_gamma': ,\n",
" 'stage2_unit1_conv1_weight': ,\n",
" 'stage2_unit1_conv2_weight': ,\n",
" 'stage2_unit1_conv3_weight': ,\n",
" 'stage2_unit1_sc_weight': ,\n",
" 'stage2_unit2_bn1_beta': ,\n",
" 'stage2_unit2_bn1_gamma': ,\n",
" 'stage2_unit2_bn2_beta': ,\n",
" 'stage2_unit2_bn2_gamma': ,\n",
" 'stage2_unit2_bn3_beta': ,\n",
" 'stage2_unit2_bn3_gamma': ,\n",
" 'stage2_unit2_conv1_weight': ,\n",
" 'stage2_unit2_conv2_weight': ,\n",
" 'stage2_unit2_conv3_weight': ,\n",
" 'stage2_unit3_bn1_beta': ,\n",
" 'stage2_unit3_bn1_gamma': ,\n",
" 'stage2_unit3_bn2_beta': ,\n",
" 'stage2_unit3_bn2_gamma': ,\n",
" 'stage2_unit3_bn3_beta': ,\n",
" 'stage2_unit3_bn3_gamma': ,\n",
" 'stage2_unit3_conv1_weight': ,\n",
" 'stage2_unit3_conv2_weight': ,\n",
" 'stage2_unit3_conv3_weight': ,\n",
" 'stage2_unit4_bn1_beta': ,\n",
" 'stage2_unit4_bn1_gamma': ,\n",
" 'stage2_unit4_bn2_beta': ,\n",
" 'stage2_unit4_bn2_gamma': ,\n",
" 'stage2_unit4_bn3_beta': ,\n",
" 'stage2_unit4_bn3_gamma': ,\n",
" 'stage2_unit4_conv1_weight': ,\n",
" 'stage2_unit4_conv2_weight': ,\n",
" 'stage2_unit4_conv3_weight': ,\n",
" 'stage3_unit1_bn1_beta': ,\n",
" 'stage3_unit1_bn1_gamma': ,\n",
" 'stage3_unit1_bn2_beta': ,\n",
" 'stage3_unit1_bn2_gamma': ,\n",
" 'stage3_unit1_bn3_beta': ,\n",
" 'stage3_unit1_bn3_gamma': ,\n",
" 'stage3_unit1_conv1_weight': ,\n",
" 'stage3_unit1_conv2_weight': ,\n",
" 'stage3_unit1_conv3_weight': ,\n",
" 'stage3_unit1_sc_weight': ,\n",
" 'stage3_unit2_bn1_beta': ,\n",
" 'stage3_unit2_bn1_gamma': ,\n",
" 'stage3_unit2_bn2_beta': ,\n",
" 'stage3_unit2_bn2_gamma': ,\n",
" 'stage3_unit2_bn3_beta': ,\n",
" 'stage3_unit2_bn3_gamma': ,\n",
" 'stage3_unit2_conv1_weight': ,\n",
" 'stage3_unit2_conv2_weight': ,\n",
" 'stage3_unit2_conv3_weight': ,\n",
" 'stage3_unit3_bn1_beta': ,\n",
" 'stage3_unit3_bn1_gamma': ,\n",
" 'stage3_unit3_bn2_beta': ,\n",
" 'stage3_unit3_bn2_gamma': ,\n",
" 'stage3_unit3_bn3_beta': ,\n",
" 'stage3_unit3_bn3_gamma': ,\n",
" 'stage3_unit3_conv1_weight': ,\n",
" 'stage3_unit3_conv2_weight': ,\n",
" 'stage3_unit3_conv3_weight': ,\n",
" 'stage3_unit4_bn1_beta': ,\n",
" 'stage3_unit4_bn1_gamma': ,\n",
" 'stage3_unit4_bn2_beta': ,\n",
" 'stage3_unit4_bn2_gamma': ,\n",
" 'stage3_unit4_bn3_beta': ,\n",
" 'stage3_unit4_bn3_gamma': ,\n",
" 'stage3_unit4_conv1_weight': ,\n",
" 'stage3_unit4_conv2_weight': ,\n",
" 'stage3_unit4_conv3_weight': ,\n",
" 'stage3_unit5_bn1_beta': ,\n",
" 'stage3_unit5_bn1_gamma': ,\n",
" 'stage3_unit5_bn2_beta': ,\n",
" 'stage3_unit5_bn2_gamma': ,\n",
" 'stage3_unit5_bn3_beta': ,\n",
" 'stage3_unit5_bn3_gamma': ,\n",
" 'stage3_unit5_conv1_weight': ,\n",
" 'stage3_unit5_conv2_weight': ,\n",
" 'stage3_unit5_conv3_weight': ,\n",
" 'stage3_unit6_bn1_beta': ,\n",
" 'stage3_unit6_bn1_gamma': ,\n",
" 'stage3_unit6_bn2_beta': ,\n",
" 'stage3_unit6_bn2_gamma': ,\n",
" 'stage3_unit6_bn3_beta': ,\n",
" 'stage3_unit6_bn3_gamma': ,\n",
" 'stage3_unit6_conv1_weight': ,\n",
" 'stage3_unit6_conv2_weight': ,\n",
" 'stage3_unit6_conv3_weight': ,\n",
" 'stage4_unit1_bn1_beta': ,\n",
" 'stage4_unit1_bn1_gamma': ,\n",
" 'stage4_unit1_bn2_beta': ,\n",
" 'stage4_unit1_bn2_gamma': ,\n",
" 'stage4_unit1_bn3_beta': ,\n",
" 'stage4_unit1_bn3_gamma': ,\n",
" 'stage4_unit1_conv1_weight': ,\n",
" 'stage4_unit1_conv2_weight': ,\n",
" 'stage4_unit1_conv3_weight': ,\n",
" 'stage4_unit1_sc_weight': ,\n",
" 'stage4_unit2_bn1_beta': ,\n",
" 'stage4_unit2_bn1_gamma': ,\n",
" 'stage4_unit2_bn2_beta': ,\n",
" 'stage4_unit2_bn2_gamma': ,\n",
" 'stage4_unit2_bn3_beta': ,\n",
" 'stage4_unit2_bn3_gamma': ,\n",
" 'stage4_unit2_conv1_weight': ,\n",
" 'stage4_unit2_conv2_weight': ,\n",
" 'stage4_unit2_conv3_weight': ,\n",
" 'stage4_unit3_bn1_beta': ,\n",
" 'stage4_unit3_bn1_gamma': ,\n",
" 'stage4_unit3_bn2_beta': ,\n",
" 'stage4_unit3_bn2_gamma': ,\n",
" 'stage4_unit3_bn3_beta': ,\n",
" 'stage4_unit3_bn3_gamma': ,\n",
" 'stage4_unit3_conv1_weight': ,\n",
" 'stage4_unit3_conv2_weight': ,\n",
" 'stage4_unit3_conv3_weight': }"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"arg_params"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"while auxiliaries contains the the mean and std for the batch normalization layers. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"{'bn0_moving_mean': ,\n",
" 'bn0_moving_var': ,\n",
" 'bn1_moving_mean': ,\n",
" 'bn1_moving_var': ,\n",
" 'bn_data_moving_mean': ,\n",
" 'bn_data_moving_var': ,\n",
" 'stage1_unit1_bn1_moving_mean': ,\n",
" 'stage1_unit1_bn1_moving_var': ,\n",
" 'stage1_unit1_bn2_moving_mean': ,\n",
" 'stage1_unit1_bn2_moving_var': ,\n",
" 'stage1_unit1_bn3_moving_mean': ,\n",
" 'stage1_unit1_bn3_moving_var': ,\n",
" 'stage1_unit2_bn1_moving_mean': ,\n",
" 'stage1_unit2_bn1_moving_var': ,\n",
" 'stage1_unit2_bn2_moving_mean': ,\n",
" 'stage1_unit2_bn2_moving_var': ,\n",
" 'stage1_unit2_bn3_moving_mean': ,\n",
" 'stage1_unit2_bn3_moving_var': ,\n",
" 'stage1_unit3_bn1_moving_mean': ,\n",
" 'stage1_unit3_bn1_moving_var': ,\n",
" 'stage1_unit3_bn2_moving_mean': ,\n",
" 'stage1_unit3_bn2_moving_var': ,\n",
" 'stage1_unit3_bn3_moving_mean': ,\n",
" 'stage1_unit3_bn3_moving_var': ,\n",
" 'stage2_unit1_bn1_moving_mean': ,\n",
" 'stage2_unit1_bn1_moving_var': ,\n",
" 'stage2_unit1_bn2_moving_mean': ,\n",
" 'stage2_unit1_bn2_moving_var': ,\n",
" 'stage2_unit1_bn3_moving_mean': ,\n",
" 'stage2_unit1_bn3_moving_var': ,\n",
" 'stage2_unit2_bn1_moving_mean': ,\n",
" 'stage2_unit2_bn1_moving_var': ,\n",
" 'stage2_unit2_bn2_moving_mean': ,\n",
" 'stage2_unit2_bn2_moving_var': ,\n",
" 'stage2_unit2_bn3_moving_mean': ,\n",
" 'stage2_unit2_bn3_moving_var': ,\n",
" 'stage2_unit3_bn1_moving_mean': ,\n",
" 'stage2_unit3_bn1_moving_var': ,\n",
" 'stage2_unit3_bn2_moving_mean': ,\n",
" 'stage2_unit3_bn2_moving_var': ,\n",
" 'stage2_unit3_bn3_moving_mean': ,\n",
" 'stage2_unit3_bn3_moving_var': ,\n",
" 'stage2_unit4_bn1_moving_mean': ,\n",
" 'stage2_unit4_bn1_moving_var': ,\n",
" 'stage2_unit4_bn2_moving_mean': ,\n",
" 'stage2_unit4_bn2_moving_var': ,\n",
" 'stage2_unit4_bn3_moving_mean': ,\n",
" 'stage2_unit4_bn3_moving_var': ,\n",
" 'stage3_unit1_bn1_moving_mean': ,\n",
" 'stage3_unit1_bn1_moving_var': ,\n",
" 'stage3_unit1_bn2_moving_mean': ,\n",
" 'stage3_unit1_bn2_moving_var': ,\n",
" 'stage3_unit1_bn3_moving_mean': ,\n",
" 'stage3_unit1_bn3_moving_var': ,\n",
" 'stage3_unit2_bn1_moving_mean': ,\n",
" 'stage3_unit2_bn1_moving_var': ,\n",
" 'stage3_unit2_bn2_moving_mean': ,\n",
" 'stage3_unit2_bn2_moving_var': ,\n",
" 'stage3_unit2_bn3_moving_mean': ,\n",
" 'stage3_unit2_bn3_moving_var': ,\n",
" 'stage3_unit3_bn1_moving_mean': ,\n",
" 'stage3_unit3_bn1_moving_var': ,\n",
" 'stage3_unit3_bn2_moving_mean': ,\n",
" 'stage3_unit3_bn2_moving_var': ,\n",
" 'stage3_unit3_bn3_moving_mean': ,\n",
" 'stage3_unit3_bn3_moving_var': ,\n",
" 'stage3_unit4_bn1_moving_mean': ,\n",
" 'stage3_unit4_bn1_moving_var': ,\n",
" 'stage3_unit4_bn2_moving_mean': ,\n",
" 'stage3_unit4_bn2_moving_var': ,\n",
" 'stage3_unit4_bn3_moving_mean': ,\n",
" 'stage3_unit4_bn3_moving_var': ,\n",
" 'stage3_unit5_bn1_moving_mean': ,\n",
" 'stage3_unit5_bn1_moving_var': ,\n",
" 'stage3_unit5_bn2_moving_mean': ,\n",
" 'stage3_unit5_bn2_moving_var': ,\n",
" 'stage3_unit5_bn3_moving_mean': ,\n",
" 'stage3_unit5_bn3_moving_var': ,\n",
" 'stage3_unit6_bn1_moving_mean': ,\n",
" 'stage3_unit6_bn1_moving_var': ,\n",
" 'stage3_unit6_bn2_moving_mean': ,\n",
" 'stage3_unit6_bn2_moving_var': ,\n",
" 'stage3_unit6_bn3_moving_mean': ,\n",
" 'stage3_unit6_bn3_moving_var': ,\n",
" 'stage4_unit1_bn1_moving_mean': ,\n",
" 'stage4_unit1_bn1_moving_var': ,\n",
" 'stage4_unit1_bn2_moving_mean': ,\n",
" 'stage4_unit1_bn2_moving_var': ,\n",
" 'stage4_unit1_bn3_moving_mean': ,\n",
" 'stage4_unit1_bn3_moving_var': ,\n",
" 'stage4_unit2_bn1_moving_mean': ,\n",
" 'stage4_unit2_bn1_moving_var': ,\n",
" 'stage4_unit2_bn2_moving_mean': ,\n",
" 'stage4_unit2_bn2_moving_var': ,\n",
" 'stage4_unit2_bn3_moving_mean': ,\n",
" 'stage4_unit2_bn3_moving_var': ,\n",
" 'stage4_unit3_bn1_moving_mean':