https://github.com/csinva/max-activation-interpretation-pytorch
Code for creating maximal activation images (like Deep Dream) in pytorch with various regularizations / losses.
https://github.com/csinva/max-activation-interpretation-pytorch
deep-dream deep-learning interpretability maximal-activation neural-network optimization pytorch regularization total-variation visualization-tools
Last synced: 5 months ago
JSON representation
Code for creating maximal activation images (like Deep Dream) in pytorch with various regularizations / losses.
- Host: GitHub
- URL: https://github.com/csinva/max-activation-interpretation-pytorch
- Owner: csinva
- Created: 2019-04-05T01:31:26.000Z (about 6 years ago)
- Default Branch: master
- Last Pushed: 2020-01-31T00:00:15.000Z (over 5 years ago)
- Last Synced: 2024-10-28T11:35:37.674Z (7 months ago)
- Topics: deep-dream, deep-learning, interpretability, maximal-activation, neural-network, optimization, pytorch, regularization, total-variation, visualization-tools
- Language: Jupyter Notebook
- Size: 188 KB
- Stars: 4
- Watchers: 3
- Forks: 2
- Open Issues: 0
-
Metadata Files:
- Readme: readme.md
Awesome Lists containing this project
README
# maximal activation
- maximal activation is a simple technique which optimizes the input of a model to maximize an output response
- the code here (in `max_act.py`) shows a simple pytorch implementation of this technique
- this code includes simple regularization for this method
- one example maximizing the class "peacock" for AlexNet:
## sample usage
- install with `pip install git+https://github.com/csinva/max-activation-interpretation-pytorch`
```python
sys.path.append('../max_act')
from max_act import maximize_im, maximize_im_simple
import visualize_ims as vizdevice = 'cuda'
model = model.to(device)
class_num = 5
im_shape = (1, 1, 28, 28) # (1, 3, 224, 224) for imagenet
im = torch.zeros(im_shape, requires_grad=True, device=device)
ims_opt, losses = maximize_im_simple(model, im, class_num=class_num, lr=1e-5,
num_iters=int(1e3), lambda_tv=1e-1, lambda_pnorm=1e-1)viz.show(ims_opt[::2])
plt.show()plt.plot(losses)
```