https://github.com/idiap/bayesian-recurrence
A Bayesian Interpretation of Recurrence in Neural Networks
https://github.com/idiap/bayesian-recurrence
Last synced: about 1 year ago
JSON representation
A Bayesian Interpretation of Recurrence in Neural Networks
- Host: GitHub
- URL: https://github.com/idiap/bayesian-recurrence
- Owner: idiap
- Created: 2022-07-18T06:42:58.000Z (almost 4 years ago)
- Default Branch: main
- Last Pushed: 2023-10-18T17:43:22.000Z (over 2 years ago)
- Last Synced: 2025-04-12T03:27:09.674Z (about 1 year ago)
- Language: Python
- Size: 13.7 KB
- Stars: 6
- Watchers: 3
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSES/BSD-3-Clause.txt
- Citation: CITATION.cff
Awesome Lists containing this project
README
# A Bayesian Interpretation of Recurrence in Neural Networks
This repository contains the different Bayesian recurrent units (BRUs)
implemented in PyTorch, that were defined in the following papers by A. Bittar
and P. Garner,
- [A Bayesian Interpretation of the Light Gated Recurrent Unit](https://rc.signalprocessingsociety.org/conferences/icassp-2021/SPSICASSP21VID0356.html?source=IBP), ICASSP 2021
- [Bayesian Recurrent Units and the Forward-Backward Algorithm](https://arxiv.org/abs/2207.10486), INTERSPEECH 2022.
Contact: abittar@idiap.ch
## Installation
git clone https://github.com/idiap/bayesian-recurrence.git
cd bayesian-recurrence
pip install -r requirements.txt
python setup.py install
## Usage
After the installation, the defined recurrent units are available as python modules.
One can then create networks of the desired Bayesian units and use them inside PyTorch.
import torch
import torch.nn as nn
from bayesian_recurrence.libru import liBRU
# Build input
batch_size = 4
nb_steps = 100
nb_inputs = 20
x = torch.Tensor(batch_size, nb_steps, nb_inputs)
nn.init.uniform_(x)
# Define network
net = liBRU(
nb_inputs,
layer_sizes=[128, 128, 10],
bidirectional=True,
hidden_type='probs',
normalization='batchnorm',
use_bias=False,
dropout=0.
)
# Pass input tensor through network
y = net(x)