Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/harvardnlp/pytorch-struct

Fast, general, and tested differentiable structured prediction in PyTorch
https://github.com/harvardnlp/pytorch-struct

Last synced: 1 day ago
JSON representation

Fast, general, and tested differentiable structured prediction in PyTorch

Awesome Lists containing this project

README

        

# Torch-Struct: Structured Prediction Library

![Tests](https://github.com/harvardnlp/pytorch-struct/workflows/Tests/badge.svg)
[![Coverage Status](https://coveralls.io/repos/github/harvardnlp/pytorch-struct/badge.svg?branch=master)](https://coveralls.io/github/harvardnlp/pytorch-struct?branch=master)



A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications.

* HMM / LinearChain-CRF
* HSMM / SemiMarkov-CRF
* Dependency Tree-CRF
* PCFG Binary Tree-CRF
* ...

Designed to be used as efficient batched layers in other PyTorch code.

[Tutorial paper](https://arxiv.org/abs/2002.00876) describing methodology.

## Getting Started

```python
!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
# Optional CUDA kernels for FastLogSemiring
!pip install -qU git+https://github.com/harvardnlp/genbmm
# For plotting.
!pip install -q matplotlib
```

```python
import torch
from torch_struct import DependencyCRF, LinearChainCRF
import matplotlib.pyplot as plt
def show(x): plt.imshow(x.detach())
```

```python
# Make some data.
vals = torch.zeros(2, 10, 10) + 1e-5
vals[:, :5, :5] = torch.rand(5)
vals[:, 5:, 5:] = torch.rand(5)
dist = DependencyCRF(vals.log())
show(dist.log_potentials[0])
```

![png](README_files/README_4_0.png)

```python
# Compute marginals
show(dist.marginals[0])
```

![png](README_files/README_5_0.png)

```python
# Compute argmax
show(dist.argmax.detach()[0])
```

![png](README_files/README_6_0.png)

```python
# Compute scoring and enumeration (forward / inside)
log_partition = dist.partition
max_score = dist.log_prob(dist.argmax)
```

```python
# Compute samples
show(dist.sample((1,)).detach()[0, 0])
```

![png](README_files/README_8_0.png)

```python
# Padding/Masking built into library.
dist = DependencyCRF(vals, lengths=torch.tensor([10, 7]))
show(dist.marginals[0])
plt.show()
show(dist.marginals[1])
```

![png](README_files/README_9_0.png)

![png](README_files/README_9_1.png)

```python
# Many other structured prediction approaches
chain = torch.zeros(2, 10, 10, 10) + 1e-5
chain[:, :, :, :] = vals.unsqueeze(-1).exp()
chain[:, :, :, :] += torch.eye(10, 10).view(1, 1, 10, 10)
chain[:, 0, :, 0] = 1
chain[:, -1,9, :] = 1
chain = chain.log()

dist = LinearChainCRF(chain)
show(dist.marginals.detach()[0].sum(-1))
```

![png](README_files/README_10_0.png)

## Library

Full docs: http://nlp.seas.harvard.edu/pytorch-struct/

Current distributions implemented:

* LinearChainCRF
* SemiMarkovCRF
* DependencyCRF
* NonProjectiveDependencyCRF
* TreeCRF
* NeuralPCFG / NeuralHMM

Each distribution includes:

* Argmax, sampling, entropy, partition, masking, log_probs, k-max

Extensions:

* Integration with `torchtext`, `pytorch-transformers`, `dgl`
* Adapters for generative structured models (CFG / HMM / HSMM)
* Common tree structured parameterizations TreeLSTM / SpanLSTM

## Low-level API:

Everything implemented through semiring dynamic programming.

* Log Marginals
* Max and MAP computation
* Sampling through specialized backprop
* Entropy and first-order semirings.

## Examples

* BERT Part-of-Speech
* BERT Dependency Parsing
* Unsupervised Learning
* Structured VAE

## Citation

```
@misc{alex2020torchstruct,
title={Torch-Struct: Deep Structured Prediction Library},
author={Alexander M. Rush},
year={2020},
eprint={2002.00876},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```

This work was partially supported by NSF grant IIS-1901030.