Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/lucidrains/etsformer-pytorch

Implementation of ETSformer, state of the art time-series Transformer, in Pytorch
https://github.com/lucidrains/etsformer-pytorch

artificial-intelligence deep-learning exponential-smoothing time-series transformers

Last synced: 6 days ago
JSON representation

Implementation of ETSformer, state of the art time-series Transformer, in Pytorch

Awesome Lists containing this project

README

        

## ETSformer - Pytorch

Implementation of ETSformer, state of the art time-series Transformer, in Pytorch

## Install

```bash
$ pip install etsformer-pytorch
```

## Usage

```python
import torch
from etsformer_pytorch import ETSFormer

model = ETSFormer(
time_features = 4,
model_dim = 512, # in paper they use 512
embed_kernel_size = 3, # kernel size for 1d conv for input embedding
layers = 2, # number of encoder and corresponding decoder layers
heads = 8, # number of exponential smoothing attention heads
K = 4, # num frequencies with highest amplitude to keep (attend to)
dropout = 0.2 # dropout (in paper they did 0.2)
)

timeseries = torch.randn(1, 1024, 4)

pred = model(timeseries, num_steps_forecast = 32) # (1, 32, 4) - (batch, num steps forecast, num time features)
```

For using ETSFormer for classification, using cross attention pooling on all latents and level output

```python
import torch
from etsformer_pytorch import ETSFormer, ClassificationWrapper

etsformer = ETSFormer(
time_features = 1,
model_dim = 512,
embed_kernel_size = 3,
layers = 2,
heads = 8,
K = 4,
dropout = 0.2
)

adapter = ClassificationWrapper(
etsformer = etsformer,
dim_head = 32,
heads = 16,
dropout = 0.2,
level_kernel_size = 5,
num_classes = 10
)

timeseries = torch.randn(1, 1024)

logits = adapter(timeseries) # (1, 10)
```

## Citation

```bibtex
@misc{woo2022etsformer,
title = {ETSformer: Exponential Smoothing Transformers for Time-series Forecasting},
author = {Gerald Woo and Chenghao Liu and Doyen Sahoo and Akshat Kumar and Steven Hoi},
year = {2022},
eprint = {2202.01381},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```