Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/sea-snell/micro_config
an opinionated config framework for deep learning based on hierarchical python dataclasses
https://github.com/sea-snell/micro_config
config dataclasses deep-learning python pytorch
Last synced: about 1 month ago
JSON representation
an opinionated config framework for deep learning based on hierarchical python dataclasses
- Host: GitHub
- URL: https://github.com/sea-snell/micro_config
- Owner: Sea-Snell
- License: mit
- Created: 2022-05-24T00:33:24.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2022-07-10T01:06:45.000Z (over 2 years ago)
- Last Synced: 2024-10-11T04:21:37.642Z (about 1 month ago)
- Topics: config, dataclasses, deep-learning, python, pytorch
- Language: Python
- Homepage:
- Size: 4.44 MB
- Stars: 9
- Watchers: 4
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# micro_config.py
*an opinionated python dataclass config framework for deep learning*I hope this approach to configurations can make your life easier, but if it doesn't, please submit a pull request or an issue and I'll see what I can do. This config system is certaintly under development, so open to any new ideas or suggestions.
# Installation
pip install:
``` shell
pip install micro-config
```or install from source:
> Place `micro_config.py` at the root of your project.
# Repo Guide
The config framework is defined in `micro_config.py`.
The rest of the repo provides a demo for how one might actually want to use `micro_config.py` in a pytorch deep learning project. Specifically, I implement transformer language model training on wikitext in pytorch. For another example of how to use `micro_config.py` in jax (flax or haiku), see this [jax_v_pytorch repo](https://github.com/Sea-Snell/jax_v_pytorch/tree/main).
To run the demo:
1. navigate to the root directory
2. `pip install -r requirements.txt`
3. `export PYTHONPATH="$PWD"`
4. `cd scripts`
5. `python train_lm.py`Optionally you can define commandline arguments to `train_lm.py` like:
``` shell
python train_lm.py epochs=1 bsize=16 model.transformer_config.hidden_dim=256
```overview of demo project code:
* `scripts/train_lm.py` defines the training configuration and script execution.
* `base_config.py` defines config schema and defaults for all main config objects: `WikiDataConfig`, `TransformerConfig`, `LMModelConfig`, `AdamWConfig`
* `general_train_loop.py` defines the config schema and script for training models.
* `src/` defines all of the core demo project code.# Quick Start / Walkthrough
*Most demo code in this section is adopted from the demo project provided in the repo.*
## Python dataclasses provide a more natural and flexible config definition interface than `.yaml` files.
* All config schema should be defined as an instance of `ConfigScript` or `ConfigScriptModel` and include a `@dataclass` decorator
* ConfigScripts firstly define a parameter schema and optionally default config values.For example, a simple dataset object configuration:
``` python
from dataclasses import dataclass, adsict
from micro_config import ConfigScript# data config
@dataclass
class WikiDataConfig(ConfigScript):
f_path: str='data/wikitext-2-raw/wiki.train.raw'
max_len: int=256
```## `ConfigScript`s load associated objects or functions.
* To do this, all `ConfigScript`s implement `unroll(self, metaconfig: MetaConfig)`.
* The `metaconfig` parameter is another dataclass which specifies configs for the config framework. Feel free to subclass `MetaConfig`.For example, loading the dataset from the config:
``` python
from dataclasses import dataclass, adsict
from micro_config import ConfigScript, MetaConfig
from src.data import WikitextDataset
import torch
import os# data config
@dataclass
class WikiDataConfig(ConfigScript):
f_path: str='data/wikitext-2-raw/wiki.train.raw'
max_len: int=256def unroll(self, metaconfig: MetaConfig):
# metaconfig.convert_path converts paths reletive to metaconfig.project_root into absolute paths
return WikitextDataset(metaconfig.convert_path(self.f_path), self.max_len)if __name__ == "__main__":
metaconfig = MetaConfig(project_root=os.path.dirname(__file__),
verbose=True)
data_config = WikiDataConfig(max_len=512)
data = data_config.unroll(metaconfig)
```## Configurations can be defined hierarchically.
* You can define `ConfigScripts` as paremeters of other `ConfigScripts`
* You can define lists or dictionaries of `ConfigScript`s as parameters of a `ConfigScript` by wrapping your list or dict in `ConfigScriptList` or `ConfigScriptDict` respectively.For example, the LM model config below defines `ConfigScript`s for both a dataset and a `transformer_config` as parameters:
``` python
from micro_config import MetaConfig
from base_configs import ConfigScriptModel
from dataclasses import field
from src.lm import LMModel
import os# model config
@dataclass
class LMModelConfig(ConfigScriptModel):
dataset: WikiDataConfig=field(default_factory=lambda: WikiDataConfig())
transformer_config: TransformerConfig=field(default_factory=lambda: TransformerConfig(max_len=256))def unroll(self, metaconfig: MetaConfig):
dataset = self.dataset.unroll(metaconfig)
transformer_config = self.transformer_config.unroll(metaconfig)
return LMModel(dataset, transformer_config, self.device)if __name__ == "__main__":
metaconfig = MetaConfig(project_root=os.path.dirname(__file__),
verbose=True)model_config = LMModelConfig(
checkpoint_path=None,
strict_load=True,
device='cpu',
dataset=WikiDataConfig(f_path='data/wikitext-2-raw/wiki.train.raw', max_len=256),
transformer_config=TransformerConfig(
max_length=256,
heads=12,
hidden_dim=768,
attn_dim=64,
intermediate_dim=3072,
num_blocks=12,
dropout=0.1
)
)
model = model_config.unroll(metaconfig)
````ConfigScriptModel` (not provided with `micro_config` out of the box), as used above, is a subclass of `ConfigScript` which defines some default functionality for loading a pytorch module returned by unroll and placing it on a specified device. You can look inside `base_configs.py` (or in the [jax_v_pytorch repo](https://github.com/Sea-Snell/jax_v_pytorch/tree/main)) to see examples of how to implement special functionality like this.
## Configs and scripts are unified: a config is to a script as a script is to a config.
* `unroll(self, metaconfig: MetaConfig)` can not only be used to load objects, but also to define script logic.For example, let's define a simple configurable training loop:
``` python
from src.utils import combine_logs
from micro_config import ConfigScript, MetaConfig
from base_configs import ConfigScriptModel@dataclass
class TrainLoop(ConfigScript):
train_dataset: ConfigScript
eval_dataset: ConfigScript
model: ConfigScriptModel
optim: ConfigScript
epochs: int=10
bsize: int=32
def unroll(self, metaconfig: MetaConfig):
print('using config:', asdict(self))
device = metaconfig.device
train_dataset = self.train_dataset.unroll(metaconfig)
eval_dataset = self.eval_dataset.unroll(metaconfig)
model = self.model.unroll(metaconfig)
model.train()
train_dataloader = DataLoader(train_dataset, batch_size=self.bsize)
eval_dataloader = DataLoader(eval_dataset, batch_size=self.bsize)
optim = self.optim.unroll(metaconfig)(model)
for epoch in range(epochs):
for x in tqdm(train_dataloader):
loss, logs = model.get_loss(x.to(device))
optim.zero_grad()
loss.backward()
optim.step()
model.eval()
val_x = next(iter(eval_dataloader))
_, val_logs = model.get_loss(val_x.to(device))
out_log = print({'train': combine_logs([logs]), 'val': combine_logs([val_logs]), 'step': (step+1)})
model.train()
return model
```## Objects returned by `unroll(self, metaconfig: MetaConfig)` respect the reference structure of the config hierarchy.
* If the same config object is referenced multiple times in a config hierarchy, the object's `unroll(self, metaconfig: MetaConfig)` method will only be called once and its output cached, subsequent calls will return the cached output. If you don't want this caching behavior, you can subclass `ConfigScriptNoCache` instead.
For example, `train_dataset` is referenced twice in `train_config_script`:
``` python
import torch
import ostrain_dataset = WikiDataConfig(f_path='data/wikitext-2-raw/wiki.train.raw', max_len=256)
eval_dataset = WikiDataConfig(f_path='data/wikitext-2-raw/wiki.valid.raw', max_len=256)model = LMModelConfig(
checkpoint_path=None,
strict_load=True,
device='cpu',
dataset=train_dataset,
transformer_config=TransformerConfig(
max_length=256,
heads=12,
hidden_dim=768,
attn_dim=64,
intermediate_dim=3072,
num_blocks=12,
dropout=0.1
)
)train_config_script = TrainLoop(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
model=model,
optim=AdamWConfig(lr=1e-4, weight_decay=0.01),
epochs=10,
bsize=16,
)if __name__ == "__main__":
metaconfig = MetaConfig(project_root=os.path.dirname(__file__),
verbose=True)
# run the script
train_config_script.unroll(metaconfig)
```The dataset object configured by `train_dataset` will only be loaded once in the above hiararchy, even though both `LMModelConfig` and `TrainLoop` take it in as input.
## A method for parsing commandline args is provided.
* `parse_args(config)` parses the command line arguments into a dictionary
* `deep_replace(config, **kwargs)` implements a nested version of the standard `dataclasses.replace` function``` python
from micro_config import parse_args, deep_replace, MetaConfig
import osif __name__ == "__main__":
metaconfig = MetaConfig(project_root=os.path.dirname(__file__),
verbose=True)
train_config_script = deep_replace(train_config_script, **parse_args())
# run the script
train_config_script.unroll(metaconfig)
```To edit any arguments in the hierarchy through the commandline, call the script like so:
``` shell
python train_lm.py epochs=1 bsize=16 model.transformer_config.hidden_dim=256
```