Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/csteinmetz1/auraloss
Collection of audio-focused loss functions in PyTorch
https://github.com/csteinmetz1/auraloss
audio loss-functions pytorch
Last synced: 3 months ago
JSON representation
Collection of audio-focused loss functions in PyTorch
- Host: GitHub
- URL: https://github.com/csteinmetz1/auraloss
- Owner: csteinmetz1
- License: apache-2.0
- Created: 2020-10-26T01:44:24.000Z (about 4 years ago)
- Default Branch: main
- Last Pushed: 2024-05-22T17:29:41.000Z (6 months ago)
- Last Synced: 2024-07-18T20:34:44.354Z (4 months ago)
- Topics: audio, loss-functions, pytorch
- Language: Python
- Homepage:
- Size: 130 KB
- Stars: 693
- Watchers: 18
- Forks: 68
- Open Issues: 21
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# auraloss
A collection of audio-focused loss functions in PyTorch.
[[PDF](https://www.christiansteinmetz.com/s/DMRN15__auraloss__Audio_focused_loss_functions_in_PyTorch.pdf)]
## Setup
```
pip install auraloss
```If you want to use `MelSTFTLoss()` or `FIRFilter()` you will need to specify the extra install (librosa and scipy).
```
pip install auraloss[all]
```## Usage
```python
import torch
import auralossmrstft = auraloss.freq.MultiResolutionSTFTLoss()
input = torch.rand(8,1,44100)
target = torch.rand(8,1,44100)loss = mrstft(input, target)
```**NEW**: Perceptual weighting with mel scaled spectrograms.
```python
bs = 8
chs = 1
seq_len = 131072
sample_rate = 44100# some audio you want to compare
target = torch.rand(bs, chs, seq_len)
pred = torch.rand(bs, chs, seq_len)# define the loss function
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
fft_sizes=[1024, 2048, 8192],
hop_sizes=[256, 512, 2048],
win_lengths=[1024, 2048, 8192],
scale="mel",
n_bins=128,
sample_rate=sample_rate,
perceptual_weighting=True,
)# compute
loss = loss_fn(pred, target)```
## Citation
If you use this code in your work please consider citing us.
```bibtex
@inproceedings{steinmetz2020auraloss,
title={auraloss: {A}udio focused loss functions in {PyTorch}},
author={Steinmetz, Christian J. and Reiss, Joshua D.},
booktitle={Digital Music Research Network One-day Workshop (DMRN+15)},
year={2020}
}
```# Loss functions
We categorize the loss functions as either time-domain or frequency-domain approaches.
Additionally, we include perceptual transforms.
Loss function
Interface
Reference
Time domain
Error-to-signal ratio (ESR)
auraloss.time.ESRLoss()
Wright & Välimäki, 2019
DC error (DC)
auraloss.time.DCLoss()
Wright & Välimäki, 2019
Log hyperbolic cosine (Log-cosh)
auraloss.time.LogCoshLoss()
Chen et al., 2019
Signal-to-noise ratio (SNR)
auraloss.time.SNRLoss()
Scale-invariant signal-to-distortion
ratio (SI-SDR)
auraloss.time.SISDRLoss()
Le Roux et al., 2018
Scale-dependent signal-to-distortion
ratio (SD-SDR)
auraloss.time.SDSDRLoss()
Le Roux et al., 2018
Frequency domain
Aggregate STFT
auraloss.freq.STFTLoss()
Arik et al., 2018
Aggregate Mel-scaled STFT
auraloss.freq.MelSTFTLoss(sample_rate)
Multi-resolution STFT
auraloss.freq.MultiResolutionSTFTLoss()
Yamamoto et al., 2019*
Random-resolution STFT
auraloss.freq.RandomResolutionSTFTLoss()
Steinmetz & Reiss, 2020
Sum and difference STFT loss
auraloss.freq.SumAndDifferenceSTFTLoss()
Steinmetz et al., 2020
Perceptual transforms
Sum and difference signal transform
auraloss.perceptual.SumAndDifference()
FIR pre-emphasis filters
auraloss.perceptual.FIRFilter()
Wright & Välimäki, 2019
\* [Wang et al., 2019](https://arxiv.org/abs/1904.12088) also propose a multi-resolution spectral loss (that [Engel et al., 2020](https://arxiv.org/abs/2001.04643) follow),
but they do not include both the log magnitude (L1 distance) and spectral convergence terms, introduced in [Arik et al., 2018](https://arxiv.org/abs/1808.0671), and then extended for the multi-resolution case in [Yamamoto et al., 2019](https://arxiv.org/abs/1910.11480).## Examples
Currently we include an example using a set of the loss functions to train a TCN for modeling an analog dynamic range compressor.
For details please refer to the details in [`examples/compressor`](examples/compressor).
We provide pre-trained models, evaluation scripts to compute the metrics in the [paper](https://www.christiansteinmetz.com/s/DMRN15__auraloss__Audio_focused_loss_functions_in_PyTorch.pdf), as well as scripts to retrain models.There are some more advanced things you can do based upon the `STFTLoss` class.
For example, you can compute both linear and log scaled STFT errors as in [Engel et al., 2020](https://arxiv.org/abs/2001.04643).
In this case we do not include the spectral convergence term.
```python
stft_loss = auraloss.freq.STFTLoss(
w_log_mag=1.0,
w_lin_mag=1.0,
w_sc=0.0,
)
```There is also a Mel-scaled STFT loss, which has some special requirements.
This loss requires you set the sample rate as well as specify the correct device.
```python
sample_rate = 44100
melstft_loss = auraloss.freq.MelSTFTLoss(sample_rate, device="cuda")
```You can also build a multi-resolution Mel-scaled STFT loss with 64 bins easily.
Make sure you pass the correct device where the tensors you are comparing will be.
```python
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
scale="mel",
n_bins=64,
sample_rate=sample_rate,
device="cuda"
)
```If you are computing a loss on stereo audio you may want to consider the sum and difference (mid/side) loss.
Below we have shown an example of using this loss function with the perceptual weighting and mel scaling for
further perceptual relevance.```python
target = torch.rand(8, 2, 44100)
pred = torch.rand(8, 2, 44100)loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss(
fft_sizes=[1024, 2048, 8192],
hop_sizes=[256, 512, 2048],
win_lengths=[1024, 2048, 8192],
perceptual_weighting=True,
sample_rate=44100,
scale="mel",
n_bins=128,
)loss = loss_fn(pred, target)
```# Development
Run tests locally with pytest.
```python -m pytest```