{"id":15038940,"url":"https://github.com/vainf/torch-pruning","last_synced_at":"2025-05-12T15:33:05.536Z","repository":{"id":37663431,"uuid":"228203350","full_name":"VainF/Torch-Pruning","owner":"VainF","description":"[CVPR 2023] DepGraph: Towards Any Structural Pruning","archived":false,"fork":false,"pushed_at":"2025-04-12T10:30:01.000Z","size":10489,"stargazers_count":2982,"open_issues_count":304,"forks_count":346,"subscribers_count":34,"default_branch":"master","last_synced_at":"2025-04-23T17:20:02.805Z","etag":null,"topics":["channel-pruning","cvpr2023","depgraph","efficient-deep-learning","model-compression","network-pruning","pruning","structural-pruning","structured-pruning"],"latest_commit_sha":null,"homepage":"https://arxiv.org/abs/2301.12900","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/VainF.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":"AUTHORS","dei":null,"publiccode":null,"codemeta":null,"zenodo":null}},"created_at":"2019-12-15T15:07:24.000Z","updated_at":"2025-04-23T07:44:18.000Z","dependencies_parsed_at":"2023-02-16T20:32:04.342Z","dependency_job_id":"7b0127fb-4ec9-4db7-a71d-6cf5f9c49e73","html_url":"https://github.com/VainF/Torch-Pruning","commit_stats":{"total_commits":1293,"total_committers":21,"mean_commits":61.57142857142857,"dds":0.4338747099767981,"last_synced_commit":"224d7f80d764f31f13ba3ce9c166f111e5c307bc"},"previous_names":[],"tags_count":35,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/VainF%2FTorch-Pruning","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/VainF%2FTorch-Pruning/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/VainF%2FTorch-Pruning/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/VainF%2FTorch-Pruning/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/VainF","download_url":"https://codeload.github.com/VainF/Torch-Pruning/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":253766117,"owners_count":21960847,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["channel-pruning","cvpr2023","depgraph","efficient-deep-learning","model-compression","network-pruning","pruning","structural-pruning","structured-pruning"],"created_at":"2024-09-24T20:40:52.071Z","updated_at":"2025-05-12T15:33:05.503Z","avatar_url":"https://github.com/VainF.png","language":"Python","readme":"\n\u003cdiv align=\"center\"\u003e\n\u003cimg src=\"https://user-images.githubusercontent.com/18592211/232830417-0b21a874-516e-4420-8984-4de414a35085.png\" width=\"400px\"\u003e\u003c/img\u003e\n\u003ch2\u003e\u003c/h2\u003e\n\u003ch3\u003eTowards Any Structural Pruning\u003ch3\u003e\n\u003cimg src=\"assets/intro.png\" width=\"50%\"\u003e\n\u003c/div\u003e\n\n\u003cp align=\"center\"\u003e\n  \u003ca href=\"https://github.com/VainF/Torch-Pruning/actions\"\u003e\u003cimg src=\"https://img.shields.io/badge/tests-passing-9c27b0.svg\" alt=\"Test Status\"\u003e\u003c/a\u003e\n  \u003ca href=\"https://pytorch.org/\"\u003e\u003cimg src=\"https://img.shields.io/badge/PyTorch-1.x %20%7C%202.x-673ab7.svg\" alt=\"Tested PyTorch Versions\"\u003e\u003c/a\u003e\n  \u003ca href=\"https://opensource.org/licenses/MIT\"\u003e\u003cimg src=\"https://img.shields.io/badge/License-MIT-4caf50.svg\" alt=\"License\"\u003e\u003c/a\u003e\n  \u003ca href=\"https://pepy.tech/project/Torch-Pruning\"\u003e\u003cimg src=\"https://static.pepy.tech/badge/Torch-Pruning?color=2196f3\" alt=\"Downloads\"\u003e\u003c/a\u003e\n  \u003ca href=\"https://github.com/VainF/Torch-Pruning/releases/latest\"\u003e\u003cimg src=\"https://img.shields.io/badge/Latest%20Version-1.5.2-3f51b5.svg\" alt=\"Latest Version\"\u003e\u003c/a\u003e\n  \u003ca href=\"https://colab.research.google.com/drive/1TRvELQDNj9PwM-EERWbF3IQOyxZeDepp?usp=sharing\"\u003e\n  \u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\n\u003c/a\u003e\n  \u003ca href=\"https://arxiv.org/abs/2301.12900\" target=\"_blank\"\u003e\u003cimg src=\"https://img.shields.io/badge/arXiv-2301.12900-009688.svg\" alt=\"arXiv\"\u003e\u003c/a\u003e\n\u003c/p\u003e\n\nTorch-Pruning (TP) is a framework for structural pruning with the following features:\n\n*  **General-purpose Pruning Toolkit:** TP enables structural pruning for a wide range of deep neural networks. Different from [torch.nn.utils.prune](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) that zeroizes parameters via masking, Torch-Pruning deploys an algorithm called ⚡ **[DepGraph](https://openaccess.thecvf.com/content/CVPR2023/html/Fang_DepGraph_Towards_Any_Structural_Pruning_CVPR_2023_paper.html)** to group and remove coupled parameters. \n* **Examples**: Pruning off-the-shelf models from Huggingface, Timm, Torchvision, including [Large Language Models (LLMs)](https://github.com/VainF/Torch-Pruning/tree/master/examples/LLMs), [Segment Anything Model (SAM)](https://github.com/czg1225/SlimSAM), [Diffusion Models](https://github.com/VainF/Diff-Pruning), [Vision Transformers](https://github.com/VainF/Isomorphic-Pruning), [ConvNext](https://github.com/VainF/Isomorphic-Pruning), [Yolov7](examples/yolov7/), [yolov8](examples/yolov8/),  [Swin Transformers](examples/transformers#swin-transformers-from-hf-transformers), [BERT](examples/transformers#bert-from-hf-transformers), FasterRCNN, SSD, ResNe(X)t, DenseNet, RegNet, DeepLab, etc. A detailed list can be found in 🎨 [Examples](examples).\n\n\nFor more technical details, please refer to our CVPR'23 paper.\n\u003e [**DepGraph: Towards Any Structural Pruning**](https://openaccess.thecvf.com/content/CVPR2023/html/Fang_DepGraph_Towards_Any_Structural_Pruning_CVPR_2023_paper.html)   \n\u003e *[Gongfan Fang](https://fangggf.github.io/), [Xinyin Ma](https://horseee.github.io/), [Mingli Song](https://person.zju.edu.cn/en/msong), [Michael Bi Mi](https://dblp.org/pid/317/0937.html), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)*    \n\u003e *[xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore*\n\n\n\n\n### Update:\n- 🔥 2025.03.24  Examples for pruning [**DeepSeek-R1-Distill**](https://github.com/VainF/Torch-Pruning/tree/master/examples/LLMs).\n- 🔥 2024.11.17  We are working to add more [**examples for LLMs**](https://github.com/VainF/Torch-Pruning/tree/master/examples/LLMs), such as Llama-2/3, Phi-3, Qwen-2/2.5.  \n- 🔥 2024.09.27  Check our latest work, [**MaskLLM (NeurIPS 24 Spotlight)**](https://github.com/NVlabs/MaskLLM), for learnable semi-structured sparsity of LLMs.\n- 🔥 2024.07.20  Add [**Isomorphic Pruning (ECCV'24)**](https://arxiv.org/abs/2407.04616). A SOTA method for Vision Transformers and Modern CNNs.\n\n### **Contact Us:**\nPlease do not hesitate to open an [issue](https://github.com/VainF/Torch-Pruning/issues) if you encounter any problems with the library or the paper.   \nOr Join our WeChat group for more discussions: ✉️ [Group-2](https://github.com/user-attachments/assets/3fe4c487-5a5b-43fd-bf64-a5ee62c3dec1) (\u003e200/500), ✉️ [Group-1](https://github.com/VainF/Torch-Pruning/assets/18592211/35d66130-eb03-4dcb-ad75-8df784460ad3) (500/500, FULL).\n\n## Table of Contents\n- [Installation](#installation)\n- [Quickstart](#quickstart)\n   - [Why Torch-Pruning?](#why-torch-pruning)\n   - [How It Works: DepGraph](#how-it-works-depgraph)\n   - [High-level Pruners](#high-level-pruners)\n     - [Global Pruning and Isomorphic Pruning](#global-pruning-and-isomorphic-pruning)\n     - [Pruning Ratios](#pruning-ratios)\n     - [Sparse Training (Optional)](#sparse-training-optional)\n     - [Interactive Pruning](#interactive-pruning)\n     - [Pruning by Masking](#pruning-by-masking)\n     - [Group-level Pruning](#group-level-pruning)\n     - [Modify static attributes or forward functions](#modify-static-attributes-or-forward-functions)\n   - [Save \u0026 Load](#save-and-load)\n   - [Low-level Pruning Functions](#low-level-pruning-functions)\n   - [Customized Layers](#customized-layers)\n   - [Reproduce Paper Results](#reproduce-paper-results)\n     - [Our Results on {ResNet-56 / CIFAR-10 / 2.00x}](#our-results-on-resnet-56--cifar-10--200x)\n     - [Latency](#latency)\n   - [Series of Works](#series-of-works)\n- [Citation](#citation)\n\n## Installation\n\nTorch-Pruning only relies on PyTorch and Numpy, and it is compatible with PyTorch 1.x and 2.x. To install the latest version, run the following command:\n\n```bash\npip install torch-pruning --upgrade\n```\nFor editable installation:\n```bash\ngit clone https://github.com/VainF/Torch-Pruning.git\ncd Torch-Pruning \u0026\u0026 pip install -e .\n```\n\n## Quickstart\n  \nHere we provide a quick start for Torch-Pruning. More explained details can be found in [Tutorals](https://github.com/VainF/Torch-Pruning/wiki)\n\n### Why Torch-Pruning?\n\nIn structural pruning, the removal of a single parameter may affect multiple layers. For example, pruning an output dimension of a linear layer will require the removal of the corresponding input dimension in the following linear layer as shown in (a). This dependency between layers makes it challenging to prune complicated networks manually. Torch-Pruning addresses this issue by introducing a graph-based algorithm called **[DepGraph](https://openaccess.thecvf.com/content/CVPR2023/html/Fang_DepGraph_Towards_Any_Structural_Pruning_CVPR_2023_paper.html)** to automatically identify dependencies and collect groups for pruning.\n\n\u003cdiv align=\"center\"\u003e\n\u003cimg src=\"assets/dep.png\" width=\"100%\"\u003e\n\u003c/div\u003e\n\n### How It Works: DepGraph\n \n\u003e [!IMPORTANT]  \n\u003e Please make sure that AutoGrad is enabled since TP will analyze the model structure with the Pytorch AutoGrad. This means we need to remove ``torch.no_grad()`` or something similar when building the dependency graph.\n\n```python\nimport torch\nfrom torchvision.models import resnet18\nimport torch_pruning as tp\n\nmodel = resnet18(pretrained=True).eval()\n\n# 1. Build dependency graph for a resnet18. This requires a dummy input for forwarding\nDG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))\n\n# 2. To prune the output channels of model.conv1, we need to find the corresponding group with a pruning function and pruning indices.\ngroup = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )\n\n# 3. Do the pruning\nif DG.check_pruning_group(group): # avoid over-pruning, i.e., channels=0.\n    group.prune()\n    \n# 4. Save \u0026 Load\nmodel.zero_grad() # clear gradients to avoid a large file size\ntorch.save(model, 'model.pth') # !! no .state_dict here since the structure has been changed after pruning\nmodel = torch.load('model.pth') # load the pruned model. you may need torch.load('model.pth', weights_only=False) for PyTorch 2.6.0+.\n```\nThe above example shows the core algorithm, DepGraph, that captures the dependencies in structural pruning. The target layer `model.conv1` is coupled with multiple layers, necessitating their simultaneous removal in structural pruning. We can print the group to take a look at the internal dependencies. In the subsequent outputs, \"A =\u003e B\" indicates that pruning operation \"A\" triggers pruning operation \"B.\" The first group[0] refers to the root of pruning. For more details about grouping, please refer to [Wiki - DepGraph \u0026 Group](https://github.com/VainF/Torch-Pruning/wiki/3.-DepGraph-\u0026-Group).\n\n```python\nprint(group.details()) # or print(group)\n```\n```\n--------------------------------\n          Pruning Group\n--------------------------------\n[0] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) =\u003e prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs (3) =[2, 6, 9]  (Pruning Root)\n[1] prune_out_channels on conv1 (Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) =\u003e prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs (3) =[2, 6, 9] \n[2] prune_out_channels on bn1 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) =\u003e prune_out_channels on _ElementWiseOp_20(ReluBackward0), idxs (3) =[2, 6, 9] \n[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) =\u003e prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), idxs (3) =[2, 6, 9] \n[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) =\u003e prune_out_channels on _ElementWiseOp_18(AddBackward0), idxs (3) =[2, 6, 9] \n[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) =\u003e prune_in_channels on layer1.0.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs (3) =[2, 6, 9] \n[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) =\u003e prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs (3) =[2, 6, 9] \n[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) =\u003e prune_out_channels on _ElementWiseOp_17(ReluBackward0), idxs (3) =[2, 6, 9] \n[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) =\u003e prune_out_channels on _ElementWiseOp_16(AddBackward0), idxs (3) =[2, 6, 9] \n[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) =\u003e prune_in_channels on layer1.1.conv1 (Conv2d(61, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs (3) =[2, 6, 9] \n[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) =\u003e prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs (3) =[2, 6, 9] \n[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) =\u003e prune_out_channels on _ElementWiseOp_15(ReluBackward0), idxs (3) =[2, 6, 9] \n[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) =\u003e prune_in_channels on layer2.0.downsample.0 (Conv2d(61, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs (3) =[2, 6, 9] \n[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) =\u003e prune_in_channels on layer2.0.conv1 (Conv2d(61, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs (3) =[2, 6, 9] \n[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) =\u003e prune_out_channels on layer1.1.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs (3) =[2, 6, 9] \n[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) =\u003e prune_out_channels on layer1.0.conv2 (Conv2d(64, 61, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs (3) =[2, 6, 9] \n--------------------------------\n```\n\n#### How to scan all groups (Advanced):\nThere might be many groups in a model. We can use ``DG.get_all_groups(ignored_layers, root_module_types)`` to scan all prunable groups sequentially. Each group will begin with a layer that matches the one ``nn.Module`` class in ``root_module_types``. The ``ignored_layers`` parameter is used to skip some layers that should not be pruned. For example, we can skip the first convolution layer in a ResNet model. \n\n```python\nfor group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):\n    # Handle groups in sequential order\n    idxs = [2,4,6] # your pruning indices, feel free to change them\n    group.prune(idxs=idxs)\n    print(group)\n```\n\n### High-level Pruners\n\n\u003e [!NOTE]  \n\u003e **The pruning ratio**: In TP, the ``pruning_ratio`` refers to the pruning ratio of channels/dims. Since both in \u0026 out dims will be removed by $p$, the actual ``parameter_pruning_ratio`` of  will be roughly $1-(1-p)^2$. To remove 50% of parameters, you may use ``pruning_ratio=0.30`` instead, which leads to the actual parameter pruning ratio of `$1-(1-0.3)^2=0.51$ (51% parameters removed).\n\nWith DepGraph, we developed several high-level pruners to facilitate effortless pruning. By specifying the desired channel pruning ratio, the pruner will scan all prunable groups, estimate weight importance and perform pruning. You can fine-tune the remaining weights using your own training code. For detailed information on this process, please refer to [this tutorial](https://github.com/VainF/Torch-Pruning/blob/master/examples/notebook/1%20-%20Customize%20Your%20Own%20Pruners.ipynb), which shows how to implement a [Network Slimming (ICCV 2017)](https://arxiv.org/abs/1708.06519) pruner from scratch. Additionally, a more practical example is available in [VainF/Isomorphic-Pruning](https://github.com/VainF/Isomorphic-Pruning) for ViT and ConvNext pruning.\n\n```python\nimport torch\nfrom torchvision.models import resnet18\nimport torch_pruning as tp\n\nmodel = resnet18(pretrained=True)\nexample_inputs = torch.randn(1, 3, 224, 224)\n\n# 1. Importance criterion, here we calculate the L2 Norm of grouped weights as the importance score\nimp = tp.importance.GroupMagnitudeImportance(p=2) \n\n# 2. Initialize a pruner with the model and the importance criterion\nignored_layers = []\nfor m in model.modules():\n    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:\n        ignored_layers.append(m) # DO NOT prune the final classifier!\n\npruner = tp.pruner.BasePruner( # We can always choose BasePruner if sparse training is not required.\n    model,\n    example_inputs,\n    importance=imp,\n    pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} =\u003e ResNet18_Half = {32, 64, 128, 256}\n    # pruning_ratio_dict = {model.conv1: 0.2, model.layer2: 0.8}, # customized pruning ratios for layers or blocks\n    ignored_layers=ignored_layers,\n    round_to=8, # It's recommended to round dims/channels to 4x or 8x for acceleration. Please see: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html\n)\n\n# 3. Prune the model\nbase_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)\ntp.utils.print_tool.before_pruning(model) # or print(model)\npruner.step()\ntp.utils.print_tool.after_pruning(model) # or print(model), this util will show the difference before and after pruning\nmacs, nparams = tp.utils.count_ops_and_params(model, example_inputs)\nprint(f\"MACs: {base_macs/1e9} G -\u003e {macs/1e9} G, #Params: {base_nparams/1e6} M -\u003e {nparams/1e6} M\")\n\n\n# 4. finetune the pruned model using your own code.\n# finetune(model)\n# ...\n```\n\n\u003cdetails\u003e\n  \u003csummary\u003eOutput\u003c/summary\u003e\n  \nThe model difference before and after pruning will be highlighted by something like `(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) =\u003e (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)`.\n```\nResNet(\n  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) =\u003e (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) =\u003e (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n  (relu): ReLU(inplace=True)\n  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n...\n     (1): BasicBlock(\n      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) =\u003e (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) =\u003e (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n      (relu): ReLU(inplace=True)\n      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) =\u003e (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) =\u003e (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n    )\n  )\n  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n  (fc): Linear(in_features=512, out_features=1000, bias=True) =\u003e (fc): Linear(in_features=256, out_features=1000, bias=True)\n)\n\nMACs: 1.822177768 G -\u003e 0.487202536 G, #Params: 11.689512 M -\u003e 3.05588 M\n```\n\u003c/details\u003e\n\n\n\n\n#### Global Pruning and Isomorphic Pruning\nGlobal pruning performs importance ranking on all layers, which has the potential to find better structures. This can be easily achieved by setting ``global_pruning=True`` in the pruner. While this strategy can possibly offer performance advantages, it also carries the potential of overly pruning specific layers, resulting in a substantial decline in overall performance. We provide an alternative algorithm called [Isomorphic Pruning](https://arxiv.org/abs/2407.04616) to alleviate this issue, which can be enabled with ``isomorphic=True``. Comprehensive examples for ViT \u0026 ConvNext pruning are available in [this project](https://github.com/VainF/Isomorphic-Pruning).\n\n```python\npruner = tp.pruner.BasePruner(\n    ...\n    isomorphic=True, # enable isomorphic pruning to improve global ranking\n    global_pruning=True, # global pruning\n)\n```\n\n\u003cdiv align=\"center\"\u003e\n\u003cimg src=\"assets/isomorphic_pruning.png\" width=\"96%\"\u003e\n\u003c/div\u003e\n\n#### Pruning Ratios\n\nThe argument ``pruning_ratio`` detemines the default pruning ratio. If you want to customize the pruning ratio for some layers or blocks, you can use ``pruning_ratio_dict``. The key of the dict can be a single ``nn.Module`` or a tuple of ``nn.Module``. In the second case, all modules in the tuple will form a ``scope`` and share the user-defined pruning ratio and compete to be pruned. \n```python\npruner = tp.pruner.BasePruner(\n    ...\n    global_pruning=True,\n    pruning_ratio=0.5, # default pruning ratio\n    # layer1 \u0026 layer2 will share a total pruning ratio of 0.4 while layer 3 will have a pruning ratio of 0.2\n    pruning_ratio_dict = {(model.layer1, model.layer2): 0.4, model.layer3: 0.2}, \n)\n```\n\n#### Sparse Training (Optional)\nSome pruners like [BNScalePruner](https://github.com/VainF/Torch-Pruning/blob/dd59921365d72acb2857d3d74f75c03e477060fb/torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py#L45) and [GroupNormPruner](https://github.com/VainF/Torch-Pruning/blob/dd59921365d72acb2857d3d74f75c03e477060fb/torch_pruning/pruner/algorithms/group_norm_pruner.py#L53) support sparse training. This can be easily achieved by inserting ``pruner.update_regularizer()`` and ``pruner.regularize(model)`` in your standard training loops. The pruner will accumulate the regularization gradients to ``.grad``. Sparse training is optional and may not always gaurentee better performance. Be careful when using it.\n```python\nfor epoch in range(epochs):\n    model.train()\n    pruner.update_regularizer() # \u003c== initialize regularizer\n    for i, (data, target) in enumerate(train_loader):\n        data, target = data.to(device), target.to(device)\n        optimizer.zero_grad()\n        out = model(data)\n        loss = F.cross_entropy(out, target)\n        loss.backward() # after loss.backward()\n        pruner.regularize(model) # \u003c== for sparse training\n        optimizer.step() # before optimizer.step()\n```\n\n#### Interactive Pruning\nAll high-level pruners offer support for interactive pruning. You can utilize the method `pruner.step(interactive=True)` to retrieve all the groups and interactively prune them by calling `group.prune()`. This feature is particularly useful if you want to control or monitor the pruning process.\n\n```python\nfor i in range(iterative_steps):\n    for group in pruner.step(interactive=True): # Warning: groups must be handled sequentially. Do not keep them as a list.\n        print(group) \n        # do whatever you like with the group \n        dep, idxs = group[0] # get the idxs\n        target_module = dep.target.module # get the root module\n        pruning_fn = dep.handler # get the pruning function\n        group.prune()\n        # group.prune(idxs=[0, 2, 6]) # It is even possible to change the pruning behaviour with the idxs parameter\n    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)\n    # finetune your model here\n    # finetune(model)\n    # ...\n```\n\n#### Pruning by Masking\n\nIt is possible to implement masking-based Pruning leveraging ``interactive=True``, which zeros out parameters without removing them. An example can be found in [tests/test_soft_pruning.py](https://github.com/VainF/Torch-Pruning/blob/c9cea192a31f64e5ea26c095a70e2e93acf0be77/tests/test_soft_pruning.py#L39)\n\n#### Group-level Pruning\n\nWith DepGraph, it is easy to design some \"group-level\" importance scores to estimate the importance of a whole group rather than a single layer. This feature can be also used to sparsify coupled layers, making all the to-be-pruned parameters consistently sparse. In Torch-pruning, all pruners work at the group level. Check the following results to see how grouping improves the performance of pruning.\n\n\u003cdiv align=\"center\"\u003e\n\u003cimg src=\"assets/group_sparsity.png\" width=\"80%\"\u003e\n\u003c/div\u003e\n\n* Pruning a ResNet50 pre-trained on ImageNet-1K without fine-tuning.\n\u003cdiv align=\"center\"\u003e\n\u003cimg src=\"https://github.com/VainF/Torch-Pruning/assets/18592211/775eb01a-4610-4637-90bd-ff53f7ea2d31\" width=\"45%\"\u003e\u003c/img\u003e\n\u003cimg src=\"https://github.com/VainF/Torch-Pruning/assets/18592211/085aa9ec-a520-4939-97f4-46f65b124929\" width=\"45%\"\u003e\u003c/img\u003e\n\u003c/div\u003e\n\n* Pruning a Vision Transformer pre-trained on ImageNet-1K without fine-tuning.\n\u003cdiv align=\"center\"\u003e\n\u003cimg src=\"https://github.com/VainF/Torch-Pruning/assets/18592211/6f99aa90-259d-41e8-902a-35675a9c9d90\" width=\"45%\"\u003e\u003c/img\u003e\n\u003cimg src=\"https://github.com/VainF/Torch-Pruning/assets/18592211/11473499-d28a-434b-a8d6-1a53c4b3b7c0\" width=\"45%\"\u003e\u003c/img\u003e\n\u003c/div\u003e\n\n#### Modify static attributes or forward functions\n\nIn some implementations, model forwarding might rely on static attributes. For example in [``convformer_s18``](https://github.com/huggingface/pytorch-image-models/blob/054c763fcaa7d241564439ae05fbe919ed85e614/timm/models/metaformer.py#L107) of timm, we have ``self.shape`` which will be changed after pruning. These attributes should be updated manually since it is impossible for TP to know the purpose of these attributes. \n\n```python\nclass Scale(nn.Module):\n    \"\"\"\n    Scale vector by element multiplications.\n    \"\"\"\n\n    def __init__(self, dim, init_value=1.0, trainable=True, use_nchw=True):\n        super().__init__()\n        self.shape = (dim, 1, 1) if use_nchw else (dim,) # static shape, which should be updated after pruning\n        self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable)\n\n    def forward(self, x):\n        return x * self.scale.view(self.shape) # =\u003e x * self.scale.view(-1, 1, 1), this works for pruning\n```\n\n\n### Save and Load\n\nThe following script saves the whole model object (structure+weights) as a 'model.pth'. You can load it using the standard PyTorch API. Just remember that we save and load the whole model **without** ``.state_dict`` or ``.load_state_dict``, since the pruned sturctured will be different from the original definition in your ``model.py``. \n```python\nmodel.zero_grad() # Remove gradients\ntorch.save(model, 'model.pth') # without .state_dict\nmodel = torch.load('model.pth') # load the pruned model\n# For PyTorch 2.6.0+, you may need weights_only=False to enable model loading\n# model = torch.load('model.pth', weights_only=False)\n```\n                   \n### Low-level Pruning Functions\n\nIn Torch-Pruning, we provide a series of low-level pruning functions that only prune a single layer or module. To manually prune the ``model.conv1`` of a ResNet-18, the pruning pipeline should look like this:\n\n```python\ntp.prune_conv_out_channels( model.conv1, idxs=[2,6,9] )\n\n# fix the broken dependencies manually\ntp.prune_batchnorm_out_channels( model.bn1, idxs=[2,6,9] )\ntp.prune_conv_in_channels( model.layer2[0].conv1, idxs=[2,6,9] )\n...\n```\n\nThe following [pruning functions](torch_pruning/pruner/function.py) are available:\n```python\n'prune_conv_out_channels',\n'prune_conv_in_channels',\n'prune_depthwise_conv_out_channels',\n'prune_depthwise_conv_in_channels',\n'prune_batchnorm_out_channels',\n'prune_batchnorm_in_channels',\n'prune_linear_out_channels',\n'prune_linear_in_channels',\n'prune_prelu_out_channels',\n'prune_prelu_in_channels',\n'prune_layernorm_out_channels',\n'prune_layernorm_in_channels',\n'prune_embedding_out_channels',\n'prune_embedding_in_channels',\n'prune_parameter_out_channels',\n'prune_parameter_in_channels',\n'prune_multihead_attention_out_channels',\n'prune_multihead_attention_in_channels',\n'prune_groupnorm_out_channels',\n'prune_groupnorm_in_channels',\n'prune_instancenorm_out_channels',\n'prune_instancenorm_in_channels',\n```\n\n### Customized Layers\n\nPlease refer to [examples/transformers/prune_hf_swin.py](examples/transformers/prune_hf_swin.py), which implements a new pruner for the customized module ``SwinPatchMerging``. Another simple example is available at [tests/test_customized_layer.py](https://github.com/VainF/Torch-Pruning/blob/master/tests/test_customized_layer.py).\n\n### Reproduce Paper Results\n\nPlease see [reproduce](reproduce).\n\n#### Our results on {ResNet-56 / CIFAR-10 / 2.00x}\n\n| Method | Base (%) | Pruned (%) | $\\Delta$ Acc (%) | Speed Up |\n|:--    |:--:  |:--:    |:--: |:--:      |\n| NIPS [[1]](#1)  | -    | -      |-0.03 | 1.76x    |\n| Geometric [[2]](#2) | 93.59 | 93.26 | -0.33 | 1.70x |\n| Polar [[3]](#3)  | 93.80 | 93.83 | +0.03 |1.88x |\n| CP  [[4]](#4)   | 92.80 | 91.80 | -1.00 |2.00x |\n| AMC [[5]](#5)   | 92.80 | 91.90 | -0.90 |2.00x |\n| HRank [[6]](#6) | 93.26 | 92.17 | -0.09 |2.00x |\n| SFP  [[7]](#7)  | 93.59 | 93.36 | +0.23 |2.11x |\n| ResRep [[8]](#8) | 93.71 | 93.71 | +0.00 |2.12x |\n||\n| Ours-L1 | 93.53 | 92.93 | -0.60 | 2.12x |\n| Ours-BN | 93.53 | 93.29 | -0.24 | 2.12x |\n| Ours-Group | 93.53 | 93.77 | +0.38 | 2.13x |\n\n#### Latency\n\nLatency test on ResNet-50, Batch Size=64. \n```\n[Iter 0]        Pruning ratio: 0.00,         MACs: 4.12 G,   Params: 25.56 M,        Latency: 45.22 ms +- 0.03 ms\n[Iter 1]        Pruning ratio: 0.05,         MACs: 3.68 G,   Params: 22.97 M,        Latency: 46.53 ms +- 0.06 ms\n[Iter 2]        Pruning ratio: 0.10,         MACs: 3.31 G,   Params: 20.63 M,        Latency: 43.85 ms +- 0.08 ms\n[Iter 3]        Pruning ratio: 0.15,         MACs: 2.97 G,   Params: 18.36 M,        Latency: 41.22 ms +- 0.10 ms\n[Iter 4]        Pruning ratio: 0.20,         MACs: 2.63 G,   Params: 16.27 M,        Latency: 39.28 ms +- 0.20 ms\n[Iter 5]        Pruning ratio: 0.25,         MACs: 2.35 G,   Params: 14.39 M,        Latency: 34.60 ms +- 0.19 ms\n[Iter 6]        Pruning ratio: 0.30,         MACs: 2.02 G,   Params: 12.46 M,        Latency: 33.38 ms +- 0.27 ms\n[Iter 7]        Pruning ratio: 0.35,         MACs: 1.74 G,   Params: 10.75 M,        Latency: 31.46 ms +- 0.20 ms\n[Iter 8]        Pruning ratio: 0.40,         MACs: 1.50 G,   Params: 9.14 M,         Latency: 29.04 ms +- 0.19 ms\n[Iter 9]        Pruning ratio: 0.45,         MACs: 1.26 G,   Params: 7.68 M,         Latency: 27.47 ms +- 0.28 ms\n[Iter 10]       Pruning ratio: 0.50,         MACs: 1.07 G,   Params: 6.41 M,         Latency: 20.68 ms +- 0.13 ms\n[Iter 11]       Pruning ratio: 0.55,         MACs: 0.85 G,   Params: 5.14 M,         Latency: 20.48 ms +- 0.21 ms\n[Iter 12]       Pruning ratio: 0.60,         MACs: 0.67 G,   Params: 4.07 M,         Latency: 18.12 ms +- 0.15 ms\n[Iter 13]       Pruning ratio: 0.65,         MACs: 0.53 G,   Params: 3.10 M,         Latency: 15.19 ms +- 0.01 ms\n[Iter 14]       Pruning ratio: 0.70,         MACs: 0.39 G,   Params: 2.28 M,         Latency: 13.47 ms +- 0.01 ms\n[Iter 15]       Pruning ratio: 0.75,         MACs: 0.29 G,   Params: 1.61 M,         Latency: 10.07 ms +- 0.01 ms\n[Iter 16]       Pruning ratio: 0.80,         MACs: 0.18 G,   Params: 1.01 M,         Latency: 8.96 ms +- 0.02 ms\n[Iter 17]       Pruning ratio: 0.85,         MACs: 0.10 G,   Params: 0.57 M,         Latency: 7.03 ms +- 0.04 ms\n[Iter 18]       Pruning ratio: 0.90,         MACs: 0.05 G,   Params: 0.25 M,         Latency: 5.81 ms +- 0.03 ms\n[Iter 19]       Pruning ratio: 0.95,         MACs: 0.01 G,   Params: 0.06 M,         Latency: 5.70 ms +- 0.03 ms\n[Iter 20]       Pruning ratio: 1.00,         MACs: 0.01 G,   Params: 0.06 M,         Latency: 5.71 ms +- 0.03 ms\n```\n\n### Series of Works\n\n\u003e **DepGraph: Towards Any Structural Pruning** [[Project]](https://github.com/VainF/Torch-Pruning) [[Paper]](https://openaccess.thecvf.com/content/CVPR2023/html/Fang_DepGraph_Towards_Any_Structural_Pruning_CVPR_2023_paper.html)   \n\u003e *Gongfan Fang, Xinyin Ma, Mingli Song, Michael Bi Mi, Xinchao Wang*  \n\u003e CVPR 2023\n\n\u003e **Isomorphic Pruning for Vision Models** [[Project]](https://github.com/VainF/Isomorphic-Pruning) [[Arxiv]](https://arxiv.org/abs/2407.04616)  \n\u003e *Gongfan Fang, Xinyin Ma, Michael Bi Mi, Xinchao Wang*   \n\u003e ECCV 2024\n\n\u003e **LLM-Pruner: On the Structural Pruning of Large Language Models** [[Project]](https://github.com/horseee/LLM-Pruner) [[arXiv]](https://arxiv.org/abs/2305.11627)   \n\u003e *Xinyin Ma, Gongfan Fang, Xinchao Wang*  \n\u003e NeurIPS 2023\n\n\u003e **Structural Pruning for Diffusion Models** [[Project]](https://github.com/VainF/Diff-Pruning) [[arxiv]](https://arxiv.org/abs/2305.10924)  \n\u003e *Gongfan Fang, Xinyin Ma, Xinchao Wang*  \n\u003e NeurIPS 2023\n\n\u003e **DeepCache: Accelerating Diffusion Models for Free** [[Project]](https://github.com/horseee/DeepCache) [[Arxiv]](https://arxiv.org/abs/2312.00858)  \n\u003e *Xinyin Ma, Gongfan Fang, and Xinchao Wang*   \n\u003e CVPR 2024\n\n\u003e **SlimSAM: 0.1% Data Makes Segment Anything Slim** [[Project]](https://github.com/czg1225/SlimSAM) [[Arxiv]](https://arxiv.org/abs/2312.05284)    \n\u003e *Zigeng Chen, Gongfan Fang, Xinyin Ma, Xinchao Wang*   \n\u003e Preprint 2023\n\n\n## Citation\n```\n@inproceedings{fang2023depgraph,\n  title={Depgraph: Towards any structural pruning},\n  author={Fang, Gongfan and Ma, Xinyin and Song, Mingli and Mi, Michael Bi and Wang, Xinchao},\n  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},\n  pages={16091--16101},\n  year={2023}\n}\n```\n\n","funding_links":[],"categories":[],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fvainf%2Ftorch-pruning","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fvainf%2Ftorch-pruning","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fvainf%2Ftorch-pruning/lists"}