https://github.com/ttitcombe/lotterytickethypothesis
PyTorch implementation of the lottery ticket hypothesis
https://github.com/ttitcombe/lotterytickethypothesis
lottery-ticket-hypothesis lottery-tickets pruning pytorch
Last synced: 3 months ago
JSON representation
PyTorch implementation of the lottery ticket hypothesis
- Host: GitHub
- URL: https://github.com/ttitcombe/lotterytickethypothesis
- Owner: TTitcombe
- License: apache-2.0
- Created: 2020-06-29T18:42:30.000Z (almost 5 years ago)
- Default Branch: master
- Last Pushed: 2020-07-14T17:28:53.000Z (almost 5 years ago)
- Last Synced: 2024-12-27T12:11:35.889Z (5 months ago)
- Topics: lottery-ticket-hypothesis, lottery-tickets, pruning, pytorch
- Language: Python
- Homepage:
- Size: 16.6 KB
- Stars: 1
- Watchers: 3
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Lottery Ticket Hypothesis
PyTorch implementation of the [Lottery Ticket Hypothesis][lottery].
This implementation uses PyTorch's `prune` module.## Pre-requisities
Developed in Python `3.7`,
but other versions should work.If using conda,
run
```bash
conda env create -f environment.yml
```
to create an environment,
`lottery-ticket`,
with all required packages.## To Run
`main.py` trains a `LeNet`-esque classifier on MNIST
with several rounds of pruning.```bash
usage: Finding Lottery Tickets on an MNIST classifier [-h] [--lr LR] [--bs BS]
[--epochs EPOCHS]
[--prune_pc PRUNE_PC]
[--prune_rounds PRUNE_ROUNDS]optional arguments:
-h, --help show this help message and exit
--lr LR Learning rate (default 1e-3)
--bs BS Batch size (default 128)
--epochs EPOCHS Number of epochs (default 8)
--prune_pc PRUNE_PC Percentage of parameters to prune over the course of
the training process (default 0.2)
--prune_rounds PRUNE_ROUNDS
Number of rounds of pruning to perform (default 5)```
## Contributions
Contributions are welcome.
If opening a PR,
ensure the code conforms to `black` formatting
and `isort` import configurations.Feel free to open an issue
to ask a question,
raise a bug,
or request new features.## Project Organization
```
├── README.md <- The top-level README for developers using this project.
│
├── environment.yml <- The conda environment file for creating the analysis environment, e.g.
│ `conda env create -f environment.yml`.
│
├── main.py <- The training script.
│
├── .gitignore <- git-ignore configuration file.
│
├── data <- Directory in which downloaded data will be stored. No data is provided in the repo.
│
├── src <- source code. Things imported into `main.py`.
```## TODO
- [X] Implement basic pruning
- [ ] Recreate experiments from the paper
- [ ] Use tensorboard to visualise model and pruning progress## License
See the full [license](./LICENSE)[lottery]: https://arxiv.org/abs/1803.03635