https://github.com/sea-snell/grokking
unofficial re-implementation of "Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets"
https://github.com/sea-snell/grokking
artificial-intelligence deep-learning grokking neural-network python pytorch transformer transformer-models
Last synced: 9 months ago
JSON representation
unofficial re-implementation of "Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets"
- Host: GitHub
- URL: https://github.com/sea-snell/grokking
- Owner: Sea-Snell
- License: mit
- Created: 2021-11-17T06:45:01.000Z (over 4 years ago)
- Default Branch: main
- Last Pushed: 2022-07-04T14:18:07.000Z (almost 4 years ago)
- Last Synced: 2025-01-12T21:33:24.426Z (over 1 year ago)
- Topics: artificial-intelligence, deep-learning, grokking, neural-network, python, pytorch, transformer, transformer-models
- Language: Python
- Homepage:
- Size: 1.82 MB
- Stars: 66
- Watchers: 4
- Forks: 14
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## GROKKING: GENERALIZATION BEYOND OVERFITTING ON SMALL ALGORITHMIC DATASETS
### unofficial re-implementation of [this paper](https://mathai-iclr.github.io/papers/papers/MATHAI_29_paper.pdf) by Power et al.
### code written by Charlie Snell
pull and install:
```
git clone https://github.com/Sea-Snell/grokking.git
cd grokking/
pip install -r requirements.txt
```
To roughly re-create Figure 1 in the paper run:
```
export PYTHONPATH=$(pwd)/grokk_replica/
cd scripts/
python train_grokk.py
```

###### Running the above command should give curves like this.
Try different operations or learning / architectural hparams by modifying configurations in the `config/` directory. I use [Hydra](https://hydra.cc/docs/intro) to handle the configs (see their documentation to learn how to change configs in the commandline etc...).
Training uses [Weights And Biases](https://wandb.ai/home) by default to generate plots in realtime. If you would not like to use wandb, just set `wandb.use_wandb=False` in `config/train_grokk.yaml` or as an argument when calling `train_grokk.py`