Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/skaftenicki/pl_crossvalidate


https://github.com/skaftenicki/pl_crossvalidate

Last synced: 3 months ago
JSON representation

Awesome Lists containing this project

README

        

# PL Crossvalidate
[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/SkafteNicki/pl_crossvalidate/blob/master/LICENSE)
[![Tests](https://github.com/SkafteNicki/pl_crossvalidate/actions/workflows/tests.yaml/badge.svg)](https://github.com/SkafteNicki/pl_crossvalidate/blob/master/.github/workflows/tests.yaml)
[![codecov](https://codecov.io/gh/SkafteNicki/pl_crossvalidate/branch/master/graph/badge.svg)](https://codecov.io/gh/SkafteNicki/pl_crossvalidate)

Cross validation in pytorch lightning made easy :]

Just import the specialized trainer from `pl_crossvalidate` instead of `pytorch_lightning` and you are set
```python
# To distinguish from the original trainer the new trainer is called KFoldTrainer by default
from pl_crossvalidate import KFoldTrainer as Trainer

# Normal Lightning module
model = MyModel(...)

# Use a Lightning datamodule or training dataloader
datamodule = MyDatamodule(...)

# New trainer takes all original arguments + three new for controling the cross validation
trainer = Trainer(
num_folds=5, # number of folds to do
shuffle=False, # if samples should be shuffled before splitting
stratified=False, # if splitting should be done in a stratified manner
accelerator=...,
callbacks=...,
...
)

# Returns a dict of stats over the different splits
cross_val_stats = trainer.cross_validate(model, datamodule=datamodule)

# Additionally, we can construct an ensemble from the K trained models
ensemble_model = trainer.create_ensemble(model)
```

## 💻 Installation

```bash
pip install pl-crossvalidate
```

Or latest version from github

```bash
pip install https://github.com/SkafteNicki/pl_crossvalidate/archive/master.zip
```

Requires `torch>=2.0`, `lightning>=2.0` and `scikit-learn>=1.0`.

## 🤔 Cross-validation: why?

The core functionality of machine learning algorithms is that they are able to *learn* from data. Therefore, it is very
interesting to ask the question: how *well* does our algorithms actually learn?. This is in abstract question, because
it requires us to define what *well* means. One interpretation of this question is an algorithms ability to
*generalize* e.g. a model that generalizes well have actually learned something meaningfull.

The mathematical definition of the generalization error/expected loss/risk is given by



where is some function

denotes the loss function and is
the joint probability distribution between and
. This is the theoretical error an algorithm will do
on some unobserved dataset. The problem with this definition is that we cannot compute it, due to
being unknown and even if we knew it
the integral is intractable. The best we therefore can do is an *approximation* of the generalization error:



which measures the error that our function
does on datapoints measured by loss
function . This function we can
compute (just think of this as your normal loss function) and we even know that



Namely that approximation of the generalization error will become the true generalization error if we just evaluate it
on enough data. But how does all this related to cross-validation you may ask? The problem with the above is that
is not a fixed function, but
data-dependent function i.e. .
Thus, the above approximation will only converge if
and refers to different sets of
data points. This is where cross-validation strategies comes into play.

| Hold out | K-fold |
|----------|--------|
| |

In general we consider two viable strategies for selecting the
(validation) and
(training) set: hold-out validation
and K-fold cross validation. In hold out we create a separate independent set of data to evaluate our training on. This
is easily done in native pytorch-lightning by implementing the `validation_step` method. For K-fold we cut our data
into K equally large chunks and then we iteratively train on K-1 folds and evaluate on the remaining 1 fold, repeating
this K times. In general K-fold gives a better approximation of the generalization error than hold-out, but at the
expense of requiring you to train K models.

## 🗒️ Some notes

* For the `.cross_validate` method to work, we in addition to the standard set of method in lightning that need
to be implemented (`training_step` and `configure_optimizers`) we also requires the `test_step` method to be
implemented, as we use this method evaluating the hold out set. We do not rely on the `validation_step` method
as your models training may be dependent on the validation set (for example if you use early stopping) and your
validation set will therefore not be truly separated from the training.

* To do the splitting in cross validation we need the total number of data points in your dataset. For this reason,
we require that your dataset implements the `__len__` method.

* Cross validation is always done sequentially, even if the device you are training on in principal could
fit parallel training on multiple folds at the same time. We try to figure out in the future if we can
parallelize the process.

* Logging can be a bit weird. Logging of training progress is essentially not important to cross-validation,
but that does not mean that it is interesting to track. The cross-validation method will hijack the `version`
attribute of any logger attached to the trainer and set the logging directory to `f"{version}/fold_{fold_index}"`.

* Stratified splitting assume that we can extract a 1D label vector from your dataset.

* If your dataset has an `labels` attribute, we will use that as the labels
* If the attribute does not exist, we manually iterate over your dataset trying to extract the labels when creating
the splits (this is done as part of `.setup` phase of the datamodule). By default we assume that given a `batch`
the labels can be found as the second argument e.g. `batch[1]`. You can adjust this by importing the specialized
`KFoldDataModule` and changing the `label_extractor` attribute. For example, if your batches are dictionaries
instead you can do something like this:

```python
from pl_crossvalidate import KFoldDataModule, KFoldTrainer

model = ...

trainer = KFoldTrainer(...)

datamodule = KFoldDataModule(
num_folds, shuffle, stratified, # these should match how the trainer is initialized
train_dataloader=my_train_dataloader,
)
# change the label extractor function, such that it will return the labels for a given batch
datamodule.label_extractor = lambda batch: batch['y']

trainer.cross_validate(model, datamodule=datamodule)
```

## 😃 Bibtex

If you want to cite the framework feel free to use this:

```bibtex
@article{software:pl_crossvalidate,
title={PL Crossvalidate},
author={Nicki S. Detlefsen},
journal={GitHub. Note: https://github.com/SkafteNicki/pl_crossvalidate},
year={2023}
}
```