https://github.com/jmaczan/darts-toolkit
Differentiable Architecture Search Toolkit in PyTorch Lightning
https://github.com/jmaczan/darts-toolkit
artificial-intelligence darts deep-learning differentiable-architecture-search machine-learning nas neural-architecture-search pc-darts pcdarts pytorch pytorch-lightning research
Last synced: about 1 month ago
JSON representation
Differentiable Architecture Search Toolkit in PyTorch Lightning
- Host: GitHub
- URL: https://github.com/jmaczan/darts-toolkit
- Owner: jmaczan
- License: gpl-3.0
- Created: 2024-10-15T09:19:25.000Z (12 months ago)
- Default Branch: main
- Last Pushed: 2024-12-17T12:32:26.000Z (10 months ago)
- Last Synced: 2025-09-05T06:49:26.690Z (about 1 month ago)
- Topics: artificial-intelligence, darts, deep-learning, differentiable-architecture-search, machine-learning, nas, neural-architecture-search, pc-darts, pcdarts, pytorch, pytorch-lightning, research
- Language: Python
- Homepage:
- Size: 228 KB
- Stars: 2
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
- Citation: CITATION.cff
Awesome Lists containing this project
README
# `darts-toolkit`
Differentiable Architecture Search Toolkit in PyTorch Lightning
> [!TIP]
> Boost your research and use solid engineering practices out-of-the-boxUse this toolkit to:
- Research your own DARTS algorithm with pre-built components and create your own components
- Use existing DARTS architectures, like [Partially-Connected](https://arxiv.org/abs/1907.05737) [Differentiable Architecture Search](https://arxiv.org/abs/1806.09055)
- Configure hyperparameters with `yaml` files
- Scale to multiple GPUs with no effort
- Visualize your neural network architecture## Examples
#### Find a network architecture for image recognition
```py
from darts_toolkit.models import LPCDARTSLightningModule
from darts_toolkit.data import CIFAR10DataModule
from darts_toolkit.utils.yaml import load_config
import yaml# Load configuration
config = load_config(os.path.join("src", "config.yaml"))# Create data module
data_module = CIFAR10DataModule(config)# Create model
model = LPCDARTSLightningModule(config)# Search phase
search_model = LPCDARTSLightningModule(config)
search_trainer = pl.Trainer(
max_epochs=config["training"]["max_epochs"],
accelerator="gpu" if config["training"].get("gpus") else "auto",
devices=config["training"].get("gpus") or "auto",
callbacks=[RichProgressBar()],
logger=TensorBoardLogger(
config["logging"]["log_dir"],
name=f"{config['logging']['experiment_name']}_search",
),
)# Train the search model
search_trainer.fit(search_model, data_module)# Test the search model
search_trainer.test(search_model, datamodule=data_module)
```#### Train a derived architecture
```py
# Derive and train the final architecture
derived_architecture = search_model.derive_architecture()
derived_model = DerivedPCDARTSModel(
derived_architecture=derived_architecture, config=config
)derived_trainer = pl.Trainer(
max_epochs=config["training"]["derived_epochs"],
accelerator="gpu" if config["training"].get("gpus") else "auto",
devices=config["training"].get("gpus") or "auto",
callbacks=[ModelCheckpoint(monitor="val_acc", mode="max"), RichProgressBar()],
logger=TensorBoardLogger(
config["logging"]["log_dir"],
name=f"{config['logging']['experiment_name']}_derived",
),
)# Train the derived model
derived_trainer.fit(
derived_model,
train_dataloaders=data_module.train_dataloader()["train"],
val_dataloaders=data_module.val_dataloader(),
)# Test the derived model
derived_trainer.test(derived_model, datamodule=data_module)
```## Install
Using pip:
```sh
pip install git+https://github.com/jmaczan/darts-toolkit.git
```Using uv:
```sh
uv pip install git+https://github.com/jmaczan/darts-toolkit.git
```## Install (for development)
```sh
git clone https://github.com/jmaczan/darts-toolkit.git
cd darts-toolkit# Install using uv (recommended)
uv pip install -e .# Or install using pip
pip install -e .
```## Prerequisities
This project uses [uv](https://docs.astral.sh/uv/getting-started/installation/) for package management
Also, it uses [Ruff](https://docs.astral.sh/ruff/installation/) for formatting if you run the project in VS Code. You can install Ruff plugin by Astral Software from extensions marketplace and you're good to go
```sh
uv sync
```## Run
```sh
uv run python -m src.models.lightning_pc_darts
```## Cite
If you use this software in your research, please use the following citation:
```bibtex
@software{Maczan_PCDARTS_2024,
author = {Maczan, Jędrzej Paweł},
title = {Differentiable Architecture Search Toolkit in PyTorch Lightning},
url = {https://github.com/jmaczan/darts-toolkit},
year = {2024},
publisher = {GitHub}
}
```## License
GNU GPLv3
## Author
Jędrzej Maczan, 2024