Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/harvardnlp/pytorch-struct
Fast, general, and tested differentiable structured prediction in PyTorch
https://github.com/harvardnlp/pytorch-struct
Last synced: 6 days ago
JSON representation
Fast, general, and tested differentiable structured prediction in PyTorch
- Host: GitHub
- URL: https://github.com/harvardnlp/pytorch-struct
- Owner: harvardnlp
- License: mit
- Created: 2019-08-26T19:34:30.000Z (over 5 years ago)
- Default Branch: master
- Last Pushed: 2022-04-20T08:21:20.000Z (almost 3 years ago)
- Last Synced: 2025-02-07T18:09:34.661Z (13 days ago)
- Language: Jupyter Notebook
- Homepage: http://harvardnlp.github.io/pytorch-struct
- Size: 8.27 MB
- Stars: 1,108
- Watchers: 33
- Forks: 93
- Open Issues: 31
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- my-awesome - harvardnlp/pytorch-struct - 04 star:1.1k fork:0.1k Fast, general, and tested differentiable structured prediction in PyTorch (Jupyter Notebook)
- Awesome-pytorch-list-CNVersion - pytorch-struct
- awesome-nlp-note - CRF!!! harvardnlp/pytorch-struct
- awesome-list - Torch-Struct - A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications (Deep Learning Framework / High-Level DL APIs)
- Awesome-pytorch-list - pytorch-struct
- awesome-python-machine-learning-resources - GitHub - 44% open · ⏱️ 30.01.2022): (Pytorch实用程序)
README
# Torch-Struct: Structured Prediction Library
data:image/s3,"s3://crabby-images/4f6e0/4f6e06c3de850289577c9f5dbf09e49ace64b727" alt="Tests"
[data:image/s3,"s3://crabby-images/e7efe/e7efeecda77f540a1d86a3bc920f209cde6f48b1" alt="Coverage Status"](https://coveralls.io/github/harvardnlp/pytorch-struct?branch=master)
![]()
A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications.
* HMM / LinearChain-CRF
* HSMM / SemiMarkov-CRF
* Dependency Tree-CRF
* PCFG Binary Tree-CRF
* ...Designed to be used as efficient batched layers in other PyTorch code.
[Tutorial paper](https://arxiv.org/abs/2002.00876) describing methodology.
## Getting Started
```python
!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
# Optional CUDA kernels for FastLogSemiring
!pip install -qU git+https://github.com/harvardnlp/genbmm
# For plotting.
!pip install -q matplotlib
``````python
import torch
from torch_struct import DependencyCRF, LinearChainCRF
import matplotlib.pyplot as plt
def show(x): plt.imshow(x.detach())
``````python
# Make some data.
vals = torch.zeros(2, 10, 10) + 1e-5
vals[:, :5, :5] = torch.rand(5)
vals[:, 5:, 5:] = torch.rand(5)
dist = DependencyCRF(vals.log())
show(dist.log_potentials[0])
```data:image/s3,"s3://crabby-images/ac17d/ac17d140f9b87fdcc960666d0e0089ab92391c96" alt="png"
```python
# Compute marginals
show(dist.marginals[0])
```data:image/s3,"s3://crabby-images/fe629/fe629bbf7af17c579ad0e4c0c54a1159c8be1d4e" alt="png"
```python
# Compute argmax
show(dist.argmax.detach()[0])
```data:image/s3,"s3://crabby-images/fafdc/fafdc958404406f44a2d597171a6bb7fd137f336" alt="png"
```python
# Compute scoring and enumeration (forward / inside)
log_partition = dist.partition
max_score = dist.log_prob(dist.argmax)
``````python
# Compute samples
show(dist.sample((1,)).detach()[0, 0])
```data:image/s3,"s3://crabby-images/2f2b6/2f2b6d07ecfc006694fa0136c991376e64f08615" alt="png"
```python
# Padding/Masking built into library.
dist = DependencyCRF(vals, lengths=torch.tensor([10, 7]))
show(dist.marginals[0])
plt.show()
show(dist.marginals[1])
```data:image/s3,"s3://crabby-images/339dd/339dda32a8ffb93bf0e78feba51832cc5845fb26" alt="png"
data:image/s3,"s3://crabby-images/72889/72889ae692cb1b185951a2c0f79d5d16a3f4d359" alt="png"
```python
# Many other structured prediction approaches
chain = torch.zeros(2, 10, 10, 10) + 1e-5
chain[:, :, :, :] = vals.unsqueeze(-1).exp()
chain[:, :, :, :] += torch.eye(10, 10).view(1, 1, 10, 10)
chain[:, 0, :, 0] = 1
chain[:, -1,9, :] = 1
chain = chain.log()dist = LinearChainCRF(chain)
show(dist.marginals.detach()[0].sum(-1))
```data:image/s3,"s3://crabby-images/01367/01367713e6332f94ffb40696eb1bd7a12744cccf" alt="png"
## Library
Full docs: http://nlp.seas.harvard.edu/pytorch-struct/
Current distributions implemented:
* LinearChainCRF
* SemiMarkovCRF
* DependencyCRF
* NonProjectiveDependencyCRF
* TreeCRF
* NeuralPCFG / NeuralHMMEach distribution includes:
* Argmax, sampling, entropy, partition, masking, log_probs, k-max
Extensions:
* Integration with `torchtext`, `pytorch-transformers`, `dgl`
* Adapters for generative structured models (CFG / HMM / HSMM)
* Common tree structured parameterizations TreeLSTM / SpanLSTM## Low-level API:
Everything implemented through semiring dynamic programming.
* Log Marginals
* Max and MAP computation
* Sampling through specialized backprop
* Entropy and first-order semirings.## Examples
* BERT Part-of-Speech
* BERT Dependency Parsing
* Unsupervised Learning
* Structured VAE
## Citation
```
@misc{alex2020torchstruct,
title={Torch-Struct: Deep Structured Prediction Library},
author={Alexander M. Rush},
year={2020},
eprint={2002.00876},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```This work was partially supported by NSF grant IIS-1901030.