Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/graph-0/jodo

Learning Joint 2D & 3D Diffusion Models for Complete Molecule Generation
https://github.com/graph-0/jodo

diffusion-models graph-neural-networks molecule

Last synced: 2 months ago
JSON representation

Learning Joint 2D & 3D Diffusion Models for Complete Molecule Generation

Awesome Lists containing this project

README

        

# JODO

----

The implementation of [Learning Joint 2D & 3D Diffusion Models for
Complete Molecule Generation](https://arxiv.org/abs/2305.12347).

Represent molecules as 3D point cloud and 2D bonding graph:



The generative diffusion process:



----

Visualization of molecules generated by JODO trained on the GEOM-Drugs dataset:












Visualization of molecules generated by JODO trained on the QM9 dataset with explict hydrogen atoms:












----

## Dependencies
* [pyTorch >= 1.11](https://pytorch.org/)
* [PyG 2.1](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html)
* See requirements.txt for others.

## Dataset

We recommend using our processed dataset files provided [here](https://zenodo.org/record/7966493).

Download datasets:
```bash
# 718MB
wget https://zenodo.org/record/7966493/files/data.zip
unzip data.zip
```

If you want to construct the GEOM-Drugs dataset from scratch:
* The raw GEOM dataset is available at [here](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/JNGTDF).
* Download `rdkit_folder.tar.gz` and unpack it.
* Run `python build_geom_dataset.py --data_dir YOUR_DATA_PATH`.

## Generated Molecules
We provide pickles of 10000 molecules generated by JODO on different datasets in `./rdkit_mols`.
Molecules are saved as RDKit Mol objects. Just load the list of molecules and make further analysis.

```python
# Example for loading molecules generated from JODO trained on GEOM-Drugs dataset.
import pickle
mol_list = pickle.load(open('rdkit_mols/geom_jodo_ancestral_ckpt_35.pkl', 'rb'))
```

## Evaluation
We construct a comprehensive evaluation pipeline for molecule generation, including 2D molecular graph metrics,
3D geometry metrics, and substructure geometry alignment metrics.
* Especially for 3D geometry metrics, we follow https://github.com/ehoogeboom/e3_diffusion_for_molecules to use distance
lookup table to predict bonds and report the same stability metrics for 3D geometry comparisons.
* However, stability metrics for 3D geometry may be tricked in some situation. Some methods get high stability
ratio but fail on FCD and alignment MMD, implying poor molecule generation quality.
This phenomenon is more pronounced on the GEOM-Drugs dataset because of more atypical interatomic distances.
* We recommend using the stability metric more cautiously, preferably in combination with other metrics to evaluate
molecular quality.

To evaluate your models with our pipeline conveniently, you can save your generated molecules as a list of RDKit Mol
objects and run `eval_rdkit_pkl.py`.

Take QM9 as an example:
```shell
# Molecules with 3D positions and atom types, without bonds
python eval_rdkit_pkl.py --dataset_name qm9 --type 3D --root_path YOUR_DATASET_PATH --pkl_path YOUR_MOL_PATH

# Molecules with atom and bond types, without 3D positions
python eval_rdkit_pkl.py --dataset_name qm9 --type 2D --root_path YOUR_DATASET_PATH --pkl_path YOUR_MOL_PATH

# Molecules with atom types, bond types and 3D positions
python eval_rdkit_pkl.py --dataset_name qm9 --type both --sub_geometry=True --root_path YOUR_DATASET_PATH --pkl_path YOUR_MOL_PATH
```

## Checkpoint

Our checkpoints are provided [here](https://zenodo.org/record/8002902).

Download checkpoints:
```bash
# Unconditional Generation: QM9, GEOM-Drugs (2.8GB)
wget https://zenodo.org/record/8002902/files/exp_uncond.zip
unzip exp_uncond.zip

# Conditional Generation: single quantum property on QM9 (3.1GB)
wget https://zenodo.org/record/8002902/files/exp_cond.zip
unzip exp_cond.zip

# Conditional Generation: multi properties (1.6GB)
wget https://zenodo.org/record/8002902/files/exp_cond_multi.zip
unzip exp_cond_multi.zip

# Molecular Graph Generation: ZINC250k, MOSES (3.9GB)
wget https://zenodo.org/record/8002902/files/exp_2d.zip
unzip exp_2d.zip
```

## Unconditional Generation

QM9 Training Example:
```shell
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_qm9_jodo
```
* Set GPU_id with `CUDA_VISIBLE_DEVICES`, support multi GPUs.

QM9 Sampling Example:
```shell
# sample from our pretrained checkpoint
CUDA_VISIBLE_DEVICES=2 python main.py --config configs/vpsde_qm9_uncond_jodo.py --mode eval --workdir exp_uncond/vpsde_qm9_jodo --config.eval.ckpts '30' --config.eval.batch_size 2500 --config.sampling.steps 1000
```
* Set `--config.eval.batch_size` to control GPU memory usage.
* Set iteration steps via `--config.sampling.steps`. (Great results can be obtained from 1000 steps to 50 steps)

GEOM-Drugs Training Example:
```shell
# Base
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_geom_jodo_base --config.model.n_layers 6 --config.model.nf 128

# Medium
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_geom_jodo_media

# Large
CUDA_VISIBLE_DEVICES=0,1 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_geom_jodo_large --config.model.nf 384 --config.training.n_iters 1500000
```

GEOM-Drugs Sampling Example:
```shell
# Base
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode eval --workdir exp_uncond/vpsde_geom_jodo_base --config.model.n_layers 6 --config.model.nf 128 --config.eval.ckpts '30' --config.eval.batch_size 800 --config.sampling.steps 1000

# Medium
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode eval --workdir exp_uncond/vpsde_geom_jodo_media --config.eval.ckpts '30' --config.eval.batch_size 1000 --config.sampling.steps 1000

# Large
CUDA_VISIBLE_DEVICES=0,1 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode eval --workdir exp_uncond/vpsde_geom_jodo_large --config.model.nf 384 --config.eval.ckpts '30' --config.eval.batch_size 500 --config.sampling.steps 1000
```

Using the simplified DGT without extra attention heads can also achieve relatively good performance:
```shell
# QM9 Training
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_qm9_jodo_sim --config.model.name DGT_concat_sim

# GEOM-Drugs Medium Training
CUDA_VISIBLE_DEVICES=2,3 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_geom_jodo_media_sim --config.model.name DGT_concat_sim
```

## Conditional Generation

```shell
# Training
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_gap --config.cond_property gap
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_homo --config.cond_property homo
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_lumo --config.cond_property lumo
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_mu --config.cond_property mu
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_Cv --config.cond_property Cv
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_alpha --config.cond_property alpha

# Sampling
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_gap --config.cond_property gap --config.eval.ckpts '40'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_homo --config.cond_property homo --config.eval.ckpts '40'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_lumo --config.cond_property lumo --config.eval.ckpts '40'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_mu --config.cond_property mu --config.eval.ckpts '40'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_Cv --config.cond_property Cv --config.eval.ckpts '40'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_alpha --config.cond_property alpha --config.eval.ckpts '40'
```

* Set conditional property `alpha, gap, homo, lumo, mu, Cv` by `--config.cond_property`.

```shell
# Training
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode train --workdir exp_cond_multi/vpsde_qm9_cond_jodo_Cv_mu --config.cond_property1 Cv --config.cond_property2 mu
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode train --workdir exp_cond_multi/vpsde_qm9_cond_jodo_gap_mu --config.cond_property1 gap --config.cond_property2 mu
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode train --workdir exp_cond_multi/vpsde_qm9_cond_jodo_alpha_mu --config.cond_property1 alpha --config.cond_property2 mu

# Sampling
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode eval --workdir exp_cond_multi/vpsde_qm9_cond_jodo_Cv_mu --config.cond_property1 Cv --config.cond_property2 mu --config.eval.ckpts '50'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode eval --workdir exp_cond_multi/vpsde_qm9_cond_jodo_gap_mu --config.cond_property1 gap --config.cond_property2 mu --config.eval.ckpts '50'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode eval --workdir exp_cond_multi/vpsde_qm9_cond_jodo_alpha_mu --config.cond_property1 alpha --config.cond_property2 mu --config.eval.ckpts '50'
```
* Set multi conditional properties via `--config.cond_property1` and `--config.cond_property2`.

## Molecular Graph Generation

ZINC250k:
```shell
# Training
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_zinc_2d_jodo.py --mode train --workdir exp_2d/vpsde_zinc_2d_jodo --config.model.nf 1024 --config.model.n_heads 64 --config.model.n_layers 6 --config.training.snapshot_freq 300000

# Sampling
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_zinc_2d_jodo.py --mode eval --workdir exp_2d/vpsde_zinc_2d_jodo --config.model.nf 1024 --config.model.n_heads 64 --config.model.n_layers 6 --config.training.snapshot_freq 300000 --config.eval.ckpts '5'
```
* You can train a smaller model by `--config.model.nf 256 --config.model.n_heads 16 --config.model.n_layers 8`.

MOSES:
```shell
# Training
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_moses_2d_jodo.py --mode train --workdir exp_2d/vpsde_moses_2d_jodo --config.model.nf 1024 --config.model.n_heads 64 --config.model.n_layers 6 --config.training.snapshot_freq 300000

# Sampling
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_moses_2d_jodo.py --mode eval --workdir exp_2d/vpsde_moses_2d_jodo --config.model.nf 1024 --config.model.n_heads 64 --config.model.n_layers 6 --config.training.snapshot_freq 300000 --config.eval.ckpts '4'
```

Training CDGS on QM9 and GEOM-Drugs:
```shell
# QM9
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_2d_cdgs.py --mode train --workdir exp_2d/vpsde_qm9_2d_cdgs

# GEOM-Drugs
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_geom_2d_cdgs.py --mode train --workdir exp_2d/vpsde_geom_2d_cdgs
```

## Citation

```bibtex
@article{huang2023learning,
title={Learning Joint 2D \& 3D Diffusion Models for Complete Molecule Generation},
author={Huang, Han and Sun, Leilei and Du, Bowen and Lv, Weifeng},
journal={arXiv preprint arXiv:2305.12347},
year={2023}
}

@article{huang2023conditional,
title={Conditional Diffusion Based on Discrete Graph Structures for Molecular Graph Generation},
author={Huang, Han and Sun, Leilei and Du, Bowen and Lv, Weifeng},
journal={arXiv preprint arXiv:2301.00427},
year={2023}
```