https://github.com/akira4o4/xtrainer
PyTorch CV Trainer
https://github.com/akira4o4/xtrainer
classification computer-vision framework inference linux macos mlflow multitask python pytorch segmentation trainer windows
Last synced: 2 months ago
JSON representation
PyTorch CV Trainer
- Host: GitHub
- URL: https://github.com/akira4o4/xtrainer
- Owner: akira4O4
- Created: 2024-07-09T10:10:29.000Z (over 1 year ago)
- Default Branch: v3.0
- Last Pushed: 2024-11-08T10:00:01.000Z (over 1 year ago)
- Last Synced: 2025-07-08T23:07:22.525Z (9 months ago)
- Topics: classification, computer-vision, framework, inference, linux, macos, mlflow, multitask, python, pytorch, segmentation, trainer, windows
- Language: Python
- Homepage:
- Size: 467 KB
- Stars: 1
- Watchers: 1
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
#
xTrainer
Version: 1.0.0-dev
Author: Li Linfeng
**
Language: [English](README_EN.md)**
---
## 简介
这是一个基于PyTorch的**分类**,**分割**,**多任务**的训练和推理框架,支持PyTorch原生模型和自定义模型,设计多种优化方案,模块化的设计了整个训练推理框架
---
## 特点
- 支持`Linux`,`Windows`,`Macos`
- 支持`分类`任务,`分割`任务,`多任务`
- 支持`训练`和`推理`
- 支持训练数据`预加载`
- 支持`自定义模型`
- 支持`MLflow`参数跟踪
- 支持模块化添加Loss
- 分类任务中数据`平衡采样`
- 可视化`混淆矩阵`
- 自动平衡多任务训练中数据量
---
## 安装
```bash
cd
git clone https://github.com/akira4O4/xTrainer.git
cd xTrainer
pip install -r requirements.txt
```
---
## 配置文件
### 默认配置文件路径
`xTrainer/configs/default.yaml`
### 配置参数解析
| 参数名字 | 默认值 | 数据类型 | 描述 |
|-------------------------|-------------|-------------|------------------------------------------------------------------------------------------------|
| `mode` | `train` | `str` | 运行模式
训练:`train`
测试:`test` |
| `task` | | `str` | 任务类型
分类:classification
分割:segmentation
多任务:multitask |
| `project` | | `str` | 项目路径 |
| `experiment` | `exp` | `str` | 每次实验名称 |
| `seed` | `0` | `int` | 随机种子 |
| `topk` | `[1,2]` | `List[int]` | 分类topk范围 |
| `device` | `0` | `int` | 模型运行设备 |
| `epochs ` | `100` | `int` | 最大轮训次数 |
| `workers` | ` 0` | `int` | dataloader多进程数 |
| `not_val` | `False` | `bool` | 是否进行验证,True:只训练不验证 |
| `model` | | `str` | 模型名称 |
| `pretrained` | ` True` | `bool` | 是否加载预训练模型,模型来自Pytorch Hub |
| `weight` | | `str` | 预训练模型路径,来自本地模型 |
| `wh` | `[256,256]` | `List[int]` | 输入图像宽高 |
| `amp` | `True` | `bool` | 是否使用自动混合精度进行训练 |
| `cache` | `False` | `bool` | 是否使用数据预加载
开启后程序会提前**全部**加载所有数据 |
| `deterministic` | `True` | `bool` | 用于启用确定性模式 |
| `save_period` | `5` | `int` | 每训练x次就进行一次模型保存 |
| `classification.batch` | | `int` | 分类任务的batch数 |
| `classifiction.classes` | | `int` | 分类任务的类别数 |
| `classification.train ` | | `str` | 分类任务的训练数据路径 |
| `classification.val` | | `str` | 分类任务的验证数据路径 |
| `segmentation.batch` | | `int` | 分割任务的batch数 |
| `segmentation.classes` | | `int` | 分割任务的类别数(**分割类别数需要包含背景**) |
| `segmentation.train ` | | `str` | 分割任务的训练数据路径 |
| `segmentation.val ` | | `str` | 分割任务的验证数据路径 |
| `optimizer` | `auto` | `str` | 优化器名称
auto="AdamW"
支持优化器=["Adam", "Adamax", "AdamW", "NAdam", "RAdam","RMSProp","SGD"] |
| `cos_lr ` | `False` | `bool` | 是否使用余弦退火学习率 |
| `lr0` | `0.001` | `float` | 初始学习率 |
| `lrf` | `0.01` | `float` | 最低学习率下降比例,最低学习率=`lr0*lrf ` |
| `momentum` | `0.937` | `float` | 优化器冲量 |
| `alpha` | `auto` | `List[int]` | Focal Loss参数 |
| `gamma` | `2.0` | `float` | Focal Loss参数 |
| `smooth` | `1.0` | `float` | 分割loss中的稳定参数
极小目标:1e-6
正常目标:1.0 |
| `loss_sum_weights` | `[1,1]` | `List[int]` | 多任务中,分类loss于分割loss加权比例 |
| `seg_loss_sum_weights` | `[1,1,1]` | `List[int]` | 多个分割loss中的加权比例 |
| `source` | | `str` | 测试数据路径 |
| `test_weight` | | `str` | 测试权重路径 |
| `cls_thr` | | `List[int]` | 分类任务阈值 |
| `seg_thr` | | `List[int]` | 分割任务阈值(**不需要包含背景**) |
| `mlflow_url` | `localhost` | `str` | mlflow URI |
| `mlflow_port` | `5000 ` | `int` | mlflow端口 |
---
## 如何使用自定义模型
---
## 训练数据格式
### 分类任务
### 分割任务
### 多任务
---
## 训练
---
## 预测
---
## 导出ONNX、TorchScript