{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Pytorch Image Models (timm)\n", "\n", "> `timm` is a deep-learning library created by [Ross Wightman](https://twitter.com/wightmanr) and is a collection of SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations and also training/validating scripts with ability to reproduce ImageNet training results. \n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", "pip install timm\n", "```\n", "\n", "Or for an editable install, \n", "\n", "```\n", "git clone https://github.com/rwightman/pytorch-image-models\n", "cd pytorch-image-models && pip install -e .\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How to use" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create a model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import timm \n", "import torch\n", "\n", "model = timm.create_model('resnet34')\n", "x = torch.randn(1, 3, 224, 224)\n", "model(x).shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is that simple to create a model using `timm`. The `create_model` function is a factory method that can be used to create over 300 models that are part of the `timm` library." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To create a pretrained model, simply pass in `pretrained=True`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading: \"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth\" to /Users/amanarora/.cache/torch/hub/checkpoints/resnet34-43635321.pth\n" ] } ], "source": [ "pretrained_resnet_34 = timm.create_model('resnet34', pretrained=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To create a model with a custom number of classes, simply pass in `num_classes=`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 10])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import timm \n", "import torch\n", "\n", "model = timm.create_model('resnet34', num_classes=10)\n", "x = torch.randn(1, 3, 224, 224)\n", "model(x).shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### List Models with Pretrained Weights\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`timm.list_models()` returns a complete list of available models in `timm`. To have a look at a complete list of pretrained models, pass in `pretrained=True` in `list_models`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(271,\n", " ['adv_inception_v3',\n", " 'cspdarknet53',\n", " 'cspresnet50',\n", " 'cspresnext50',\n", " 'densenet121'])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "avail_pretrained_models = timm.list_models(pretrained=True)\n", "len(avail_pretrained_models), avail_pretrained_models[:5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are a total of **271** models with pretrained weights currently available in `timm`!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Search for model architectures by Wildcard" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is also possible to search for model architectures using Wildcard as below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['densenet121',\n", " 'densenet121d',\n", " 'densenet161',\n", " 'densenet169',\n", " 'densenet201',\n", " 'densenet264',\n", " 'densenet264d_iabn',\n", " 'densenetblur121d',\n", " 'tv_densenet121']" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "all_densenet_models = timm.list_models('*densenet*')\n", "all_densenet_models" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }