https://github.com/vchitect/optix
Memory Efficient Training Framework for Large Video Generation Model
https://github.com/vchitect/optix
Last synced: about 1 year ago
JSON representation
Memory Efficient Training Framework for Large Video Generation Model
- Host: GitHub
- URL: https://github.com/vchitect/optix
- Owner: Vchitect
- License: apache-2.0
- Created: 2024-01-26T09:21:55.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2024-04-12T07:55:08.000Z (about 2 years ago)
- Last Synced: 2024-04-12T15:14:01.874Z (about 2 years ago)
- Language: Python
- Size: 97.7 KB
- Stars: 17
- Watchers: 4
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Optix
Optix: Memory Efficient Training Framework for Large Video Generation Model
**Update**:
- support training [LATTE](https://github.com/Vchitect/Latte) with 320 frames of 512*512 video on A100 without sequence parallel; See [Latte training example](./example/train_latte_t2v.py)
**Features**:
- 4x the batchsize when training with high resolution images
- average 1.2x the training throughput
- Optix remains effective in DiT model training!
Results of training stable-diffusion models:


*Basline config*:tf32, grad checkpointing
*Tested on*:A100 80GB; Pytorch2.1.2+cu118
## Getting started
### install optix
`python setup.py develop`
### install dependency
Refer to [requirements](./requiresments.txt)
## API Usage
```py
import optix
# optimize model(fusedops, ddp, etc), setup optimizer, and create a ema
model, vae, opt, ema = optix.compile(model, vae, learning_rate=1e-5, weight_decay=1e-5, use_ema=True)
# or do not create ema:
model, vae, opt, _ = optix.compile(model, vae, learning_rate=1e-5, weight_decay=1e-5)
# use `sliced_vae` to replace the original vae.encode codes:
# with torch.no_grad():
# x = vae.encode(x)
# if not args.use_video:
# x = x.latent_dist.sample().mul_(vae.config.scaling_factor)
model_input = optix.sliced_vae(vae, model_input, use_autocast=True, nhwc=True)
```
Keyword arguments for `optix.compile` and default value:
```py
{
'use_ema': False, # create ema
'compile_vae': True, # [PERF] for torch>2.0, recommended to use torch.compile
'ddp': True, # automatically create a ddp module over unet
'dp_group': None, # ddp communication group, default is None
'gradient_checkpointing': True, # [PERF] grad_ckpt is ON by default; for small batchsize this can be turned off for speedup
'xformer': True, # [PERF] use xformer can speedup a little bit
'fusedln': True, # [PERF] use fusedln can speedup
'compile_unet': False, # [PERF] this function is not stable so OFF by default
'vae_channels_last': True, # [PERF] use channels_last format for vae
'optim': 'adamw', # the optimizer type
'learning_rate': 1e-5, # optimizer params
'weight_decay': 0, # optimizer params
'hybrid_zero': True, # [PERF] for multi node training, hybrid zero can be faster
}
```
These `Keyword arguments` can be directly passed to `optix.compile` like:
```py
model, vae, opt, _ = optix.compile(model, vae, learning_rate=1e-5, weight_decay=1e-5,
use_ema=False, compile_vae=False, optim='sgd',
xformer=False)
```
## Examples
Stable Diffusion: [train_sd_unet.py](./example/train_sd_unet.py)
DiT:[train_dit.py](./example/train_dit.py)