Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/activatedgeek/torch-sgld
SGLD and cSGLD as a PyTorch Optimizer
https://github.com/activatedgeek/torch-sgld
csgld cyclical-learning-rate python pytorch pytorch-optimizers sgld
Last synced: 4 months ago
JSON representation
SGLD and cSGLD as a PyTorch Optimizer
- Host: GitHub
- URL: https://github.com/activatedgeek/torch-sgld
- Owner: activatedgeek
- License: apache-2.0
- Created: 2023-03-24T18:02:49.000Z (almost 2 years ago)
- Default Branch: main
- Last Pushed: 2023-05-04T12:21:58.000Z (almost 2 years ago)
- Last Synced: 2024-09-14T14:37:02.491Z (5 months ago)
- Topics: csgld, cyclical-learning-rate, python, pytorch, pytorch-optimizers, sgld
- Language: Jupyter Notebook
- Homepage:
- Size: 124 KB
- Stars: 7
- Watchers: 3
- Forks: 1
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# SGLD in PyTorch
[![PyPI version](https://badge.fury.io/py/torch-sgld.svg)](https://pypi.org/project/torch-sgld/)
This package implements [SGLD](https://icml.cc/2011/papers/398_icmlpaper.pdf)
and [cSGLD](https://arxiv.org/abs/1902.03932)
as a [PyTorch Optimizer](https://pytorch.org/docs/stable/optim.html).## Installation
Install from `pip` as:
```shell
pip install torch-sgld
```To install the latest directly from source, run
```shell
pip install git+https://github.com/activatedgeek/torch-sgld.git
```## Usage
The general idea is to modify the usual gradient-based update loops
in PyTorch with the `SGLD` optimizer.```python
from torch_sgld import SGLDf = module() ## construct PyTorch nn.Module.
sgld = SGLD(f.parameters(), lr=lr, momentum=.9) ## Add momentum to make it SG-HMC.
sgld_scheduler = ## Optionally add a step-size scheduler.for _ in range(num_steps):
energy = f()
energy.backward()sgld.step()
sgld_scheduler.step() ## Optional scheduler step.
````cSGLD` can be implemented by using a cyclical learning rate schedule.
See the [toy_csgld.ipynb](./notebooks/toy_csgld.ipynb) notebook for a
complete example.## License
Apache 2.0