https://github.com/lucidrains/siren-pytorch
Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function
https://github.com/lucidrains/siren-pytorch
activation-functions artificial-intelligence deep-learning
Last synced: about 1 year ago
JSON representation
Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function
- Host: GitHub
- URL: https://github.com/lucidrains/siren-pytorch
- Owner: lucidrains
- License: mit
- Created: 2020-06-19T19:39:13.000Z (about 6 years ago)
- Default Branch: master
- Last Pushed: 2023-07-28T14:59:08.000Z (almost 3 years ago)
- Last Synced: 2025-04-04T01:05:24.246Z (about 1 year ago)
- Topics: activation-functions, artificial-intelligence, deep-learning
- Language: Python
- Homepage:
- Size: 15.6 KB
- Stars: 482
- Watchers: 12
- Forks: 51
- Open Issues: 4
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## SIREN in Pytorch
[](https://badge.fury.io/py/siren-pytorch)
Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function
## Install
```bash
$ pip install siren-pytorch
```
## Usage
A SIREN based multi-layered neural network
```python
import torch
from torch import nn
from siren_pytorch import SirenNet
net = SirenNet(
dim_in = 2, # input dimension, ex. 2d coor
dim_hidden = 256, # hidden dimension
dim_out = 3, # output dimension, ex. rgb value
num_layers = 5, # number of layers
final_activation = nn.Sigmoid(), # activation of final layer (nn.Identity() for direct output)
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)
coor = torch.randn(1, 2)
net(coor) # (1, 3) <- rgb value
```
One SIREN layer
```python
import torch
from siren_pytorch import Siren
neuron = Siren(
dim_in = 3,
dim_out = 256
)
coor = torch.randn(1, 3)
neuron(coor) # (1, 256)
```
Sine activation (just a wrapper around `torch.sin`)
```python
import torch
from siren_pytorch import Sine
act = Sine(1.)
coor = torch.randn(1, 2)
act(coor)
```
Wrapper to train on a specific image of specified height and width from a given `SirenNet`, and then to subsequently generate.
```python
import torch
from torch import nn
from siren_pytorch import SirenNet, SirenWrapper
net = SirenNet(
dim_in = 2, # input dimension, ex. 2d coor
dim_hidden = 256, # hidden dimension
dim_out = 3, # output dimension, ex. rgb value
num_layers = 5, # number of layers
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)
wrapper = SirenWrapper(
net,
image_width = 256,
image_height = 256
)
img = torch.randn(1, 3, 256, 256)
loss = wrapper(img)
loss.backward()
# after much training ...
# simply invoke the wrapper without passing in anything
pred_img = wrapper() # (1, 3, 256, 256)
```
## Modulation with Latent Code
A new paper proposes that the best way to condition a Siren with a latent code is to pass the latent vector through a modulator feedforward network, where each layer's hidden state is elementwise multiplied with the corresponding layer of the Siren.
You can use this simply by setting an extra keyword `latent_dim`, on the `SirenWrapper`
```python
import torch
from torch import nn
from siren_pytorch import SirenNet, SirenWrapper
net = SirenNet(
dim_in = 2, # input dimension, ex. 2d coor
dim_hidden = 256, # hidden dimension
dim_out = 3, # output dimension, ex. rgb value
num_layers = 5, # number of layers
w0_initial = 30. # different signals may require different omega_0 in the first layer - this is a hyperparameter
)
wrapper = SirenWrapper(
net,
latent_dim = 512,
image_width = 256,
image_height = 256
)
latent = nn.Parameter(torch.zeros(512).normal_(0, 1e-2))
img = torch.randn(1, 3, 256, 256)
loss = wrapper(img, latent = latent)
loss.backward()
# after much training ...
# simply invoke the wrapper without passing in anything
pred_img = wrapper(latent = latent) # (1, 3, 256, 256)
```
## Citations
```bibtex
@misc{sitzmann2020implicit,
title = {Implicit Neural Representations with Periodic Activation Functions},
author = {Vincent Sitzmann and Julien N. P. Martel and Alexander W. Bergman and David B. Lindell and Gordon Wetzstein},
year = {2020},
eprint = {2006.09661},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
```bibtex
@misc{mehta2021modulated,
title = {Modulated Periodic Activations for Generalizable Local Functional Representations},
author = {Ishit Mehta and Michaël Gharbi and Connelly Barnes and Eli Shechtman and Ravi Ramamoorthi and Manmohan Chandraker},
year = {2021},
eprint = {2104.03960},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```