https://github.com/codingfisch/flashrl
Fast reinforcement learning đ¨
https://github.com/codingfisch/flashrl
multi-agent-reinforcement-learning reinforcement-learning reinforcement-learning-environments
Last synced: 7 months ago
JSON representation
Fast reinforcement learning đ¨
- Host: GitHub
- URL: https://github.com/codingfisch/flashrl
- Owner: codingfisch
- License: mit
- Created: 2025-03-01T09:22:19.000Z (8 months ago)
- Default Branch: main
- Last Pushed: 2025-03-08T15:54:23.000Z (8 months ago)
- Last Synced: 2025-03-08T16:31:13.211Z (8 months ago)
- Topics: multi-agent-reinforcement-learning, reinforcement-learning, reinforcement-learning-environments
- Language: Cython
- Homepage:
- Size: 81.1 KB
- Stars: 18
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# flashrl
`flashrl` does RL with **millions of steps/second đ¨ while being tiny**: ~200 lines of code
đ ī¸ `pip install flashrl` or clone the repo & `pip install -r requirements.txt`
- If cloned (or if envs changed), compile: `python setup.py build_ext --inplace`
đĄ `flashrl` will always be **tiny**: **Read the code** (+paste into LLM) to understand it!
## Quick Start đ
`flashrl` uses a `Learner` that holds an `env` and a `model` (default: `Policy` with LSTM)
```python
import flashrl as frl
learn = frl.Learner(frl.envs.Pong(n_agents=2**14))
curves = learn.fit(40, steps=16, desc='done')
frl.print_curve(curves['loss'], label='loss')
frl.play(learn.env, learn.model, fps=8)
learn.env.close()
```
`.fit` does RL with ~**10 million steps**: `40` iterations à `16` steps à `2**14` agents!
**Run it yourself via `python train.py` and play against the AI** đĒ
Click here, to read a tiny doc đ
`Learner` takes the arguments
- `env`: RL environment
- `model`: A `Policy` model
- `device`: Per default picks `mps` or `cuda` if available else `cpu`
- `dtype`: Per default `torch.bfloat16` if device is `cuda` else `torch.float32`
- `compile_no_lstm`: Speedup via `torch.compile` if `model` has no `lstm`
- `**kwargs`: Passed to the `Policy`, e.g. `hidden_size` or `lstm`
`Learner.fit` takes the arguments
- `iters`: Number of iterations
- `steps`: Number of steps in `rollout`
- `desc`: Progress bar description (e.g. `'reward'`)
- `log`: If `True`, `tensorboard` logging is enabled
- run `tensorboard --logdir=runs`and visit `http://localhost:6006` in the browser!
- `stop_func`: Function that stops training if it returns `True` e.g.
```python
...
def stop(kl, **kwargs):
return kl > .1
curves = learn.fit(40, steps=16, stop_func=stop)
...
```
- `lr`, `anneal_lr` & args of `ppo` after `bs`: Hyperparameters
The most important functions in `flashrl/utils.py` are
- `print_curve`: Visualizes the loss across the `iters`
- `play`: Plays the environment in the terminal and takes
- `model`: A `Policy` model
- `playable`: If `True`, allows you to act (or decide to let the model act)
- `steps`: Number of steps
- `fps`: Frames per second
- `obs`: Argument of the env that should be rendered as observations
- `dump`: If `True`, no frame refresh -> Frames accumulate in the terminal
- `idx`: Agent index between `0` and `n_agents` (default: `0`)
## Environments đšī¸
**Each env is one Cython(=`.pyx`) file** in `flashrl/envs`. **That's it!**
To **add custom envs**, use `grid.pyx`, `pong.pyx` or `multigrid.pyx` as a **template**:
- `grid.pyx` for **single-agent** envs (~110 LOC)
- `pong.pyx` for **1 vs 1 agent** envs (~150 LOC)
- `multigrid.pyx` for **multi-agent** envs (~190 LOC)
| `Grid` | `Pong` | `MultiGrid` |
|-----------------------|-----------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------|
| Agent must reach goal | Agent must score | Agent must reach goal first |
|| | |
## Acknowledgements đ
I want to thank
- [Joseph Suarez](https://github.com/jsuarez5341) for open sourcing RL envs in C(ython)! Star [PufferLib](https://github.com/PufferAI/PufferLib) â
- [Costa Huang](https://github.com/vwxyzjn) for open sourcing high-quality single-file RL code! Star [cleanrl](https://github.com/vwxyzjn/cleanrl) â
and last but not least...