Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/toshas/torch_truncnorm
Truncated Normal Distribution in PyTorch
https://github.com/toshas/torch_truncnorm
distributions pytorch truncated-normal
Last synced: about 14 hours ago
JSON representation
Truncated Normal Distribution in PyTorch
- Host: GitHub
- URL: https://github.com/toshas/torch_truncnorm
- Owner: toshas
- License: bsd-3-clause
- Created: 2020-10-07T12:44:27.000Z (about 4 years ago)
- Default Branch: main
- Last Pushed: 2021-08-15T18:20:18.000Z (over 3 years ago)
- Last Synced: 2023-03-10T01:23:20.001Z (over 1 year ago)
- Topics: distributions, pytorch, truncated-normal
- Language: Python
- Homepage:
- Size: 15.6 KB
- Stars: 55
- Watchers: 2
- Forks: 8
- Open Issues: 2
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING
- License: LICENSE
Awesome Lists containing this project
README
# torch_truncnorm
Truncated Normal distribution in PyTorch. The module provides:
- `TruncatedStandardNormal` class - zero mean unit variance of the parent Normal distribution, parameterized by the
cut-off range `[a, b]` (similar to `scipy.stats.truncnorm`);
- `TruncatedNormal` class - a wrapper with extra `loc` and `scale` parameters of the parent Normal distribution;
- Differentiability wrt parameters of the distribution;
- Batching support.# Why
I just needed differentiation with respect to parameters of the distribution and found out that truncated normal
distribution is not bundled in `torch.distributions` as of 1.6.0.# Known issues
`icdf` is numerically unstable; as a consequence, so is `rsample`. This issue is also seen in
`torch.distributions.normal.Normal`, so it is sort of *normal* (ba-dum-tss).# Tests
```shell script
CUDA_VISIBLE_DEVICES=0 python -m tests.test
```# Links
https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf