Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/francescosaveriozuppichini/mirror

Visualisation tool for CNNs in pytorch
https://github.com/francescosaveriozuppichini/mirror

deep-learning neural-network python pytorch visualis

Last synced: 14 days ago
JSON representation

Visualisation tool for CNNs in pytorch

Awesome Lists containing this project

README

        

{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Mirror\n",
"## Pytorch CNN Visualisation Tool\n",
"*Francesco Saverio Zuppichini*\n",
"\n",
"This is a raw beta so expect lots of things to change and improve over time.\n",
"\n",
"![alt](https://github.com/FrancescoSaverioZuppichini/mirror/blob/master/resources/mirror.gif?raw=true)\n",
"\n",
"An interactive version of this tutorial can be [found here](https://github.com/FrancescoSaverioZuppichini/mirror/blob/master/README.ipynb)\n",
"\n",
"### Getting started\n",
"\n",
"To install mirror run\n",
"\n",
"```\n",
"pip install git+https://github.com/FrancescoSaverioZuppichini/mirror.git\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Getting Started\n",
"\n",
"A basic example"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" * Serving Flask app \"mirror.App\" (lazy loading)\n",
" * Environment: production\n",
" WARNING: Do not use the development server in a production environment.\n",
" Use a production WSGI server instead.\n",
" * Debug mode: off\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" * Running on http://0.0.0.0:5000/ (Press CTRL+C to quit)\n",
"127.0.0.1 - - [10/Sep/2019 21:10:50] \"GET /api/inputs HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [10/Sep/2019 21:10:50] \"GET /api/model HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [10/Sep/2019 21:10:50] \"GET /api/visualisation HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [10/Sep/2019 21:10:50] \"GET /api/model/image/-9223372036580704890/-9223372036580704890/2041159775650575674/%3Cbuilt-in%20function%20id%3E/0 HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [10/Sep/2019 21:10:50] \"GET /api/model/image/-9223372036580704890/-9223372036580704890/2041180666371503418/%3Cbuilt-in%20function%20id%3E/1 HTTP/1.1\" 200 -\n",
"127.0.0.1 - - [10/Sep/2019 21:10:51] \"GET /manifest.json HTTP/1.1\" 404 -\n"
]
}
],
"source": [
"from mirror import mirror\n",
"from mirror.visualisations.web import *\n",
"from PIL import Image\n",
"from torchvision.models import resnet101, resnet18, vgg16, alexnet\n",
"from torchvision.transforms import ToTensor, Resize, Compose\n",
"\n",
"# create a model\n",
"model = vgg16(pretrained=True)\n",
"# open some images\n",
"cat = Image.open(\"./cat.jpg\")\n",
"dog_and_cat = Image.open(\"./dog_and_cat.jpg\")\n",
"# resize the image and make it a tensor\n",
"to_input = Compose([Resize((224, 224)), ToTensor()])\n",
"# call mirror with the inputs and the model\n",
"mirror([to_input(cat), to_input(dog_and_cat)], model, visualisations=[BackProp, GradCam, DeepDream])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It will automatic open a new tab in your browser\n",
"\n",
"![alt](https://github.com/FrancescoSaverioZuppichini/mirror/blob/master/resources/mirror.png?raw=true)\n",
"\n",
"On the left you can see your model tree structure, by clicking on one layer all his children are showed. On the right there are the visualisation settings. You can select your input by clicking on the bottom tab.\n",
"\n",
"![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/mirror/master/resources/inputs.png)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Available Visualisations\n",
"All visualisation available for the web app are inside `.mirror.visualisations.web`.\n",
"### Weights\n",
"![alt](https://github.com/FrancescoSaverioZuppichini/mirror/blob/master/resources/weights.png?raw=true)\n",
"### Deep Dream\n",
"![alt](https://github.com/FrancescoSaverioZuppichini/mirror/blob/master/resources/deepdream.png?raw=true)\n",
"### Back Prop / Guide Back Prop\n",
"By clicking on the radio button 'guide', all the relus negative output will be set to zero producing a nicer looking image\n",
"![alt](https://github.com/FrancescoSaverioZuppichini/mirror/blob/master/resources/backprop.png?raw=true)\n",
"## Grad Cam / Guide Grad Cam\n",
"![alt](https://github.com/FrancescoSaverioZuppichini/mirror/blob/master/resources/grad_cam.png?raw=true)\n",
"\n",
"### Using with Tensors\n",
"If you want, you can use the vanilla version of each visualisation by importing them from `.mirror.visualisation.core`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from mirror.visualisations.core import GradCam\n",
"\n",
"# create a model\n",
"model = vgg16(pretrained=True)\n",
"# open some images\n",
"cat = Image.open(\"./cat.jpg\")\n",
"dog_and_cat = Image.open(\"./dog_and_cat.jpg\")\n",
"# resize the image and make it a tensor\n",
"to_input = Compose([Resize((224, 224)), ToTensor()])\n",
"\n",
"cam = GradCam(model, device='cpu')\n",
"cam(to_input(cat).unsqueeze(0), None) # will return the output image and some additional information"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create a Visualisation\n",
"To create a visualisation you first have to subclass the `Visualisation` class and override the`__call__` method to return an image and, if needed, additional informations. The following example creates a custom visualisation that just repeat the input `repeat` times. So"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from mirror.visualisations.core import Visualisation\n",
"\n",
"class RepeatInput(Visualisation):\n",
"\n",
" def __call__(self, inputs, layer, repeat=1):\n",
" return inputs.repeat(repeat, 1, 1, 1), None\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This class repeats the input for `repeat` times.\n",
"### Connect to the web interface\n",
"To connect our fancy visualisation to the web interface, we have to create a `WebInterface`. Easily, we can use `WebInterface.from_visualisation` to generate the communication channel between our visualisation and the web app. \n",
"\n",
"It follows and example"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'RepeatInput' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mvisualisation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mWebInterface\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_visualisation\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mRepeatInput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Repeat'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mNameError\u001b[0m: name 'RepeatInput' is not defined"
]
}
],
"source": [
"from mirror.visualisations.web import WebInterface\n",
"from functools import partial\n",
"\n",
"params = {'repeat' : {\n",
" 'type' : 'slider',\n",
" 'min' : 1,\n",
" 'max' : 100,\n",
" 'value' : 2,\n",
" 'step': 1,\n",
" 'params': {}\n",
" }\n",
" }\n",
"\n",
"\n",
"visualisation = partial(WebInterface.from_visualisation, RepeatInput, params=params, name='Repeat')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First we import `WebInterface` and `partial`. Then, we create a dictionary where each **they key is the visualisation parameter name**. In our example, `RepeatInput` takes a parameter called `repeat`, thus we have to define a dictionary `{ 'repeat' : { ... }' }`. \n",
"\n",
"The value of that dictionary is the configuration for one of the basic UI elements: *slider*, *textfield* and *radio*. \n",
"\n",
"The input is stored in the `value` slot.\n",
"\n",
"Then we call `WebInterface.from_visualisation` by passing the visualisation, the params and the name. We need to wrap this function using `partial` since `mirror` will need to dynamically pass some others parameters, the current layer and the input, at run time.\n",
"\n",
"![alt](https://github.com/FrancescoSaverioZuppichini/mirror/blob/master/resources/repeat_slider.jpg?raw=true)\n",
"The final result is \n",
"\n",
"![alt](https://github.com/FrancescoSaverioZuppichini/mirror/blob/master/resources/repeat_example.jpg?raw=true)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![alt](https://github.com/FrancescoSaverioZuppichini/mirror/blob/master/resources/dummy.jpg?raw=true)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Change the front-end\n",
"All the front-end is developed usin [React](https://reactjs.org/) and [Material-UI](https://material-ui.com/), two very known frameworks, making easier for anybody to contribuite.\n",
"\n",
"You can customise the front-end by changing the source code in `mirror/client`. After that, you need to build the react app and move the file to the server static folder.\n",
"\n",
"**I was not able to serve the static file directly from the /mirror/client/build folder** if you know how to do it any pull request is welcome :)\n",
"\n",
"```\n",
"cd ./mirror/mirror/client // assuming the root folder is called mirror\n",
"npm run build\n",
"```\n",
"Then you need to move the fiels from the `mirror/mirror/client/build` folder to `mirror/mirror`. You can remove all the files in `mirror/mirro/static`\n",
"```\n",
"mv ./build/static ../ && cp ./build/* ../static/\n",
"```\n",
"\n",
"### TODO\n",
"- [x] Cache reused layer \n",
"- [x] Make a generic abstraction of a visualisation in order to add more features \n",
"- [x] Add dropdown as parameter\n",
"- [x] Add text field\n",
"- [x] Support multiple inputs\n",
"- [x] Support multiple models\n",
"- Add all visualisation present here https://github.com/utkuozbulak/pytorch-cnn-visualizations\n",
" * [x] [Gradient visualization with vanilla backpropagation](#gradient-visualization)\n",
" * [x] [Gradient visualization with guided backpropagation](#gradient-visualization) [1]\n",
" * [x] [Gradient visualization with saliency maps](#gradient-visualization) [4]\n",
" * [x] [CNN filter visualization](#convolutional-neural-network-filter-visualization) [9]\n",
" * [x] [Deep dream](#deep-dream) [10]\n",
" * [ ] [Class specific image generation](#class-specific-image-generation) [4]\n",
" * [x] [Grad Cam](https://arxiv.org/abs/1610.02391)\n",
"\n",
"- [ ] Add a `output_transformation` params for each visualisation to allow better customisation \n",
"- [ ] Add a `input_transformation` params for each visualisation to allow better customisation "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}