Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/kelvintechnical/decision-tree-classifier
https://github.com/kelvintechnical/decision-tree-classifier
Last synced: 27 days ago
JSON representation
- Host: GitHub
- URL: https://github.com/kelvintechnical/decision-tree-classifier
- Owner: kelvintechnical
- Created: 2024-11-18T21:01:50.000Z (about 2 months ago)
- Default Branch: main
- Last Pushed: 2024-11-18T22:32:10.000Z (about 2 months ago)
- Last Synced: 2024-11-18T23:31:23.659Z (about 2 months ago)
- Language: Python
- Size: 57.6 KB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
#
Decision Tree Classifier
Welcome to the Decision Tree Classifier project! This repository contains code to implement a basic machine learning classifier using the Decision Tree algorithm. A decision tree classifies data points by splitting the data into branches based on feature values and ultimately assigning a label to each point.
---
##
📫 How to reach me:
-
Email: [email protected] -
LinkedIn: Kelvin R. Tobias -
Bluesky: @kelvintechnical.bsky.social -
Instagram: @kelvinintech
---
##
Project Overview
The Decision Tree Classifier is a simple and effective classification algorithm. In this project, we use Python libraries like pandas
, scikit-learn
, and matplotlib
to build and visualize the decision tree. The goal is to classify data points based on specific features and evaluate the model's accuracy.
---
##
5 Things I Learned from This Project
-
The Purpose of Imports: Each Python library has specific functionalities that simplify machine learning workflows. -
Data Splitting: Dividing data into training and testing sets ensures the model can generalize to unseen data. -
Model Visualization: Plotting the decision tree helped me understand how the algorithm splits data at each node. -
Parameter Tuning: Adjusting parameters likemax_depth
can simplify the model and reduce overfitting. -
Evaluation Metrics: Accuracy alone might not always be enough, and exploring additional metrics like confusion matrices can help evaluate models better.
---
##
Code Explanation
Below is an overview of the key components of the code:
-
from sklearn.tree import DecisionTreeClassifier, plot_tree
: Imports the Decision Tree Classifier for building the model and the plotting function for visualizing the tree. -
from sklearn.model_selection import train_test_split
: Splits the dataset into training and testing sets. -
from sklearn.metrics import accuracy_score
: Calculates the model's accuracy by comparing predictions to actual labels. -
import pandas as pd
: Used for handling the dataset as a DataFrame for easier data manipulation. -
import matplotlib.pyplot as plt
: Visualizes the decision tree with a readable plot.
```python
# Importing necessary libraries
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import pandas as pd
import matplotlib.pyplot as plt
# Create a simple dataset
data = {
'Feature1': [1, 2, 3, 4, 5],
'Feature2': [5, 4, 3, 2, 1],
'Label': [0, 1, 0, 1, 0]
}
# Convert the dataset into a pandas DataFrame
df = pd.DataFrame(data)
# Separate the features (X) and labels (y)
X = df[['Feature1', 'Feature2']] # Features
y = df['Label'] # Labels
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Create the Decision Tree Classifier
classifier = DecisionTreeClassifier(max_depth=3)
# Train the model using the training data
classifier.fit(X_train, y_train)
# Make predictions on the test set
y_pred = classifier.predict(X_test)
# Evaluate the model's accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f'Model Accuracy: {accuracy * 100:.2f}%')
# Visualize the Decision Tree
plt.figure(figsize=(12, 8)) # Adjust the plot size for readability
plot_tree(
classifier,
feature_names=['Feature1', 'Feature2'],
class_names=['Class 0', 'Class 1'],
filled=True,
fontsize=10,
rounded=True
)
plt.show()