https://github.com/titu1994/lambda_networks_pt
Lambda Networks implemented in PyTorch
https://github.com/titu1994/lambda_networks_pt
Last synced: 6 months ago
JSON representation
Lambda Networks implemented in PyTorch
- Host: GitHub
- URL: https://github.com/titu1994/lambda_networks_pt
- Owner: titu1994
- License: mit
- Created: 2021-02-21T21:41:44.000Z (over 4 years ago)
- Default Branch: master
- Last Pushed: 2021-02-22T06:53:10.000Z (over 4 years ago)
- Last Synced: 2025-03-25T05:34:09.742Z (7 months ago)
- Language: Python
- Size: 37.1 KB
- Stars: 13
- Watchers: 3
- Forks: 2
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Lambda Networks in PyTorch 
Lambda Networks from the paper [LambdaNetworks: Modeling Long-Range Interactions Without Attention](https://arxiv.org/abs/2102.08602) by Bello et al.
Adds support for Rank 3 and Rank 5 tensor inputs as well as initial implementation of ResNet-D with Lambda Convolutions.Code is adapted from `lucidrains` implementation - https://github.com/lucidrains/lambda-networks
TODO:
- Add support for ResNet-RS# Installation
```
pip install --upgrade git+https://github.com/titu1994/lambda_networks_pt.git
```# Usage
There are three modules inside -
- `lambda_module_1d.py`: Implements Lambda Network block for Rank 3 input data (B, C, T)
- `lambda_module_2d.py`: Implements Lambda Network block for Rank 4 input data (B, C, H, W)
- `lambda_module_3d.py`: Implements Lambda Network block for Rank 5 input data (B, C, D, H, W)
- `lambda_resnets.py` : Implements Lambda ResNets (Using ResNet-D, Not ResNet-RS!) for Rank 4 input data (B, C, H, W)```python
import lambda_networks# 1D Block, 2D Block or 3D Block
module = lambda_networks.LambdaLayer1D(
dim=32,
dim_out=64,
m=None, # Use positive integer for Global Context using Lambda Layer. Represents "m" in the paper.
r=None, # Use positive integer for Local Context using Lambda Convolution. Represents "r" in the paper.
dim_k=16, # Dimension of key/query.
dim_intra=1 # Intra-dimension "u" in the paper.
heads=4, # Number of heads. Represents "h" in the paper.
implementation=0, # Defaults to 0 generally, which implements the paper version of n-D Lambda using (n+1)-D Convolution.
)# Lambda ResNet-D
model = lambda_networks.resnet_18(
lambda_m: bool = False, # Bool flag whether global context should be used or not. If set to True, pass in `input_size` as well to compute global context size per block.
lambda_r: Optional[int] = None, # Optional int, which if passed computes Local Context using Lambda Convolution.
lambda_k: int = 16, # Dimension of key/query.
lambda_u: int = 1, # Intra-dimension "u" in the paper.
lambda_heads: int = 4, # Number of heads. Represents "h" in the paper.
input_size: Optional[int] = None, # Optional int representing Height and Width of the image. Must be passed if `lambda_m` is set to True.
)```
# Tests
Tests will perform CPU only checks if there are no GPUs. If GPUs are present, will run all tests once for `cuda:0` as well.
```bash
pytest tests/
```# Requirements
- pytorch >= 1.7.1. Older versions might work, not tested.
- einops - Required for LambdaLayer2D/3D and Lambda-ResNet-D. Not required for LambdaLayer1D.