Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/RohollahHS/BAD
The official Pytorch implementation of “BAD: Bidirectional Auto-regressive Diffusion for Text-to-Motion Generation”
https://github.com/RohollahHS/BAD
Last synced: about 1 month ago
JSON representation
The official Pytorch implementation of “BAD: Bidirectional Auto-regressive Diffusion for Text-to-Motion Generation”
- Host: GitHub
- URL: https://github.com/RohollahHS/BAD
- Owner: RohollahHS
- Created: 2024-09-10T16:24:28.000Z (4 months ago)
- Default Branch: master
- Last Pushed: 2024-10-22T05:02:54.000Z (2 months ago)
- Last Synced: 2024-10-23T07:21:03.391Z (2 months ago)
- Language: Python
- Homepage:
- Size: 4.96 MB
- Stars: 22
- Watchers: 3
- Forks: 1
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# [BAD: Bidirectional Auto-regressive Diffusion for Text-to-Motion Generation](https://www.arxiv.org/abs/2409.10847)
### [[Paper]](https://www.arxiv.org/abs/2409.10847) [[Project Page]](https://rohollahhs.github.io/BAD-page/) [[Colab]](https://colab.research.google.com/drive/1vsLKHR0DVY4moKC0SFuM8IVLJDHyLTkn?usp=sharing)![Sample Image](visualization/quality-comp-walk_page-0001.jpg)
If you find our code or paper helpful, please consider starring our repository and citing us.
```
@article{hosseyni2024bad,
title={BAD: Bidirectional Auto-regressive Diffusion for Text-to-Motion Generation},
author={Hosseyni, S Rohollah and Rahmani, Ali Ahmad and Seyedmohammadi, S Jamal and Seyedin, Sanaz and Mohammadi, Arash},
journal={arXiv preprint arXiv:2409.10847},
year={2024}
}
```## News
📢 **2024-09-24** --- Initialized the webpage and git project.
## Get You Ready
### 1. Conda EnvironmentFor training and evaluation, we used the following conda environment, which is based on the [MMM](https://github.com/exitudio/MMM.git) environment:
```
conda env create -f environment.yml
conda activate bad
pip install git+https://github.com/openai/CLIP.git
```We encountered issues when using the above environment for generation and visualization. As a result, we had to use a new environment. You may try changing the version of some packages from the previous environment, particularly numpy, and it might work. The new environment is based on the [Momask](https://github.com/EricGuo5513/momask-codes.git) environment, with additional packages like smplx from the [MDM](https://github.com/GuyTevet/motion-diffusion-model.git) environment.
```
conda env create -f environment2.yml
conda activate bad2
pip install git+https://github.com/openai/CLIP.git
```### 2. Models and Dependencies
#### Download Pre-trained Models
```
bash dataset/prepare/download_models.sh
```#### Download SMPL Files
For rendering.
```
bash dataset/prepare/download_smpl_files.sh
```#### Download Evaluation Models and Gloves
For evaluation only.
```
bash dataset/prepare/download_extractor.sh
bash dataset/prepare/download_glove.sh
```#### Troubleshooting
To address the download error related to gdown: "Cannot retrieve the public link of the file. You may need to change the permission to 'Anyone with the link', or have had many accesses". A potential solution is to run `pip install --upgrade --no-cache-dir gdown`, as suggested on https://github.com/wkentaro/gdown/issues/43. This should help resolve the issue.#### (Optional) Download Manually
Visit [[Google Drive]](https://drive.google.com/drive/folders/1sHajltuE2xgHh91H9pFpMAYAkHaX9o57?usp=drive_link) to download the models and evaluators mannually.### 3. Get Data
**HumanML3D** -
We are using two 3D human motion-language dataset: HumanML3D and KIT-ML. For both datasets, you could find the details as well as download link [here](https://github.com/EricGuo5513/HumanML3D.git).```
./dataset/HumanML3D/
├── new_joint_vecs/
├── texts/
├── Mean.npy # same as in [HumanML3D](https://github.com/EricGuo5513/HumanML3D)
├── Std.npy # same as in [HumanML3D](https://github.com/EricGuo5513/HumanML3D)
├── train.txt
├── val.txt
├── test.txt
├── train_val.txt
└── all.txt
```**KIT-ML** - For KIT-ML dataset, you can download and extract it using the following files:
```
bash dataset/prepare/download_kit.sh
bash dataset/prepare/extract_kit.sh
```
If you face any issues, you can refer to [this link](https://github.com/EricGuo5513/HumanML3D.git).## Training
### Stage 1: VQ-VAE
```
python train_vq.py --exp_name 'trian_vq' \
--dataname t2m \
--total_batch_size 256
```
- **`--exp_name`**: The name of your experiment.
- **`--dataname`**: Dataset name; use `t2m` for HumanML3D and `kit` for KIT-ML dataset.
### Stage 2: Transformer
```
python train_t2m_trans.py --exp_name 'train_tr' \
--dataname t2m \
--time_cond \
--z_0_attend_to_all \
--unmasked_tokens_not_attend_to_mask_tokens \
--total_batch_size 256 \
--vq_pretrained_path ./output/vq/vq_last.pth
```
- **`--z_0_attend_to_all`**: Specifies the causality condition for mask tokens, where each mask token attends to the last `T-p+1` mask tokens. If `z_0_attend_to_all` is not activated, each mask token attends to the first `p` mask tokens.
- **`--time_cond`**: Uses time as one of the conditions for training the transformer.
- **`--unmasked_tokens_not_attend_to_mask_tokens`**: Prohibits mask tokens from attending to other mask tokens.
- **`--vq_pretrained_path`**: The path to your pretrained VQ-VAE.## Evaluation
For sampling using Order-Agnostic Autoregressive Sampling (OAAS), `rand_pos` should be set to `False`. `rand_pos=False` means that the token with the highest probability is always sampled, and no `top_p`, `top_k`, or `temperature` is applied. If `rand_pos=True`, the metrics significantly worsen, whereas in Confidence-Based Sampling (CBS), the metrics significantly improve. We do not know why OAAS performance worsens with random sampling during generation. Maybe this is a bug; we are not sure! We would be extremely grateful if anyone could help fix this issue.
```
python GPT_eval_multi.py --exp_name "eval" \
--sampling_type OAAS \
--z_0_attend_to_all \
--time_cond \
--unmasked_tokens_not_attend_to_mask_tokens \
--num_repeat_inner 1 \
--resume_pth ./output/vq/vq_last.pth \
--resume_trans ./output/t2m/trans_best_fid.pth
```
- **`--sampling_type`**: Type of sampling.
- **`--num_repeat_inner`**: If you want to calculate MModality, it should be above 10, like 20. For other metrics, 1 is enough.
- **`--resume_pth`**: The path to your pretrained VQ-VAE.
- **`--resume_trans`**: The path to your pretrained transformer.For sampling using Confidence-Based Sampling (CBS), `rand_pos=True` significantly improves FID compared to CBS with `rand_pos=False`.
```
python GPT_eval_multi.py --exp_name "eval" \
--z_0_attend_to_all \
--time_cond \
--sampling_type CBS \
--rand_pos \
--unmasked_tokens_not_attend_to_mask_tokens \
--num_repeat_inner 1 \
--resume_pth ./output/vq/vq_last.pth \
--resume_trans ./output/t2m/trans_best_fid.pth
```For evaluation of four temporal editing tasks (inpainting, outpainting, prefix prediction, suffix prediction), you should use `eval_edit.py`. We used OAAS to report our results on temporal editing tasks in Table 3 of the paper.
```
python eval_edit.py --exp_name "eval" \
--edit_task inbetween \
--z_0_attend_to_all \
--time_cond \
--sampling_type OAAS \
--unmasked_tokens_not_attend_to_mask_tokens \
--num_repeat_inner 1 \
--resume_pth ./output/vq/vq_last.pth \
--resume_trans ./output/t2m/trans_best_fid.pth
```
- **`--edit_task`**: Four edit tasks are available: `inbetween`, `outpainting`, `prefix`, and `suffix`.## Generation
For generating a motion sequence run the following
```
python generate.py --caption 'a person jauntily skips forward.' \
--length 196 \
--z_0_attend_to_all \
--time_cond \
--sampling_type OAAS \
--unmasked_tokens_not_attend_to_mask_tokens \
--resume_pth ./output/vq/vq_last.pth \
--resume_trans ./output/t2m/trans_best_fid.pth
```
- **`--length`**: The length of the motion sequence. If not provided, a length estimator will be used to predict the length of the motion sequence based on the caption.
- **`--caption`**: Text prompt used for generating the motion sequence.For temporal editing, run the following.
```
python generate.py --temporal_editing \
--caption 'a person jauntily skips forward.' \
--caption_inbetween 'a man walks in a clockwise circle an then sits.' \
--length 196 \
--edit_task inbetween \
--z_0_attend_to_all \
--time_cond \
--sampling_type OAAS \
--unmasked_tokens_not_attend_to_mask_tokens \
--resume_pth ./output/vq/vq_last.pth \
--resume_trans ./output/t2m/trans_best_fid.pth
```
- **`--caption_inbetween`**: Text prompt used for generating the `inbetween`/`outpainting`/`prefix`/`suffix` motion sequence.
- **`--edit_task`**: Four edit tasks are available: `inbetween`, `outpainting`, `prefix`, and `suffix`.For long sequence generation, run the following.
```
python generate.py --long_seq_generation \
--long_seq_captions 'a person runs forward and jumps.' 'a person crawls.' 'a person does a cart wheel.' 'a person walks forward up stairs and then climbs down.' 'a person sits on the chair and then steps up.' \
--long_seq_lengths 128 196 128 128 128 \
--z_0_attend_to_all \
--time_cond \
--sampling_type OAAS \
--unmasked_tokens_not_attend_to_mask_tokens \
--resume_pth ./output/vq/vq_last.pth \
--resume_trans ./output/t2m/trans_best_fid.pth
```
- **`--long_seq_generation`**: Activating long sequence generation.
- **`--long_seq_captions`**: Specifies multiple captions.
- **`--long_seq_lengths`**: Specifies multiple lengths (between 40 and 196) corresponding to each caption.## Visualization
The above commands will save `.bvh` and `.mp4` files in `./output/visualization/` directory. The `.bvh` file can be rendered in Blender. Please refer to [this link](https://github.com/EricGuo5513/momask-codes?tab=readme-ov-file#dancers-visualization) for more information.
To render the motion sequence in SMPL, you need to pass the `.mp4` and `.npy` file generated by `generate.py` to `visualization/render_mesh.py`. The following command will create `.obj` files that can be easily imported into Blender. This script is running [SMPLify](https://smplify.is.tue.mpg.de/) and needs GPU as well.
```
python visualization/render_mesh.py \
--input_path output/visualization/animation/a_person_jauntily_skips_forwar_196/sample103_repeat0_len196.mp4 \
--npy_path output/visualization/joints/a_person_jauntily_skips_forwar_196/sample103_repeat0_len196.npy
```
- **`--input_path`**: Path to the `.mp4` file, created by `generate.py`.
- **`--npy_path`**: Path to the `.npy` file, created by `generate.py`For rendering `.obj` files using Blender, you can use the scripts in the [visualization/blender_scripts](https://github.com/RohollahHS/BAD/tree/master/visualization/blender_scripts) directory. First, open Blender, then go to **File -> Import -> Wavefront (.obj)**, navigate to the directory containing the `.obj` files, and press `A` to select and import all of them. Next, copy and paste the script from [visualization/blender_scripts/framing_coloring.py](https://github.com/RohollahHS/BAD/blob/master/visualization/blender_scripts/framing_coloring.py) into the **Scripting** tab in Blender, and run the script. Finally, you can render the animation in the **Render** tab.
## Acknowledgement
We would like to express our sincere gratitude to [MMM](https://github.com/exitudio/MMM.git), [Momask](https://github.com/EricGuo5513/momask-codes.git), [MDM](https://github.com/GuyTevet/motion-diffusion-model.git), and [T2M-GPT](https://github.com/Mael-zys/T2M-GPT.git) for their outstanding open-source contributions.