Skip to content

Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function

License

Notifications You must be signed in to change notification settings

lucidrains/siren-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SIREN in Pytorch

PyPI version

Pytorch implementation of SIREN - Implicit Neural Representations with Periodic Activation Function

Install

$ pip install siren-pytorch

Usage

A SIREN based multi-layered neural network

import torch
from torch import nn
from siren_pytorch import SirenNet

net = SirenNet(
    dim_in = 2,                        # input dimension, ex. 2d coor
    dim_hidden = 256,                  # hidden dimension
    dim_out = 3,                       # output dimension, ex. rgb value
    num_layers = 5,                    # number of layers
    final_activation = nn.Sigmoid(),   # activation of final layer (nn.Identity() for direct output)
    w0_initial = 30.                   # different signals may require different omega_0 in the first layer - this is a hyperparameter
)

coor = torch.randn(1, 2)
net(coor) # (1, 3) <- rgb value

One SIREN layer

import torch
from siren_pytorch import Siren

neuron = Siren(
    dim_in = 3,
    dim_out = 256
)

coor = torch.randn(1, 3)
neuron(coor) # (1, 256)

Sine activation (just a wrapper around torch.sin)

import torch
from siren_pytorch import Sine

act = Sine(1.)
coor = torch.randn(1, 2)
act(coor)

Wrapper to train on a specific image of specified height and width from a given SirenNet, and then to subsequently generate.

import torch
from torch import nn
from siren_pytorch import SirenNet, SirenWrapper

net = SirenNet(
    dim_in = 2,                        # input dimension, ex. 2d coor
    dim_hidden = 256,                  # hidden dimension
    dim_out = 3,                       # output dimension, ex. rgb value
    num_layers = 5,                    # number of layers
    w0_initial = 30.                   # different signals may require different omega_0 in the first layer - this is a hyperparameter
)

wrapper = SirenWrapper(
    net,
    image_width = 256,
    image_height = 256
)

img = torch.randn(1, 3, 256, 256)
loss = wrapper(img)
loss.backward()

# after much training ...
# simply invoke the wrapper without passing in anything

pred_img = wrapper() # (1, 3, 256, 256)

Modulation with Latent Code

A new paper proposes that the best way to condition a Siren with a latent code is to pass the latent vector through a modulator feedforward network, where each layer's hidden state is elementwise multiplied with the corresponding layer of the Siren.

You can use this simply by setting an extra keyword latent_dim, on the SirenWrapper

import torch
from torch import nn
from siren_pytorch import SirenNet, SirenWrapper

net = SirenNet(
    dim_in = 2,                        # input dimension, ex. 2d coor
    dim_hidden = 256,                  # hidden dimension
    dim_out = 3,                       # output dimension, ex. rgb value
    num_layers = 5,                    # number of layers
    w0_initial = 30.                   # different signals may require different omega_0 in the first layer - this is a hyperparameter
)

wrapper = SirenWrapper(
    net,
    latent_dim = 512,
    image_width = 256,
    image_height = 256
)

latent = nn.Parameter(torch.zeros(512).normal_(0, 1e-2))
img = torch.randn(1, 3, 256, 256)

loss = wrapper(img, latent = latent)
loss.backward()

# after much training ...
# simply invoke the wrapper without passing in anything

pred_img = wrapper(latent = latent) # (1, 3, 256, 256)

Citations

@misc{sitzmann2020implicit,
    title   = {Implicit Neural Representations with Periodic Activation Functions},
    author  = {Vincent Sitzmann and Julien N. P. Martel and Alexander W. Bergman and David B. Lindell and Gordon Wetzstein},
    year    = {2020},
    eprint  = {2006.09661},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{mehta2021modulated,
    title   = {Modulated Periodic Activations for Generalizable Local Functional Representations}, 
    author  = {Ishit Mehta and Michaël Gharbi and Connelly Barnes and Eli Shechtman and Ravi Ramamoorthi and Manmohan Chandraker},
    year    = {2021},
    eprint  = {2104.03960},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}