https://github.com/thu-ml/srpo
Codes accompanying the paper "Score Regularized Policy Optimization through Diffusion Behavior" (ICLR 2024).
https://github.com/thu-ml/srpo
behavior-regularization d4rl diffusion generative offline reinforcement-learning rl score-based-models srpo
Last synced: 10 months ago
JSON representation
Codes accompanying the paper "Score Regularized Policy Optimization through Diffusion Behavior" (ICLR 2024).
- Host: GitHub
- URL: https://github.com/thu-ml/srpo
- Owner: thu-ml
- License: mit
- Created: 2023-10-11T07:44:56.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2024-02-10T12:46:24.000Z (about 2 years ago)
- Last Synced: 2024-05-22T11:33:58.845Z (almost 2 years ago)
- Topics: behavior-regularization, d4rl, diffusion, generative, offline, reinforcement-learning, rl, score-based-models, srpo
- Language: Python
- Homepage:
- Size: 592 KB
- Stars: 28
- Watchers: 6
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Score Regularized Policy Optimization through Diffusion Behavior
Huayu Chen, Cheng Lu, Zhengyi Wang, Hang Su, Jun Zhu

## D4RL experiments
### Requirements
Installations of [PyTorch](https://pytorch.org/), [MuJoCo](https://github.com/deepmind/mujoco), and [D4RL](https://github.com/Farama-Foundation/D4RL) are needed.
### Running
Download the pretrained behavior and critic checkpoints from [here](https://drive.google.com/drive/folders/1N0qC6lakTtwLa7oE0B_9jHfwCj65Irxx?usp=drive_link) and store them under `./SRPO_model_factory/`.
You can also choose to pretrain the behavior and the critic model yourself. Respectively run
```.bash
TASK="halfcheetah-medium-v2"; seed=0; python3 -u train_behavior.py --expid ${TASK}-baseline-seed${seed} --env $TASK --seed ${seed}
```
```.bash
TASK="halfcheetah-medium-v2"; seed=0; python3 -u train_critic.py --expid ${TASK}-baseline-seed${seed} --env $TASK --seed ${seed}
```
Finally, run
```.bash
TASK="halfcheetah-medium-v2"; seed=0; python3 -u train_policy.py --expid ${TASK}-baseline-seed${seed} --env $TASK --seed ${seed} --actor_load_path ./SRPO_model_factory/${TASK}-baseline-seed${seed}/behavior_ckpt200.pth --critic_load_path ./SRPO_model_factory/${TASK}-baseline-seed${seed}/critic_ckpt150.pth
```
## License
MIT