https://github.com/frankroeder/goal_conditioned_rl
Goal-conditioned reinforcement learning like 🔥
https://github.com/frankroeder/goal_conditioned_rl
actor-critic deep-deterministic-policy-gradient deep-reinforcement-learning distrax droq flax goal-conditioned-rl goals gymnasium gymnasium-robotics jax mpi4jax optax reinforcement-learning robotics soft-actor-critic
Last synced: about 1 month ago
JSON representation
Goal-conditioned reinforcement learning like 🔥
- Host: GitHub
- URL: https://github.com/frankroeder/goal_conditioned_rl
- Owner: frankroeder
- License: mit
- Created: 2024-01-30T16:46:56.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2024-02-03T10:23:58.000Z (over 1 year ago)
- Last Synced: 2024-12-31T22:32:03.552Z (10 months ago)
- Topics: actor-critic, deep-deterministic-policy-gradient, deep-reinforcement-learning, distrax, droq, flax, goal-conditioned-rl, goals, gymnasium, gymnasium-robotics, jax, mpi4jax, optax, reinforcement-learning, robotics, soft-actor-critic
- Language: Python
- Homepage:
- Size: 29.3 KB
- Stars: 10
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Goal-Conditioned Reinforcement Learning (Jax/Flax/Optax)
This repository contains a collection of goal-conditioned reinforcement learning algorithms.
It is compatible with the latest [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) API and uses very recent version of jax, flax and optax.
We support multiprocessing via [mpi4jax](https://github.com/mpi4jax/mpi4jax) like the deprecated OpenAI [baselines](https://github.com/openai/baselines).## Supported Algorithms
- [x] Deep Deterministic Policy Gradient (DDPG [paper link](https://arxiv.org/abs/1509.02971))
- [x] Soft Actor-Critic (SAC [paper link](https://arxiv.org/abs/1801.01290))
- [x] DroQ ([paper link](https://arxiv.org/abs/2110.02034))All algorithms make use of Hindsight Experience Replay (HER [paper link](https://arxiv.org/abs/1707.01495))
## Installation
- `git clone https://github.com/frankroeder/goal_conditioned_rl.git`
- pip users: `pip install -r requirements.txt`
- conda users: `conda create --file= conda_env.yaml`
- libraries: `apt install libopenmpi-dev`### Jax CUDA Support
> https://github.com/google/jax#installation
To install on a machine with an Nvidia GPU, run
```bash
# install packages
pip install -r requirements.txt
# remove jaxlib and install cuda version of necessary
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
## Training### Single process
```bash
# SAC
python train.py n_epochs=10 agent=sac env_name=FetchPush-v2 hindsight=her agent.critic.dropout=0.0
# DDPG
python train.py n_epochs=10 agent=ddpg env_name=FetchPush-v2 hindsight=her
# DroQ
python train.py n_epochs=10 agent=sac env_name=FetchPush-v2 hindsight=her agent.critic.dropout=0.01
```### Multiple processes
```bash
mpirun -np 4 python -u train.py n_epochs=10 agent=sac env_name=FetchPush-v2 hindsight=her
```## Enjoy your trained agent
```bash
python demo.py --demo_path
# or
python demo.py --wandb_url
```## Results
![]()
![]()
... more results will follow
## References
- https://github.com/TianhongDai/hindsight-experience-replay
- https://github.com/akakzia/decstr
- https://github.com/frankroeder/hipss