https://github.com/emapco/chem-mrl
Chem-MRL: SMILES Matryoshka Representation Learning Embedding Model
https://github.com/emapco/chem-mrl
chemoinformatics embedding-models latent-space matryoshka-representation-learning smiles
Last synced: about 1 month ago
JSON representation
Chem-MRL: SMILES Matryoshka Representation Learning Embedding Model
- Host: GitHub
- URL: https://github.com/emapco/chem-mrl
- Owner: emapco
- License: apache-2.0
- Created: 2025-01-17T08:05:03.000Z (5 months ago)
- Default Branch: main
- Last Pushed: 2025-03-29T09:44:04.000Z (2 months ago)
- Last Synced: 2025-05-01T13:19:44.495Z (about 1 month ago)
- Topics: chemoinformatics, embedding-models, latent-space, matryoshka-representation-learning, smiles
- Language: Python
- Homepage:
- Size: 31.4 MB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# CHEM-MRL
Chem-MRL is a SMILES embedding transformer model that leverages Matryoshka Representation Learning (MRL) to generate efficient, truncatable embeddings for downstream tasks such as classification, clustering, and database querying.
The model employs [SentenceTransformers' (SBERT)](https://sbert.net/) [2D Matryoshka Sentence Embeddings](https://sbert.net/examples/training/matryoshka/README.html) (`Matryoshka2dLoss`) to enable truncatable embeddings with minimal accuracy loss, improving query performance and flexibility in downstream applications.
Datasets should consists of SMILES pairs and their corresponding [Morgan fingerprint](https://www.rdkit.org/docs/GettingStartedInPython.html#morgan-fingerprints-circular-fingerprints) Tanimoto similarity scores. Currently, datasets must be in Parquet format.
Hyperparameter tuning indicates that a custom Tanimoto similarity loss function, `TanimotoSentLoss`, based on [CoSENTLoss](https://kexue.fm/archives/8847), outperforms [Tanimoto similarity](https://jcheminf.biomedcentral.com/articles/10.1186/s13321-015-0069-3/tables/2), CoSENTLoss, [AnglELoss](https://arxiv.org/pdf/2309.12871), and cosine similarity.
## Installation
**Install with pip**
```bash
pip install chem-mrl
```**Install from source code**
```bash
pip install -e .
```## Usage
### Hydra & Training Scripts
Hydra configuration files are in `chem_mrl/conf`. The base config defines shared arguments, while model-specific configs are located in `chem_mrl/conf/model`. Use `chem_mrl_config.yaml` or `classifier_config.yaml` to run specific models.
The `scripts` directory provides training scripts with Hydra for parameter management:
- **Train Chem-MRL model:**
```bash
python scripts/train_chem_mrl.py train_dataset_path=/path/to/training.parquet val_dataset_path=/path/to/val.parquet
```
- **Train a linear classifier:**
```bash
python scripts/train_classifier.py train_dataset_path=/path/to/training.parquet val_dataset_path=/path/to/val.parquet
```### Basic Training Workflow
To train a model, initialize the configuration with dataset paths and model parameters, then pass it to `ChemMRLTrainer` for training.
```python
from chem_mrl.schemas import BaseConfig, ChemMRLConfig
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.trainers import ChemMRLTrainer# Define training configuration
config = BaseConfig(
model=ChemMRLConfig(
model_name=BASE_MODEL_NAME, # Predefined model name - Can be any transformer model name or path that is compatible with sentence-transformers
n_dims_per_step=3, # Model-specific hyperparameter
use_2d_matryoshka=True, # Enable 2d MRL
# Additional parameters specific to 2D MRL models
n_layers_per_step=2,
kl_div_weight=0.7, # Weight for KL divergence regularization
kl_temperature=0.5, # Temperature parameter for KL loss
),
train_dataset_path="train.parquet", # Path to training data
val_dataset_path="val.parquet", # Path to validation data
test_dataset_path="test.parquet", # Optional test dataset
smiles_a_column_name="smiles_a", # Column with first molecule SMILES representation
smiles_b_column_name="smiles_b", # Column with second molecule SMILES representation
label_column_name="similarity", # Similarity score between molecules
)# Initialize trainer and start training
trainer = ChemMRLTrainer(config)
test_eval_metric = (
trainer.train()
) # Returns the test evaluation metric if a test dataset is provided.
# Otherwise returns the final validation eval metric
```### Experimental
#### Train a Query Model
To train a querying model, configure the model to utilize the specialized query tokenizer.
The query tokenizer supports the following query types:
- similar: Computes SMILES similarity between two molecular structures. For retrieving similar SMILES.
- substructure: Determines the presence of a substructure within the second SMILES string.Supported query formats for `smiles_a` column:
- `similar {smiles}`
- `substructure {smiles}````python
from chem_mrl.schemas import BaseConfig, ChemMRLConfig
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.trainers import ChemMRLTrainerconfig = BaseConfig(
model=ChemMRLConfig(
model_name=BASE_MODEL_NAME,
use_query_tokenizer=True, # Train a query model
),
train_dataset_path="train.parquet",
val_dataset_path="val.parquet",
smiles_a_column_name="query",
smiles_b_column_name="target_smiles",
label_column_name="similarity",
)
trainer = ChemMRLTrainer(config)
```#### Latent Attention Layer
The Latent Attention Layer model is an experimental component designed to enhance the representation learning of transformer-based models by introducing a trainable latent dictionary. This mechanism applies cross-attention between token embeddings and a set of learnable latent vectors before pooling. The output of this layer contributes to both **1D Matryoshka loss** (as the final layer output) and **2D Matryoshka loss** (by integrating into all-layer outputs). Note: initial tests suggests that when using default configuration, the latent attention layer leads to overfitting.
```python
from chem_mrl.models import LatentAttentionLayer
from chem_mrl.schemas import BaseConfig, ChemMRLConfig, LatentAttentionConfig
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.trainers import ChemMRLTrainerconfig = BaseConfig(
model=ChemMRLConfig(
model_name=BASE_MODEL_NAME,
latent_attention_config=LatentAttentionConfig(
hidden_dim=768, # Transformer hidden size
num_latents=512, # Number of learnable latents
num_cross_heads=8, # Number of attention heads
cross_head_dim=32, # Dimensionality of each head
output_normalize=True, # Apply L2 normalization to outputs
),
use_2d_matryoshka=True,
),
train_dataset_path="train.parquet",
val_dataset_path="val.parquet",
)# Train a model with latent attention
trainer = ChemMRLTrainer(config)
```### Custom Evaluation Callbacks
You can provide a callback function that is executed every `evaluation_steps` steps, allowing custom logic such as logging, early stopping, or model checkpointing.
```python
from chem_mrl.schemas import BaseConfig, ChemMRLConfig
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.trainers import ChemMRLTrainer# Define a callback function for logging evaluation metrics
def eval_callback(score: float, epoch: int, steps: int):
print(f"Step {steps}, Epoch {epoch}: Evaluation Score = {score}")config = BaseConfig(
model=ChemMRLConfig(
model_name=BASE_MODEL_NAME,
),
train_dataset_path="train.parquet",
val_dataset_path="val.parquet",
smiles_a_column_name="smiles_a",
smiles_b_column_name="smiles_b",
label_column_name="similarity",
)# Train with callback
trainer = ChemMRLTrainer(config)
val_eval_metric = trainer.train(
eval_callback=eval_callback
) # Callback executed every `evaluation_steps`
```### W&B Integration
This library includes a `WandBTrainerExecutor` class for seamless Weights & Biases (W&B) integration. It handles authentication, initialization, and logging at the frequency specified by `evaluation_steps`.
```python
from chem_mrl.schemas import BaseConfig, WandbConfig, ChemMRLConfig
from chem_mrl.constants import BASE_MODEL_NAME
from chem_mrl.trainers import ChemMRLTrainer, WandBTrainerExecutor
from chem_mrl.schemas.Enums import WatchLogOption# Define W&B configuration for experiment tracking
wandb_config = WandbConfig(
project_name="chem_mrl_test", # W&B project name
run_name="test", # Name for the experiment run
use_watch=True, # Enables model watching for tracking gradients
watch_log=WatchLogOption.all, # Logs all model parameters and gradients
watch_log_freq=1000, # Logging frequency
watch_log_graph=True, # Logs model computation graph
)# Configure training with W&B integration
config = BaseConfig(
model=ChemMRLConfig(
model_name=BASE_MODEL_NAME,
),
train_dataset_path="train.parquet",
val_dataset_path="val.parquet",
smiles_a_column_name="smiles_a",
smiles_b_column_name="smiles_b",
label_column_name="similarity",
evaluation_steps=1000,
wandb=wandb_config,
)# Initialize trainer and W&B executor
trainer = ChemMRLTrainer(config)
executor = WandBTrainerExecutor(trainer)
executor.execute() # Handles training and W&B logging
```## Classifier
This repository includes code for training a linear classifier with optional dropout regularization. The classifier categorizes substances based on SMILES and category features.
Hyperparameter tuning shows that cross-entropy loss (`softmax` option) outperforms self-adjusting dice loss in terms of accuracy, making it the preferred choice for molecular property classification.
### Usage
#### Basic Classification Training
To train a classifier, configure the model with dataset paths and column names, then initialize `ClassifierTrainer` to start training.
```python
from chem_mrl.schemas import BaseConfig, ClassifierConfig
from chem_mrl.trainers import ClassifierTrainer# Define classification training configuration
config = BaseConfig(
model=ClassifierConfig(
model_name="path/to/trained_mrl_model", # Pretrained MRL model path
),
train_dataset_path="train_classification.parquet", # Path to training dataset
val_dataset_path="val_classification.parquet", # Path to validation dataset
smiles_a_column_name="smiles", # Column containing SMILES representations of molecules
label_column_name="label", # Column containing classification labels
)# Initialize and train the classifier
trainer = ClassifierTrainer(config)
trainer.train()
```#### Training with Dice Loss
For imbalanced classification tasks, **Dice Loss** can improve performance by focusing on hard-to-classify samples. Below is a configuration using `DiceLossClassifierConfig`, which introduces additional hyperparameters.
```python
from chem_mrl.schemas import BaseConfig, ClassifierConfig
from chem_mrl.trainers import ClassifierTrainer
from chem_mrl.schemas.Enums import ClassifierLossFctOption, DiceReductionOption# Define classification training configuration with Dice Loss
config = BaseConfig(
model=ClassifierConfig(
model_name="path/to/trained_mrl_model",
loss_func=ClassifierLossFctOption.selfadjdice,
dice_reduction=DiceReductionOption.sum, # Reduction method for Dice Loss (e.g., 'mean' or 'sum')
dice_gamma=1.0, # Smoothing factor hyperparameter
),
train_dataset_path="train_classification.parquet", # Path to training dataset
val_dataset_path="val_classification.parquet", # Path to validation dataset
smiles_a_column_name="smiles",
label_column_name="label",
)# Initialize and train the classifier with Dice Loss
trainer = ClassifierTrainer(config)
trainer.train()
```## References:
- Chithrananda, Seyone, et al. "ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction." _arXiv [Cs.LG]_, 2020. [Link](http://arxiv.org/abs/2010.09885).
- Ahmad, Walid, et al. "ChemBERTa-2: Towards Chemical Foundation Models." _arXiv [Cs.LG]_, 2022. [Link](http://arxiv.org/abs/2209.01712).
- Kusupati, Aditya, et al. "Matryoshka Representation Learning." _arXiv [Cs.LG]_, 2022. [Link](https://arxiv.org/abs/2205.13147).
- Li, Xianming, et al. "2D Matryoshka Sentence Embeddings." _arXiv [Cs.CL]_, 2024. [Link](http://arxiv.org/abs/2402.14776).
- Bajusz, Dávid, et al. "Why is the Tanimoto Index an Appropriate Choice for Fingerprint-Based Similarity Calculations?" _J Cheminform_, 7, 20 (2015). [Link](https://doi.org/10.1186/s13321-015-0069-3).
- Li, Xiaoya, et al. "Dice Loss for Data-imbalanced NLP Tasks." _arXiv [Cs.CL]_, 2020. [Link](https://arxiv.org/abs/1911.02855)
- Reimers, Nils, and Gurevych, Iryna. "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks." _Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing_, 2019. [Link](https://arxiv.org/abs/1908.10084).
- Lee, Chankyu, et al. "NV-Embed: Improved Techniques for Training LLMs as Generalist Embedding Models." _arXiv [Cs.CL]_, 2025. [Link](https://arxiv.org/abs/2405.17428).