{"id":16191534,"url":"https://github.com/witmemtech/witin-nn-tool-","last_synced_at":"2026-02-03T23:04:30.026Z","repository":{"id":253201675,"uuid":"842756207","full_name":"witmemtech/Witin-NN-Tool-","owner":"witmemtech","description":"The \"witin_nn\" framework, based on PyTorch, maps neural networks to chip computations and supports operators including Linear, Conv2d, and GruCell. It enables 8-12 bit quantization for inputs/outputs and weights, implementing QAT.","archived":false,"fork":false,"pushed_at":"2024-08-15T08:34:13.000Z","size":20534,"stargazers_count":0,"open_issues_count":1,"forks_count":0,"subscribers_count":0,"default_branch":"main","last_synced_at":"2025-02-13T17:18:08.629Z","etag":null,"topics":["computinginmemory","conv2d","convtranspose2d","grucell","linear","nat","natural-language-processing","nerual-network","pytorch","qat"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":null,"status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/witmemtech.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2024-08-15T02:38:37.000Z","updated_at":"2024-09-27T07:52:53.000Z","dependencies_parsed_at":"2024-09-22T06:01:50.337Z","dependency_job_id":"83c8fd88-802e-43ff-ac64-ee8e99aaafd1","html_url":"https://github.com/witmemtech/Witin-NN-Tool-","commit_stats":{"total_commits":6,"total_committers":1,"mean_commits":6.0,"dds":0.0,"last_synced_commit":"dc6f7c34a14da2726a067cb0360103385cc28b99"},"previous_names":["witmem/witin-nn-tool-"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/witmemtech%2FWitin-NN-Tool-","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/witmemtech%2FWitin-NN-Tool-/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/witmemtech%2FWitin-NN-Tool-/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/witmemtech%2FWitin-NN-Tool-/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/witmemtech","download_url":"https://codeload.github.com/witmemtech/Witin-NN-Tool-/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247675633,"owners_count":20977376,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["computinginmemory","conv2d","convtranspose2d","grucell","linear","nat","natural-language-processing","nerual-network","pytorch","qat"],"created_at":"2024-10-10T08:00:51.288Z","updated_at":"2026-02-03T23:04:29.964Z","avatar_url":"https://github.com/witmemtech.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"#                                Witin_nn V2.1.0  **用户手册**\n\n## 一、witin_nn 概述\n\n### **1.1**、背景介绍\n\n由于知存科技存内计算方案的模拟噪声影响，单纯经过浮点训练的神经网络模型在\n\n部署到芯片后往往会出现性能下降，因此有必要引入噪声感知训练，使得神经网络在训\n\n练过程中感知到芯片的噪声特性，从而获得部署到芯片的更好性能。\n\nwitin_nn 框架是基于 PyTorch 开发的，witin_nn 框架主要实现了适配知存科技芯\n\n片的量化感知训练（QAT）和噪声感知训练（NAT）方法，目前支持 Linear、Conv2d、\n\nConvTranspose2d、GruCell 等算子。本框架通过在神经网络的正向传播链路上引入输\n\n入、权重、偏置以及输出的噪声，干预神经网络的反向传播（参数更新），从而增强网\n\n络的泛化能力。具体来说，witin_nn 模拟神经网络映射到知存科技存内芯片计算的过程，\n\n支持输入和输出的 8bits~12bits 位宽量化以及权重的 8bits 量化，实现 QAT，并引入\n\n模拟电路噪声，实现 NAT。\n\n从训练效果来看，如果以浮点训练的浮点软跑性能作为 baseline，通常在增加量化\n\n感知训练（QAT）、噪声感知训练（NAT）之后，部署到芯片的性能会更加逼近 baseline。\n\n\n\n### **1.2**、计算详解\n\n\\1. 如表 1 所示，展示了各 witin_nn 算子和 torch 算子的对应关系。\n\n| 算子类型 | Witin-nn 算子                 | 对标Pytorch算子           | 芯片计算公式                                                 |\n| -------- | ----------------------------- | ------------------------- | ------------------------------------------------------------ |\n| 存 算    | witin_nn.WitinLinear          | torch.nn. Linear          | output = torch.nn.functional.linear(input, weight, bias) / g_value |\n| 存 算    | witin_nn.WitinConv2d          | torch.nn. Conv2d          | output = torch.nn.functional.conv2d(input, weight, bias, stride, padding, dilation, groups) / g_value |\n| 存 算    | witin_nn.WitinConvTranspose2d | torch.nn. ConvTranspose2d | output = torch.nn.functional.conv_transpose2d(input, weight, bias, stride, padding, output_padding, groups, dilation) / g_value |\n| 存 算    | witin_nn.WitinGruCell         | torch.nn. GRUCell         | output = torch._VF.gru_cell(input, hx, weight_ih, weight_hh, bias_ih, bias_hh) / g_value |\n| 数 字    | witin_nn.WitinGELU            | torch.nn. GELU            | /                                                            |\n| 数 字    | witin_nn.WitinSigmoid         | torch.nn. Sigmoid         | /                                                            |\n| 数 字    | witin_nn.WitinTanh            | torch.nn. Tanh            | /                                                            |\n| 数 字    | witin_nn.WitinPReU            | torch.nn. PReLU           | /                                                            |\n| 数 字    | witin_nn.WitinElement Add     | 加法                      | /                                                            |\n| 数 字    | witin_nn.WitinElement Divide  | 除法                      | /                                                            |\n| 数 字    | witin_nn.WitinElement Mul     | 乘法                      | /                                                            |\n| 数 字    | witin_nn.WitinSqrt            | torch.sqrt                | /                                                            |\n| 数 字    | witin_nn.WitinMean            | torch.mean                | /                                                            |\n| 数 字    | witin_nn.WitinCat             | torch.cat                 | /                                                            |\n| 数 字    | witin_nn.WitinBatchNorm2d     | torch.nn. BatchNo rm2d    | /                                                            |\n\n​                                                                                                      **表** **1**\n\n2. 下面以 witin_nn.WitinLinear 算子为例，简述 QAT 及 NAT 计算的过程（输入、输出\n\n均量化到 8bits）。\n\n![image-20240815110907574](https://github.com/user-attachments/assets/d1a661c1-41f4-4165-8cb7-9f3a015ea73e)\n\n\n如上所示，输入 x 量化为 uint8 的 NPU_x，权重 weight 量化为 int8 的 NPU_weight，偏置 bias 量化为 128 的整数倍，即 NPU_bias，已知 NPU_x，NPU_weight，NPU_bias，可计算出 NPU_y'，其中引入模拟电路噪声，得到 NPU_y，最终量化为 int8。最终，witin_nn.WitinLinear 算子输出为 NPU_y/y_scale（反量化回到浮点域）。\n\n\n\n3. 数学等价性分析：\n\n![image-20240815111009737](https://github.com/user-attachments/assets/76fec77b-978c-479d-ae00-5b11cdfcf0a0)\n\n\n## **二、开发指导**\n\n### **2.1** **环境准备**\n\npython \u003e= 3.7\n\ntorch == 1.13\n\n### **2.2** **算子参数说明**\n\nwitin_nn 算子是对 torch.nn 对应算子的再次封装，witin_nn 算子保留了 torch.nn 对应算子的所有参数，在 torch.nn 参数列表基础上扩展了 QAT 及 NAT 相关参数。在构建神经网络时，需要将 torch 算子替换为对应的 witin_nn 算子，并为其配置相应参数即可。保留参数可以参考 pytorch 官方文档，witin_nn 所有算子都包含以下扩展参数，但不是所有参数都可以生效，释义如下：\n\n| ***\\*参数\\****    | 类型                                                         | 默认值                   | 含义                                                         | 适用算子                                         |\n| ----------------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------ |\n| target_ platform  | ***\\*class\\**** ***\\*TargetPlatfor\\**** ***\\*m\\****(Enum):WTM2101 =1 | TargetPlatform. WTM2101  | 区别不同芯片平台。                                           | 全部                                             |\n| Hardware          | ***\\*Class\\**** ***\\*HardwareType\\**** (Enum): ARRAY = 1 VPU = 2 | Hardwar eType.A RRAY     | 区别不同计算平台。                                           | 全部                                             |\n| w_clip            | float 或者 None                                              | None                     | 当 w_clip = None 时，将不会 对权重做任何操作；反之则会 将 weight 限制在 -w_clip~w_clip 之间。 | 全部存算 算子                                    |\n| bias_row_N        | int                                                          | 8                        | bias 计算所用的 NPU array 行数，仅当 use_quantization = True 时有效。 | 全部存算 算子                                    |\n| use_quantization  | bool                                                         | False                    | use_quantization = True 进 行量化感知训练。 use_quantization = False 进 行浮点训练。 | 全部                                             |\n| noise_ model      | ***\\*class\\**** ***\\*NoiseModel\\****(E num): NORMAL = 1 ARRMDL = 2 MBS = 3 SIMPLE = 4 | NoiseM odel.NO RMAL      | 噪声模型类型，目前仅支持 NORMAL 类型噪声模型。               | 全部存算 算子                                    |\n| noise_ level      | Int                                                          | 0                        | noise_level = 0 不加噪声。 0\u003c=noise_level \u003c10 对应 NORMAL 噪声模型的噪声等 级，数字越大，噪声越强。 | 全部存算 算子                                    |\n| to_linear         | bool                                                         | False                    | 是否将 Conv2d、Conv1d、 ConvTranspose2d 算子等价 替换为 linear 算子进行计算， 训练中保持 to_linear = False 即可。 | WitinLine ar WitinCon v2d WitinCon vTranspo se2d |\n| use_auto_scale    | bool                                                         | True                     | 是否自动计算 scale_x， scale_y，scale_weight                 | 全部                                             |\n| scale_ x          | int                                                          | 1                        | 仅当 use_quantization == True 时有效。                       | 全部                                             |\n| scale_ y          | int                                                          | 1                        | 仅当 use_quantization == True 时有效。                       | 全部                                             |\n| scale_ weight     | int                                                          | 1                        | 仅当 use_quantization == True 时有效。                       | 全部有权重的算子                                 |\n| handle _neg_in    | ***\\*class\\**** ***\\*HandleNegInT\\**** ***\\*ype\\****(Enum): FALSE = 1 #不对负输入做 处理PN = 2 #输入符号变换 至权重 Shift = 3 #对输入整体偏 移 | Handle NegInTy pe.FAL SE | 支持对负输入的处理，仅当 use_quantization == True 时 有效，  | WitinLine ar WitinCon v2d WitinCon vTranspose2d  |\n| shift_num         | float                                                        | 1                        | 选择 HandleNegInType.Shift 时，需配置该偏移参数。            | WitinLinear WitinCon v2d WitinCon vTranspo se2d  |\n| x_quant_bits      | Int                                                          | 8                        | 量化位宽                                                     | 全部                                             |\n| y_quant_bits      | Int                                                          | 8                        | 量化位宽                                                     | 全部                                             |\n| weight_quant_bits | Int                                                          | 8                        | 量化位宽                                                     | 全部有权重的算子                                 |\n| bias_d            | torch.Tensor                                                 | torch.te nsor(0)         | 拆出到数字计算的偏置。                                       | WitinLine ar WitinCon v2d WitinCon vTranspo se2d |\n| Conv2d_split _N   | /                                                            | /                        | 预留 暂不开放                                                |                                                  |\n\n​                                                                                                                 **表** **2** **参数列表**\n\n### **2.3** **配置文件说明**\n\n两种 config 类型:\n\n• WitinGlobalConfig: 全局配置，所有算子的默认配置。\n\n• WitinLayerConfig: 针对某个算子特定的传参设置。\n\ninterface/ConfigFactory.py 中定义了几种标准的配置方案。\n\n\n\n### **2.4** **使用示例**\n\n##### **2.4.1** **定义一个简单的** **torch** **神经网络**\n\n\n![image-20240815111245738](https://github.com/user-attachments/assets/24cf7a5e-65e4-4050-b78e-d2f6470b16ca)\n\n**\n\n#### **2.4.2 witin_nn** **浮点训练示例**\n\n![image-20240815111308627](https://github.com/user-attachments/assets/6ccc93ef-bd31-434d-bb17-688f47816872)\n![image-20240815111326196](https://github.com/user-attachments/assets/c0578c1f-e2e2-4dc8-9d02-f7935776ba90)\n\n\n\n\n#### **2.4.3 witin_nn** **量化训练示例**\n\n![image](https://github.com/user-attachments/assets/45448065-04a1-4121-b818-eac6441fffea)\n\n\n#### **2.4.3 witin_nn** **量化及加噪训练示例**\n\n![image-20240815111445712](https://github.com/user-attachments/assets/94ae1e79-ac3c-4421-be60-f8bd4d7ea01d)\n\n\n\n\n\n#### **2.5** **量化位宽大于** **8bit** **指导**\n\n存算核支持的是 8bits 数据计算，但是为了提高精度，希望量化后输入位宽大于8bits。witin_nn 将模拟映射到芯片的拆分过程（即低 8 位用模拟计算，高位用数字计算）。需要注意的是，bias 也可能会涉及到拆分以保证映射后模拟计算的输出尽量不出现饱和，在此引入额外参数 bias_d（d 意为 digital）来表示拆出到数字计算的偏置。\n\n下面以 witin_nn.WitinLinear 为例，以 10bits 输入、10bits 输出说明该过程。\n\n![image-20240815111530187](https://github.com/user-attachments/assets/44a31ba8-74d4-40f6-a239-2ded8b4f1e11)\n\n\n如上图所示：\n\n（1） 对输入 x 、权重 weight 分别量化为 uint10（0~1023）、int8（-128~127）的整型；对偏置 bias 量化为 128 的整数倍；\n\n（2）将量化后的 x 拆分为低 8 位 NPU_x 和高 2 位 NPU_x_d、量化后的 bias 拆分为模拟计算部分的偏置 NPU_bias 和数字计算部分的偏置 NPU_bias_d；NPU_weight 为量化后的权重。\n\n（3） 进行计算并得到模拟计算输出 NPU_y、数字计算输出 NPU_y_d；\n\n（4）最终输出 y 先将 NPU_y 与 NPU_y_d 求和并量化为 int10，再除以 y_scale（反量化回到浮点域）。\n\n\n\n#### **2.6 auto-scale** **策略理解**\n\n量化方式为对称量化，按照数据的 min-max 确定量化参数，对算子的输入，输出，权重（如果有）进行量化。\n\n举例如下：量化一组数据，量化位宽为 int8，量化参数按如下方式确定：\n\n```\nPython\n\n#量化位宽 int8\n\nx_quant_bits = 8\n\nx = torch.randn(1,10)\n\nx_max = x.abs().max() \n\nscale_x = 2 ** (x_quant_bits - 1) / 2 ** (torch.log2(x_max).ceil())\n\n'''\n\nx: tensor([[ 0.1875, -1.3344, 0.5350, 1.5472, -0.9712, \n\n-1.4459, 0.1024, -0.8054,\n\n -1.7309, -0.8548]])\n\nx_max: 1.7309\n\nscale_x: 128\n\n'''\n```\n\n• 在模型训练阶段，配置 use_auto_scale = True，假定训练 M 个 epoch，每个 epoch包含 N 个 iter。\n\n（1）在训练启动时，会预先训练 n 个 iter，量化参数 data_scale 为用户设置的初始值(scale_x, scale_weight, scale_y)。训练期间统计数据的绝对值的最大值 data_max，n由用户自己配置，对应参数 auto_scale_updata_step。\n\n（2）在训练 iter 超过 n 之后，根据 data_max 计算 data_scale，并更新 data_scale，后续的 N-n 个 iter 的训练都将使用该 data_scale。\n\n（3）在下一个 epoch 开始后，重复（1）（2）步，data_max 重新统计，data_scale 重新计算。\n\n（4）训练完成后，在保存的模型文件中，模型的每一层均包含参数 io_max，即该层的data_max。\n\n• 在模型推理阶段，配置 use_auto_scale = Truewitin_nn 自动读取模型中的参数 io_max，并自动计算量化参数。\n\n• 如果配置 use_auto_scale = False，量化参数固定，始终为用户配置的 scale_x, scale_weight, scale_y。\n\n• 如果需要提取量化参数，首先要提取 io_max，再手动计算量化参数。\n\n• **启用** **auto-scale** **时，需要特别注意量化参数初值的选择，过小或者过大会影响最终**scale** **的确定。**\n\n\n\n#### **2.7 witin_nn** **训练建议**\n\n下面将对如何应用本框架进行模型训练阐述。\n\n建议在浮点训练模型的基础上逐步引入噪声或限制进行重新训练，所以建议训练顺序：\n\n![image-20240815111906865](https://github.com/user-attachments/assets/dfa385f9-2554-4834-b47f-cede6add339c)\n\n\n训练流程如下：\n\n![image-20240815111929663](https://github.com/user-attachments/assets/ae6a4fd7-a2ec-4076-8483-b0c10b0fb727)\n\n\n以上三部分的训练精度一般来说满足以下规则：\n\n**step1: use_quantization = False 进行浮点训练（可能需要指定 w_clip 对权重进行限制从而得到比较适合芯片部署的预训练模型）。**\n\n在训练结束后，建议分别在测试集上测试三种条件下对应的损失函数值、模型评价指标（取决于具体任务，例如识别率、PSNR 等）。\n\n• use_quantization = False 时，损失函数值记为 Lf1、模型评价指标记为 Pf1。\n\n• use_quantization = True，指定 scale_x、scale_y、scale_weight，指定 bias_row_N (=8) 时，损失函数值记为 Lf2、模型评价指标记为 Pf2。\n\n• use_quantization = True、use_noise = True，指定 scale_x、scale_y、scale_weight，指定 bias_row_N (=8)时，损失函数值记为 Lf3、模型评价指标记为 Pf3。一般来说，Lf1\u003cLf2\u003cLf3、Pf1 好于 Pf2 好于 Pf3，具体差别多少反映了量化、加噪带来的影响。\n\n\n\n**step2: use_quantization =True、指定 scale_x、scale_y、scale_weight、指定bias_row_N (=8) ，加载 step1 浮点模型，进行 QAT 重训练（量化损失不大时可略过）。**\n\n在训练结束后，建议分别在测试集上测试三种条件下对应的损失函数值、模型评价指标（取决于具体任务，例如识别率、PSNR 等）。\n\n• use_quantization = False 时，损失函数值记为 Lq1、模型评价指标记为 Pq1\n\n• use_quantization = True、指定 scale_x、scale_y、scale_weight，指定 bias_row_N (=8) 时，损失函数值记为 Lq2、模型评价指标为 Pq2\n\n• use_quantization = True、use_noise = True，指定 scale_x、scale_y、scale_weight，指定 bias_row_N (=8)时，损失函数值记为 Lq3、模型评价指标记为 Pq3。量化感知训练之后，我们希望 Lf1≈Lq2\u003cLq3、Pf1≈Pq2 好于 Pq3，实际情况中，具体问题具体分析。\n\n\n\n**step3: use_quantization =True、use_noise = True，指定 scale_x、scale_y、scale_weight、指定 bias_row_N (=8) ，进行 =NAT 重训练。**\n\n在训练结束后，建议分别在测试集上测试三种条件下对应的损失函数值、模型评价指标（取决于具体任务，例如识别率、PSNR 等）。\n\n• use_quantization = False 时，损失函数值记为 Ln1、模型评价指标记为 Pn1\n\n• use_quantization = True、指定 scale_x、scale_y、scale_weight，指定 bias_row_N (=8) 时，损失函数值记为 Ln2、模型评价指标记为 Pn2\n\n• use_quantization = True、use_noise = True，指定 scale_x、scale_y、scale_weight，指定 bias_row_N (=8)时，损失函数值记为 Ln3、模型评价指标记为 Pn3。噪声感知训练之后，我们希望 Lf1≈Lq2≈Ln2≈Ln3、Pf1≈Pq2≈Pn2≈Pn3，实际情况中，具体问题具体分析。\n\n\n\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fwitmemtech%2Fwitin-nn-tool-","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fwitmemtech%2Fwitin-nn-tool-","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fwitmemtech%2Fwitin-nn-tool-/lists"}