Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/5brian/aep-model
Deep learning model using fine-tuned DistilBERT for classifying safety observations
https://github.com/5brian/aep-model
Last synced: about 1 month ago
JSON representation
Deep learning model using fine-tuned DistilBERT for classifying safety observations
- Host: GitHub
- URL: https://github.com/5brian/aep-model
- Owner: 5brian
- Created: 2024-10-19T17:47:53.000Z (3 months ago)
- Default Branch: main
- Last Pushed: 2024-11-30T05:05:30.000Z (about 2 months ago)
- Last Synced: 2024-12-06T04:21:02.714Z (about 2 months ago)
- Language: Python
- Homepage:
- Size: 6.15 MB
- Stars: 0
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Safety Classification Project
This project uses machine learning to classify safety comments into high priority and standard priority categories. It fine-tunes a deep learning transformer-based neural network [DistilBERT](https://huggingface.co/distilbert/distilbert-base-uncased), adapting it to the specific task of safety observation classification.
## Project Structure
- `safety_classification.py`: Script for training the model, evaluating its performance, and generating results for the entire dataset.
- `classify_new_inputs.py`: Interactive script for classifying new safety comments using the trained model.
- `requirements.txt`: List of Python packages required for the project.## Usage: Training on Your Own Data
### Training the Model and Generating Results
1. Prepare your data:
- Create a CSV file with at least two columns: one for the safety comments and one for the priority labels.
- Name your columns exactly as they are in the original dataset or update the column names in the `safety_classification.py` script.
- Move your CSV file to the `data/` directory.
- Rename it to `data.csv` or update the file path in `safety_classification.py`.2. Run the training and classification script:
```
python3 safety_classification.py
```3. This script will:
- Train the model on the included dataset
- Evaluate its performance
- Save the model and tokenizer in the `saved_model` and `saved_tokenizer` directories respectively
- Classify all comments in the original dataset
- Save the results (including classifications) to `results/output_data.csv`### Interactively Classifying New Inputs
1. After training the model, you can use it to classify new safety comments interactively:
```
python3 classify_new_inputs.py
```2. This script will:
- Load the trained model and tokenizer
- Prompt you to enter safety comments
- Display the classification (High Priority or Standard Priority) and confidence score for each entered comment3. Enter your safety comments one at a time and press `Enter` to see the classification result.
4. Type 'quit' to exit the program.
## Our Training Data Results
```
Evaluation Results:
Accuracy: 0.9910
F1 Score: 0.9833
Precision: 0.9925
Recall: 0.9742
```## Notes
- The scripts are set to use CPU for computations. If you have a GPU and want to use it, set USE_CPU to False in both scripts.
```python
# Configurable variables
USE_CPU = False
```- This project utilizes the DistilBERT base model (uncased) for sequence classification. DistilBERT is a smaller, faster version of BERT, developed by Hugging Face. DistilBERT is a transformers model, pretrained on the same corpus as BERT in a self-supervised fashion, using the BERT base model as a teacher. It's designed for tasks that use the whole sentence to make decisions, such as sequence classification, token classification, or question answering.
## Troubleshooting
If you encounter any issues:
- Ensure all dependencies are correctly installed.
- Make sure you've run `safety_classification.py` before trying to use `classify_new_inputs.py`.
- Training the model will take a while, and will be physically demanding on your machines RAM. If this is an issue, the MAX_STEPS variable can be decreased, or another model can be used.## Contributors
- Reid Ammer ([email protected])
- Brian Tan ([email protected])