https://github.com/esceptico/squeezer
Lightweight knowledge distillation pipeline
https://github.com/esceptico/squeezer
distillation knowledge-distillation model-compression pytorch
Last synced: 11 months ago
JSON representation
Lightweight knowledge distillation pipeline
- Host: GitHub
- URL: https://github.com/esceptico/squeezer
- Owner: esceptico
- License: mit
- Created: 2021-10-17T14:23:51.000Z (over 4 years ago)
- Default Branch: master
- Last Pushed: 2021-11-29T08:30:34.000Z (over 4 years ago)
- Last Synced: 2024-10-19T17:30:36.203Z (over 1 year ago)
- Topics: distillation, knowledge-distillation, model-compression, pytorch
- Language: Jupyter Notebook
- Homepage:
- Size: 116 KB
- Stars: 28
- Watchers: 4
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Squeezer (WIP)
## Usage
### Step 1: Define Distiller class
Implement `teacher_forward`, `student_forward`
and (if required) `move_batch_to_device` methods.
```python
from squeezer import Distiller
class CustomDistiller(Distiller):
def teacher_forward(self, batch):
return self.teacher(batch['data'])
def student_forward(self, batch):
return self.student(batch['data'])
```
### Step 2: Define LossPolicy
```python
from torch.nn.functional import mse_loss
from squeezer import AbstractDistillationPolicy
class DistillationPolicy(AbstractDistillationPolicy):
def forward(self, teacher_output, student_output, batch, epoch):
loss_mse = mse_loss(student_output, teacher_output)
loss_dict = {'mse': loss_mse.item()}
return loss_mse, loss_dict
```
### Step 3: Fit
```python
from torch import optim
from squeezer.logging import TensorboardLogger
train_loader = ...
teacher = Teacher()
student = Student()
logger = TensorboardLogger('runs', 'experiment')
optimizer = optim.AdamW(student.parameters(), lr=3e-4)
policy = DistillationPolicy()
distiller = CustomDistiller(teacher, student, policy, optimizer=optimizer, logger=logger)
distiller(train_loader, n_epochs=10)
distiller.save('path_to_some_directory')
```