{"id":28690877,"url":"https://github.com/weltxing/pydt","last_synced_at":"2025-07-09T17:31:21.062Z","repository":{"id":112060269,"uuid":"405339586","full_name":"WeltXing/PyDT","owner":"WeltXing","description":"决策树分类与回归模型的实现和可视化","archived":false,"fork":false,"pushed_at":"2021-10-13T02:55:46.000Z","size":644,"stargazers_count":16,"open_issues_count":0,"forks_count":0,"subscribers_count":1,"default_branch":"main","last_synced_at":"2025-06-04T19:59:53.366Z","etag":null,"topics":["decision-trees","machine-learning","prune","visualization"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":null,"status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/WeltXing.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null}},"created_at":"2021-09-11T09:40:11.000Z","updated_at":"2024-10-28T14:41:04.000Z","dependencies_parsed_at":null,"dependency_job_id":"83acccb4-d307-4749-bdd1-3a183ffbb85a","html_url":"https://github.com/WeltXing/PyDT","commit_stats":null,"previous_names":["weltxing/pydt","kaslanarian/pydt"],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/WeltXing/PyDT","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/WeltXing%2FPyDT","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/WeltXing%2FPyDT/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/WeltXing%2FPyDT/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/WeltXing%2FPyDT/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/WeltXing","download_url":"https://codeload.github.com/WeltXing/PyDT/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/WeltXing%2FPyDT/sbom","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":259768621,"owners_count":22908232,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["decision-trees","machine-learning","prune","visualization"],"created_at":"2025-06-14T06:07:57.560Z","updated_at":"2025-06-14T06:07:58.176Z","avatar_url":"https://github.com/WeltXing.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# PyDecisionTree\n\n决策树分类与回归模型，以及可视化\n\n- [x] ID3分类树\n- [x] C4.5分类树\n- [x] CART分类树\n- [x] CART回归树\n- [x] 决策树可视化\n- [x] REP剪枝\n- [x] PEP剪枝\n- [x] CCP剪枝\n\n## ID3\n\nID3决策树是最朴素的决策树分类器：\n\n- 无剪枝\n- 只支持离散属性\n- 采用信息增益准则\n\n在`data.py`中，我们记录了一个小的西瓜数据集，用于离散属性的二分类任务。我们可以像下面这样训练一个ID3决策树分类器：\n\n```python\nfrom ID3 import ID3Classifier\nfrom data import load_watermelon2\nimport numpy as np\n\nX, y = load_watermelon2(return_X_y=True) # 函数参数仿照sklearn.datasets\nmodel = ID3Classifier()\nmodel.fit(X, y)\npred = model.predict(X)\nprint(np.mean(pred == y))\n```\n\n输出1.0，说明我们生成的决策树是正确的。\n\n## C4.5\n\nC4.5决策树分类器对ID3进行了改进：\n\n- 用信息增益率的启发式方法来选择划分特征；\n- 能够处理离散型和连续型的属性类型，即将连续型的属性进行离散化处理；\n- 剪枝；\n- 能够处理具有缺失属性值的训练数据；\n\n我们实现了前两点，以及第三点中的预剪枝功能（超参数）\n\n在`data.py`中还有一个连续离散特征混合的西瓜数据集，我们用它来测试C4.5决策树的效果：\n\n```python\nfrom C4_5 import C4_5Classifier\nfrom data import load_watermelon3\nimport numpy as np\n\nX, y = load_watermelon3(return_X_y=True) # 函数参数仿照sklearn.datasets\nmodel = C4_5Classifier()\nmodel.fit(X, y)\npred = model.predict(X)\nprint(np.mean(pred == y))\n```\n\n输出1.0，说明我们生成的决策树正确.\n\n## CART\n\n### 分类\n\nCART(Classification and Regression Tree)是C4.5决策树的扩展，支持分类和回归。CART分类树算法使用基尼系数选择特征，此外对于离散特征，CART决策树在每个节点二分划分，缓解了过拟合。\n\n这里我们用`sklearn`中的鸢尾花数据集测试：\n\n```python\nfrom CART import CARTClassifier\nfrom sklearn.datasets import load_iris\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.metrics import accuracy_score\n\nX, y = load_iris(return_X_y=True)\ntrain_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.7)\nmodel = CARTClassifier()\nmodel.fit(train_X, train_y)\npred = model.predict(test_X)\nprint(accuracy_score(test_y, pred))\n```\n\n准确率95.55%。\n\n### 回归\n\n`CARTRegressor`类实现了决策树回归，以`sklearn`的波士顿数据集为例：\n\n```python\nfrom CART import CARTRegressor\nfrom sklearn.datasets import load_boston\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.metrics import mean_squared_error\n\nX, y = load_boston(return_X_y=True)\ntrain_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.7)\nmodel = CARTRegressor()\nmodel.fit(train_X, train_y)\npred = model.predict(test_X)\nprint(mean_squared_error(test_y, pred))\n```\n\n输出26.352171052631576，sklearn决策树回归的Baseline是22.46，性能近似，说明我们的实现正确。\n\n## 决策树绘制\n\n### 分类树\n\n利用python3的graphviz第三方库和[Graphviz](https://graphviz.org/)(需要安装)，我们可以将决策树可视化：\n\n```python\nfrom plot import tree_plot\nfrom CART import CARTClassifier\nfrom sklearn.datasets import load_iris\n\nX, y = load_iris(return_X_y=True)\nmodel = CARTClassifier()\nmodel.fit(X, y)\ntree_plot(model)\n```\n\n运行，文件夹中生成`tree.png`：\n\n![iris_tree](src/iris_tree.png)\n\n如果提供了特征的名词和标签的名称，决策树会更明显：\n\n```python\nfrom plot import tree_plot\nfrom CART import CARTClassifier\nfrom sklearn.datasets import load_iris\n\niris = load_iris()\nmodel = CARTClassifier()\nmodel.fit(iris.data, iris.target)\ntree_plot(model,\n          filename=\"tree2\",\n          feature_names=iris.feature_names,\n          target_names=iris.target_names)\n```\n\n![iris_tree2](src/iris_tree2.png)\n\n绘制西瓜数据集2对应的ID3决策树：\n\n```python\nfrom plot import tree_plot\nfrom ID3 import ID3Classifier\nfrom data import load_watermelon2\n\nwatermelon = load_watermelon2()\nmodel = ID3Classifier()\nmodel.fit(watermelon.data, watermelon.target)\ntree_plot(\n    model,\n    filename=\"tree\",\n    font=\"SimHei\",\n    feature_names=watermelon.feature_names,\n    target_names=watermelon.target_names,\n)\n```\n\n这里要自定义字体，否则无法显示中文：\n\n![watermelon](src/watermelon_tree.png)\n\n### 回归树\n\n用同样的方法，我们可以进行回归树的绘制：\n\n```python\nfrom plot import tree_plot\nfrom ID3 import ID3Classifier\nfrom sklearn.datasets import load_boston\n\nboston = load_boston()\nmodel = ID3Classifier(max_depth=5)\nmodel.fit(boston.data, boston.target)\ntree_plot(\n    model,\n    feature_names=boston.feature_names,\n)\n```\n\n由于生成的回归树很大，我们限制最大深度再绘制：\n\n![regression](src/boston_tree.png)\n\n## 调参\n\nCART和C4.5都是有超参数的，我们让它们作为`sklearn.base.BaseEstimator`的派生类，借助`sklearn`的GridSearchCV，就可以实现调参：\n\n```python\nfrom plot import tree_plot\nfrom CART import CARTClassifier\nfrom sklearn.datasets import load_wine\nfrom sklearn.model_selection import train_test_split, GridSearchCV\n\nwine = load_wine()\ntrain_X, test_X, train_y, test_y = train_test_split(\n    wine.data,\n    wine.target,\n    train_size=0.7,\n)\nmodel = CARTClassifier()\ngrid_param = {\n    'max_depth': [2, 4, 6, 8, 10],\n    'min_samples_leaf': [1, 3, 5, 7],\n}\n\nsearch = GridSearchCV(model, grid_param, n_jobs=4, verbose=5)\nsearch.fit(train_X, train_y)\nbest_model = search.best_estimator_\nprint(search.best_params_, search.best_estimator_.score(test_X, test_y))\ntree_plot(\n    best_model,\n    feature_names=wine.feature_names,\n    target_names=wine.target_names,\n)\n```\n\n输出最优参数和最优模型在测试集上的表现：\n\n```python\n{'max_depth': 4, 'min_samples_leaf': 3} 0.8518518518518519\n```\n\n绘制对应的决策树：\n\n![wine](src/wine_tree.png)\n\n## 剪枝\n\n在ID3和CART回归中加入了REP剪枝，C4.5则支持了PEP剪枝，CART分类树使用的是CCP剪枝。\n\n剪枝参考: \u003chttps://welts.xyz/2021/09/27/prune/\u003e\n\n### PEP剪枝\n\n对IRIS数据集训练后的决策树进行PEP剪枝：\n\n```python\niris = load_iris()\nmodel = C4_5Classifier()\nX, y = iris.data, iris.target\ntrain_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.7)\nmodel.fit(train_X, train_y)\nprint(model.score(test_X, test_y))\ntree_plot(model,\n          filename=\"src/pre_prune\",\n          feature_names=iris.feature_names,\n          target_names=iris.target_names)\nmodel.pep_pruning()\nprint(model.score(test_X, test_y))\ntree_plot(model,\n          filename=\"src/post_prune\",\n          feature_names=iris.feature_names,\n          target_names=iris.target_names,\n)\n```\n\n剪枝前后的准确率分别为97.78%，100%，即泛化性能的提升：\n\n\u003cimg src=\"src/pre_prune.png\" alt=\"pre\" style=\"zoom:60%;\" /\u003e![pre](src/post_prune.png)\n\n### CCP剪枝\n\n对IRIS数据集训练后的决策树进行CCP剪枝：\n\n剪枝前后：\n\n\u003cimg src=\"src/pre_ccp.png\" alt=\"pre\" style=\"zoom:40%;\" /\u003e\u003cimg src=\"src/post_ccp.png\" alt=\"post\" style=\"zoom:50%;\" /\u003e\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fweltxing%2Fpydt","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fweltxing%2Fpydt","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fweltxing%2Fpydt/lists"}