Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/xavierjiezou/litmnist

Deep learning project template based on PyTorch Lightning and Hydra.
https://github.com/xavierjiezou/litmnist

config deep-leaning hydra image-classification mnist python pytorch pytorch-lightning template yaml

Last synced: 2 months ago
JSON representation

Deep learning project template based on PyTorch Lightning and Hydra.

Awesome Lists containing this project

README

        

[![logo](https://raw.githubusercontent.com/XavierJiezou/LitMNIST/main/images/logo.png)](https://pixelied.com/editor/design/6282f5970515730397249959)

# LitMNIST

基于 [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) + [Hydra](https://github.com/facebookresearch/hydra) 的深度学习项目模板。

*(以 [MNIST](http://yann.lecun.com/exdb/mnist/) 分类任务为例)*

点击 [Use this template](https://github.com/XavierJiezou/LitMNIST/generate) 即可使用该模板来初始化你的新仓库。



GitHub Workflow Release Status


GitHub Workflow Test Status


GitHub Workflow Lint Status



Codacy Badge


codecov


PyPI


PyPI - Downloads



GitHub stars


GitHub forks


GitHub issues


GitHub license





Python


PyTorch


Lightning


Config: hydra


Code style: black


观看演示

报告错误

功能需求


English

简体中文

喜欢这个项目吗?请考虑捐赠([微信](https://raw.githubusercontent.com/XavierJiezou/ys-dl/main/image/wechat.jpg) | [支付宝](https://raw.githubusercontent.com/XavierJiezou/ys-dl/main/image/alipay.jpg)),以帮助它改善!

## 目录

LitMNIST

- [演示](#演示)
- [安装](#安装)
- [运行](#运行)
- [结构](#结构)
- [用法](#用法)
- [基础](#基础)
- [进阶](#进阶)
- [提示](#提示)
- [更新](#更新)
- [证书](#证书)
- [参考](#参考)

## 演示

![demo](https://raw.githubusercontent.com/XavierJiezou/LitMNIST/main/images/demo.jpg)

## 安装

> 开始之前,你必须熟练使用 [PyTorch Lightning](https://www.pytorchlightning.ai/),并对 [Hydra](https://hydra.cc/) 有一定的了解。

1. 克隆仓库到本地

```bash
git clone https://github.com/XavierJiezou/LitMNIST.git
cd LitMNIST
```

2. 创建并激活 conda 虚拟环境

```bash
conda create -n myenv python=3.8
conda activate myenv
```

3. 安装项目依赖包(如需安装 GPU 版 PyTorch,请参考[官网安装教程](https://pytorch.org/get-started/))

```bash
pip install -r requirements.txt
```

## 运行

### CPU

```bash
python train.py
```

### GPU

```bash
python train.py \
trainer.gpus=4 \
+trainer.strategy=ddp_find_unused_parameters_false \
datamodule.num_workers=16 \
datamodule.pin_memory=True \
datamodule.persistent_workers=True
```

## 结构

项目的主要目录结构如下:

```bash
├── configs # 存放 Hydra 配置文件
│ ├── callbacks # Callbacks 配置(例如 EarlyStopping、ModelCheckpoint 等)
│ ├── datamodule # Datamodule 配置(例如 batch_size、num_workers 等)
│ ├── debug # 调试配置
│ ├── experiment # 实验配置
│ ├── hparams_search # 超参数搜索配置
│ ├── local # 本地配置(暂时可以忽略)
│ ├── log_dir # 日志存放目录配置
│ ├── logger # 日志配置
│ ├── model # 模型配置
│ ├── trainer # Trainer 配置
│ │
│ ├── test.yaml # 测试的主要配置
│ └── train.yaml # 训练的主要配置

├── data # 存放项目数据

├── logs # 存放项目日志(Hydra 日志 和 PyTorch Lightning loggers 生成的日志)

├── src # 项目源代码
│ ├── datamodules # LightningDataModule
│ ├── models # 存放基于原生 PyTorch 框架编写的模型
│ ├── litmodules # LightningModule
│ ├── utils # 存放一些实用的脚本(例如数据预处理的脚本)
│ │
│ ├── testing_pipeline.py # 测试流水线(实例化对象)
│ └── training_pipeline.py # 训练流水线(实例化对象)

├── tests # 单元测试(可选)

├── test.py # 开始测试(加载配置文件)
├── train.py # 开始训练(加载配置文件)

├── .env # 存储私有环境变量(例如 wandb 的 API_KEY)【注意:该文件不受版本控制】
├── .gitignore # 设置版本控制需要排除的文件或目录(例如 .env 文件)
├── requirements.txt # 项目依赖环境(pip install -r requirements.txt)
└── README.md # 项目概述文档
```

## 用法

本仓库是一个基于 [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) + [Hydra](https://github.com/facebookresearch/hydra) 的深度学习项目模板。因此你仅需要套用该模板,并作出如下修改:

1. 编写你自己的 PyTorch `nn.Module` 模型(参见 [src/models/simple_densenet.py](/src/models/simple_densenet.py))
2. 编写你自己的 PyTorch Lightning `LightningModule`(参见 [src/litmodules/mnist_litmodule.py](/src/litmodules/mnist_litmodule.py))
3. 编写你自己的 PyTorch Lightning `LightningDataModule`(参见 [src/datamodules/mnist_datamodule.py](/src/datamodules/mnist_datamodule.py))
4. 编写你自己的实验配置文件(参见 [configs/experiment/example.yaml](/configs/experiment/example.yaml))
5. 使用选定的实验配置运行训练代码:`python train.py experiment=`

### 基础

[train.py](/train.py) 集成了模型**训练**、**验证**及**测试**的一整套工作流,安装好环境后,运行即可:

```bash
python train.py
```

[test.py](/test.py) 仅包含**测试**步骤,允许你单独加载预训练模型进行测试(但要指定模型路径):

```bash
python test.py ckpt_path=checkpoints/last.ckpt
```

### 进阶

- 从命令行覆盖任何配置参数

> Hydra 允许你轻松覆盖配置文件中定义的任何参数。

`train.py` 默认从 [configs/train.yaml](configs/train.yaml) 中获取参数。因此,你可以先修改 `yaml` 配置文件中的参数,然后再运行。

或者,你也可以在命令行中直接指定参数。命令行中参数的优先级要大于 `yaml` 配置文件中参数的优先级。

```bash
python train.py trainer.max_epochs=3
```

对于某些不太重要的参数,它们没有在 `yaml` 配置文件中定义,因此你在命令行中指定的时候必须添加 `+`:

```bash
python train.py +trainer.precision=16
```

- 在 CPU、GPU、多 GPU 和 TPU 上训练

> PyTorch Lightning 使得在不同硬件上训练模型变得容易。

在 CPU 上训练

```bash
python train.py trainer.gpus=0
```

在 GPU 上训练

```bash
python train.py trainer.gpus=1
```

在 TPU 上训练

```bash
python train.py +trainer.tpu_cores=8
```

基于 DDP(Distributed Data Parallel,分布式数据并行)的训练【4 个 GPU】

```bash
python train.py trainer.gpus=4 +trainer.strategy=ddp
```

基于 DDP(Distributed Data Parallel,分布式数据并行)的训练【8 个 GPU,两个节点】

```bash
python train.py trainer.gpus=4 +trainer.num_nodes=2 +trainer.strategy=ddp
```

- 混合精度训练

> PyTorch Lightning 允许你使用半精度或混合精度以减少训练期间的内存占用。(在 GPU 上能够实现 3 倍的加速效果,但可能损失精度)

```bash
python train.py trainer.gpus=1 +trainer.precision=16
```

- 使用 PyTorch Lightning 中的日志记录器来记录训练日志

> PyTorch Lightning 集成了多种主流日志记录框架,包括 TensorBoard 和 Weights&Biases等。

这里以 [wandb](https://wandb.ai/) 为例,展示如何使用:

1. 安装 `wandb`

```bash
pip install wandb
```

2. 转到 [wandb.ai/authorize](https://wandb.ai/authorize) 获取 **API key**

3. 执行 `login` 命令(需要用到上一步获取的 **API key**)

```bash
wandb login
```

4. 在 [configs/logger/](/configs/logger/) 目录下新建一个名为 `wandb.yaml` 的文件,并写入以下内容

```yaml
wandb:
_target_: pytorch_lightning.loggers.wandb.WandbLogger
project: "mnist"
```

5. 执行训练代码的时候指定 `logger` 为 `wandb`

```bash
python train.py logger=wandb
```

- 根据自定义实验配置来训练模型

> 配置文件见 [configs/experiment/](/configs/experiment/)

```bash
python train.py experiment=example
```

- 带回调函数的训练

> 配置文件见 [configs/callbacks/](/configs/callbacks/)

```python
python train.py callbacks=default
```

- 使用 Pytorch Lightning 中的训练策略

> 点击[这里](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html)了解 Pytorch Lightning 中的各种训练策略

梯度裁剪来避免梯度爆炸

```bash
python train.py +trainer.gradient_clip_val=0.5
```

随机加权平均可以使您的模型更好地泛化

```bash
python train.py +trainer.stochastic_weight_avg=true
```

梯度累积

```bash
python train.py +trainer.accumulate_grad_batches=10
```

- 轻松调试

> 配置文件见 [configs/debug/](/configs/debug/)

默认调试模式(运行 1 个 epoch)

```bash
python train.py debug=default
```

仅对 test epoch 进行调试

```bash
python train.py debug=test_only
```

执行一次 train,val 和 test 步骤(仅使用 1 个 batch)

```bash
python train.py +trainer.fast_dev_run=true
```

训练完成后打印各个阶段的执行时间(用于快速发现训练瓶颈)

```bash
python train.py +trainer.profiler="simple"
```

- 断点续训

```bash
python train.py trainer.resume_from_checkpoint="/path/to/name.ckpt"
```

- 一次执行多个实验

例如,下方代码将按顺序运行所有参数组合(共 6 个)的实验。

```bash
python train.py -m datamodule.batch_size=32,64,128 litmodule.lr=0.001,0.0005
```

此外,你也可以执行 [/configs/experiment/](/configs/experiment/) 目录下的的所有实验

```bash
python train.py -m 'experiment=glob(*)'
```

- 使用 Optuna 进行超参数搜索

> [Optuna Sweeper plugin | Hydra](https://hydra.cc/docs/next/plugins/optuna_sweeper/)

1. 安装 **hydra-optuna-sweeper** 插件

```bash
pip install hydra-optuna-sweeper
```

2. 修改 [configs/hparams_search/](/configs/hparams_search/) 目录下的配置文件

3. 执行训练代码的时候指定 `hparams_search`

```bash
python train.py -m hparams_search=mnist_optuna
```

- 使用 Tab 键智能提示可选配置参数

> [Tab completion | Hydra](https://hydra.cc/docs/tutorials/basic/running_your_app/tab_completion/)

```bash
$ eval "$(python train.py -sc install=bash)" # 安装
$ python train.py logger= # 按下 Tab 键后会智能提示有哪些可选参数
logger=comet logger=csv logger=many_loggers logger=mlflow logger=neptune logger=tensorboard logger=wandb
```

## 提示

- 在 `.env` 文件中设置私有环境变量

1. 例如,你可以将 [comet](https://www.comet.ml/) 的 **API Key** 添加到 `.env` 文件中

```bash
COMET_API_KEY="xxx"
```

2. 并在配置文件 [configs/logger/comet.yaml](/configs/logger/comet.yaml) 中进行调用

```yaml
comet:
_target_: pytorch_lightning.loggers.comet.CometLogger
api_key: ${oc.env:COMET_API_KEY}
```

3. 在训练的时候指定 `logger` 参数为 `comet`

```bash
python train.py logger=comet
```

注意:`.env` 文件不应受版本控制,因此我们已将其添加到 `.gitignore` 文件中了。

- 使用 [torchmetrics](https://github.com/PytorchLightning/metrics) 计算指标

推荐使用 PytorchLightning 官方提供的 [torchmetrics](https://github.com/PytorchLightning/metrics) 库来计算指标(像准确率,F1 score 和混淆矩阵等)。这对于多 GPU 训练尤为重要!并且,推荐对每个步骤使用不同的指标实例,以确保所有 GPU 进程都有正确的累积值。下面给出了一个简单示例。

```python
from torchmetrics.classification.accuracy import Accuracy

class LitModel(LightningModule):
def __init__(self)
self.train_acc = Accuracy()
self.val_acc = Accuracy()

def training_step(self, batch, batch_idx):
...
acc = self.train_acc(predictions, targets)
self.log("train/acc", acc)
...

def validation_step(self, batch, batch_idx):
...
acc = self.val_acc(predictions, targets)
self.log("val/acc", acc)
...
```

- 可以使用 [DVC](https://github.com/iterative/dvc) 对数据和模型这些大文件进行版本控制

```bash
dvc init
dvc add data/MNIST
git add data/MNIST.dvc data/.gitignore
git commit -m "Add raw data"
```

## 更新

见 [CHANGELOG.md](/CHANGELOG.md)

## 证书

[MIT License](/License)

## 参考

此模板引用了以下仓库并进行了一些细微的修改。

[![Readme Card](https://github-readme-stats.vercel.app/api/pin/?username=ashleve&repo=lightning-hydra-template)](https://github.com/ashleve/lightning-hydra-template)