Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/lucidrains/logavgexp-torch
Implementation of LogAvgExp for Pytorch
https://github.com/lucidrains/logavgexp-torch
logsumexp pytorch
Last synced: 2 months ago
JSON representation
Implementation of LogAvgExp for Pytorch
- Host: GitHub
- URL: https://github.com/lucidrains/logavgexp-torch
- Owner: lucidrains
- License: mit
- Created: 2022-01-31T16:13:06.000Z (almost 3 years ago)
- Default Branch: main
- Last Pushed: 2022-03-28T17:36:40.000Z (over 2 years ago)
- Last Synced: 2024-10-15T00:16:35.088Z (2 months ago)
- Topics: logsumexp, pytorch
- Language: Python
- Homepage:
- Size: 15.6 KB
- Stars: 32
- Watchers: 2
- Forks: 2
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## LogAvgExp - Pytorch
Implementation of LogAvgExp for Pytorch
## Install
```bash
$ pip install logavgexp-pytorch
```## Usage
```python
import torch
from logavgexp_pytorch import logavgexp# basically it is an improved logsumexp (differentiable max)
# normalized for lengthx = torch.arange(1000)
y = logavgexp(x, dim = 0, temp = 0.01) # ~998.8# more than 1 dimension
x = torch.randn(1, 2048, 5)
y = logavgexp(x, dim = 1, temp = 0.2) # (1, 5)# keep dimension
x = torch.randn(1, 2048, 5)
y = logavgexp(x, dim = 1, temp = 0.2, keepdim = True) # (1, 1, 5)# masking (False for mask out with large negative value)
x = torch.randn(1, 2048, 5)
m = torch.randint(0, 2, (1, 2048, 1)).bool()y = logavgexp(x, mask = m, dim = 1, temp = 0.2, keepdim = True) # (1, 1, 5)
```
With learned temperature
```python
# learned temperature
import torch
from torch import nn
from logavgexp_pytorch import logavgexplearned_temp = nn.Parameter(torch.ones(1) * -5).exp().clamp(min = 1e-8) # make sure temperature can't hit 0
x = torch.randn(1, 2048, 5)
y = logavgexp(x, temp = learned_temp, dim = 1) # (1, 5)
```Or you can use the `LogAvgExp` class to handle the learned temperature parameter
```python
import torch
from logavgexp_pytorch import LogAvgExplogavgexp = LogAvgExp(
temp = 0.01,
dim = 1,
learned_temp = True
)x = torch.randn(1, 2048, 5)
y = logavgexp(x) # (1, 5)
```## LogAvgExp2D
```python
import torch
from logavgexp_pytorch import LogAvgExp2Dlogavgexp_pool = LogAvgExp2D((2, 2), stride = 2) # (2 x 2) pooling
img = torch.randn(1, 16, 64, 64)
out = logavgexp_pool(img) # (1, 16, 32, 32)
```## Todo
- [ ] look into a https://github.com/openai/triton Triton optimized version
## Citations
```bibtex
@misc{lowe2021logavgexp,
title = {LogAvgExp Provides a Principled and Performant Global Pooling Operator},
author = {Scott C. Lowe and Thomas Trappenberg and Sageev Oore},
year = {2021},
eprint = {2111.01742},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```