{"id":23102660,"url":"https://github.com/dataxujing/detr_transformer","last_synced_at":"2025-08-16T14:33:01.132Z","repository":{"id":112305121,"uuid":"289813332","full_name":"DataXujing/detr_transformer","owner":"DataXujing","description":"transformer used in object detection [DETR训练自己的数据集]","archived":false,"fork":false,"pushed_at":"2020-08-24T03:08:29.000Z","size":699,"stargazers_count":48,"open_issues_count":5,"forks_count":10,"subscribers_count":1,"default_branch":"master","last_synced_at":"2025-06-08T02:07:33.357Z","etag":null,"topics":["object-detection","transformer"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/DataXujing.png","metadata":{"files":{"readme":"README.md","changelog":"change.py","contributing":".github/CONTRIBUTING.md","funding":null,"license":"LICENSE","code_of_conduct":".github/CODE_OF_CONDUCT.md","threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null}},"created_at":"2020-08-24T02:57:29.000Z","updated_at":"2025-01-21T08:53:57.000Z","dependencies_parsed_at":"2023-05-12T14:00:32.079Z","dependency_job_id":null,"html_url":"https://github.com/DataXujing/detr_transformer","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/DataXujing/detr_transformer","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/DataXujing%2Fdetr_transformer","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/DataXujing%2Fdetr_transformer/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/DataXujing%2Fdetr_transformer/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/DataXujing%2Fdetr_transformer/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/DataXujing","download_url":"https://codeload.github.com/DataXujing/detr_transformer/tar.gz/refs/heads/master","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/DataXujing%2Fdetr_transformer/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":270723411,"owners_count":24634375,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","status":"online","status_checked_at":"2025-08-16T02:00:11.002Z","response_time":91,"last_error":null,"robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":true,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["object-detection","transformer"],"created_at":"2024-12-17T00:00:17.092Z","updated_at":"2025-08-16T14:33:01.095Z","avatar_url":"https://github.com/DataXujing.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"## [DETR](https://github.com/facebookresearch/detr)训练自己的数据集\n\n**徐静**\n\n\n\n![](pic/DETR.png)\n\n### 0.需要的环境\n\n```shell\n# python \u003e= 3.5\ncython\nsubmitit\ntorch\u003e=1.4.0\ntorchvision\u003e=0.5.0\nscipy\nonnx\nonnxruntime\n# pip3 install torch==1.4.0+cu92 torchvision==0.5.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html\n\n```\n\n\n\n### 1.构建数据集\n\n我们的训练数据采用COCO数据的样式( 存放在项目目录下的myData文件夹下 )\n\n```shell\n./myData\n└─coco                # 数据集的名称\n    ├─annotations     # 标注的json文件，coco类型的标注\n      ├─instances_train.json\n      ├─instances_val.json\n    ├─train           # 训练图像的存放地址\n      ├─xxx.jpg\n    ├─val             # 验证图像的存放地址\n      └─xxxx.jpg             \n```\n\n### 2.修改部分代码\n\n**1.修改COCO的预训练的网络**\n\n运行`./change.py`修改COCO预训练网络节点的类别数量\n\n```shell\npython3 change.py\n# 在项目文件夹下生成detr_r50_{class_num}.pth\n```\n\n**2.修改`./datasets/coco.py`的build方法**\n\n```python\ndef build(image_set, args):\n    root = Path(args.coco_path)\n    assert root.exists(), 'provided COCO path {} does not exist'.format(root)\n    mode = 'instances'\n    PATHS = {\n        \"train\": (root / \"train\", root / \"annotations\" / '{}_train.json'.format(mode)),\n        \"val\": (root / \"val\", root / \"annotations\" / '{}_val.json'.format(mode)),\n    }\n\n    img_folder, ann_file = PATHS[image_set]\n    dataset = CocoDetection(str(img_folder), str(ann_file), transforms=make_coco_transforms(image_set), return_masks=args.masks)  # \u003c----------- 如果你是python3.5 需要str(img_folder),str(ann_file)\n    return dataset\n```\n\n**3.修改`./models/detr.py`中的build方法**\n\n```python\ndef build(args):\n    num_classes = 3+1   # \u003c---------------类别数 这里就是不包含background的类别数，需要+1!!!\n    if args.dataset_file == \"coco_panoptic\":  # 全景分割\n        num_classes = 3+1   # \u003c-------------\n    device = torch.device(args.device)\n\n    backbone = build_backbone(args)\n\n    transformer = build_transformer(args)\n\n    model = DETR(\n        backbone,\n        transformer,\n        num_classes=num_classes,\n        num_queries=args.num_queries,\n        aux_loss=args.aux_loss,\n    )\n    if args.masks:\n        model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))\n    matcher = build_matcher(args)\n    weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef}\n    weight_dict['loss_giou'] = args.giou_loss_coef\n    if args.masks:\n        weight_dict[\"loss_mask\"] = args.mask_loss_coef\n        weight_dict[\"loss_dice\"] = args.dice_loss_coef\n    # TODO this is a hack\n    if args.aux_loss:\n        aux_weight_dict = {}\n        for i in range(args.dec_layers - 1):\n            aux_weight_dict.update({k + '_{}'.format(i): v for k, v in weight_dict.items()})\n        weight_dict.update(aux_weight_dict)\n\n    losses = ['labels', 'boxes', 'cardinality']\n    if args.masks:\n        losses += [\"masks\"]\n    criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict,\n                             eos_coef=args.eos_coef, losses=losses)\n    criterion.to(device)\n    postprocessors = {'bbox': PostProcess()}\n    if args.masks:\n        postprocessors['segm'] = PostProcessSegm()\n        if args.dataset_file == \"coco_panoptic\":\n            is_thing_map = {i: i \u003c= 90 for i in range(201)}\n            postprocessors[\"panoptic\"] = PostProcessPanoptic(is_thing_map, threshold=0.85)\n\n    return model, criterion, postprocessors\n```\n\n\n\n**4.修改main.py文件**\n\n```python\ndef get_args_parser():\n    parser = argparse.ArgumentParser('Set transformer detector', add_help=False)\n    parser.add_argument('--lr', default=1e-4, type=float)   # \u003c----------lr\n    parser.add_argument('--lr_backbone', default=1e-5, type=float)  # \u003c-------lr_backbone\n    parser.add_argument('--batch_size', default=2, type=int)       #\u003c---------batch size\n    parser.add_argument('--weight_decay', default=1e-4, type=float)\n    parser.add_argument('--epochs', default=300, type=int)         # epoch\n    parser.add_argument('--lr_drop', default=200, type=int)\n    parser.add_argument('--clip_max_norm', default=0.1, type=float,\n                        help='gradient clipping max norm')\n\n    # Model parameters\n    parser.add_argument('--frozen_weights', type=str, default=None,\n                        help=\"Path to the pretrained model. If set, only the mask head will be trained\")\n    # * Backbone\n    parser.add_argument('--backbone', default='resnet50', type=str,\n                        help=\"Name of the convolutional backbone to use\")\n    parser.add_argument('--dilation', action='store_true',\n                        help=\"If true, we replace stride with dilation in the last convolutional block (DC5)\")\n    parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),\n                        help=\"Type of positional embedding to use on top of the image features\")\n\n    # * Transformer\n    parser.add_argument('--enc_layers', default=6, type=int,\n                        help=\"Number of encoding layers in the transformer\")\n    parser.add_argument('--dec_layers', default=6, type=int,\n                        help=\"Number of decoding layers in the transformer\")\n    parser.add_argument('--dim_feedforward', default=2048, type=int,\n                        help=\"Intermediate size of the feedforward layers in the transformer blocks\")\n    parser.add_argument('--hidden_dim', default=256, type=int,\n                        help=\"Size of the embeddings (dimension of the transformer)\")\n    parser.add_argument('--dropout', default=0.1, type=float,\n                        help=\"Dropout applied in the transformer\")\n    parser.add_argument('--nheads', default=8, type=int,\n                        help=\"Number of attention heads inside the transformer's attentions\")\n    parser.add_argument('--num_queries', default=100, type=int,   \n                        help=\"Number of query slots\")     # \u003c-----------每张图最多预测的box个数\n    parser.add_argument('--pre_norm', action='store_true')\n\n    # * Segmentation\n    parser.add_argument('--masks', action='store_true',\n                        help=\"Train segmentation head if the flag is provided\")\n\n    # Loss\n    parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',\n                        help=\"Disables auxiliary decoding losses (loss at each layer)\")\n    # * Matcher\n    parser.add_argument('--set_cost_class', default=1, type=float,\n                        help=\"Class coefficient in the matching cost\")\n    parser.add_argument('--set_cost_bbox', default=5, type=float,\n                        help=\"L1 box coefficient in the matching cost\")\n    parser.add_argument('--set_cost_giou', default=2, type=float,\n                        help=\"giou box coefficient in the matching cost\")\n    # * Loss coefficients\n    parser.add_argument('--mask_loss_coef', default=1, type=float)\n    parser.add_argument('--dice_loss_coef', default=1, type=float)\n    parser.add_argument('--bbox_loss_coef', default=5, type=float)\n    parser.add_argument('--giou_loss_coef', default=2, type=float)\n    parser.add_argument('--eos_coef', default=0.1, type=float,\n                        help=\"Relative classification weight of the no-object class\")\n\n    # dataset parameters\n    parser.add_argument('--dataset_file', default='coco')\n    parser.add_argument('--coco_path',default='myData/coco',type=str)   # \u003c-----------修改default\n    parser.add_argument('--coco_panoptic_path', type=str)\n    parser.add_argument('--remove_difficult', action='store_true')\n\n    parser.add_argument('--output_dir', default='',\n                        help='path where to save, empty for no saving')\n    parser.add_argument('--device', default='cuda',\n                        help='device to use for training / testing')\n    parser.add_argument('--seed', default=42, type=int)\n    parser.add_argument('--resume', default='', help='resume from checkpoint')\n    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',\n                        help='start epoch')\n    parser.add_argument('--eval', action='store_true')\n    parser.add_argument('--num_workers', default=2, type=int)\n\n    # distributed training parameters\n    parser.add_argument('--world_size', default=1, type=int,\n                        help='number of distributed processes')\n    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')\n    return parser\n```\n\n### 3.训练\n\n```shell\npython3 main.py --dataset_file \"coco\" --coco_path \"/myData/coco\" --epoch 500 --lr=1e-4 --batch_size=8 --num_workers=4 --output_dir=\"outputs\" --resume=\"detr_r50_4.pth\"\n```\n\n### 4.测试\n\n```shell\n# 单张图片的推断\npython3 inference_img.py\n```\n\n![](pic/3e325ba8174811ea919400e04c510bc1.jpg)\n\n```shell\n# 视频的推断\npython3 inference_video.py\n```\n\n```shell\n# 将训练的log画出来\npython3 ./util/plot_utils.py\n```\n\n![](pic/log.png)\n\n### 5.ONNX\n\n**TODO**","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdataxujing%2Fdetr_transformer","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fdataxujing%2Fdetr_transformer","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdataxujing%2Fdetr_transformer/lists"}