https://github.com/mddct/losses
ctc releated
https://github.com/mddct/losses
Last synced: 10 months ago
JSON representation
ctc releated
- Host: GitHub
- URL: https://github.com/mddct/losses
- Owner: Mddct
- Created: 2022-05-09T09:21:24.000Z (about 4 years ago)
- Default Branch: main
- Last Pushed: 2023-06-01T01:29:04.000Z (about 3 years ago)
- Last Synced: 2025-05-07T06:44:58.760Z (about 1 year ago)
- Language: Python
- Homepage:
- Size: 98.6 KB
- Stars: 5
- Watchers: 1
- Forks: 1
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# losses
loss functions associated with ctc and cif
(Note): ctc decoder binding from wenet runtime
TODO:
- [ ] entmax and entmax losses
- [x] kd ctc decodernbest strategey
- [x] suport batch ctc decode not parallel
- [x] suport batch ctc decode parallel
- [ ] suport chunk state ctc decode
- [ ] suport torch sparse tensor
- [x] sequence focal loss
- [x] cross entropy focal loss
- [x] sigmod focal loss
- [ ] focal logits for mwer
- [x] mwer loss ssupport
```python
import torch
from torch.nn.utils.rnn import pad_sequence
from ctcdecoder import CTCDecoder
from edit_distance import edit_distance
inputs = torch.tensor(
[[[0.25, 0.40, 0.35],
[0.40, 0.35, 0.25],
[0.10, 0.50, 0.40]]]);
inputs = inputs.log()
seq_len = torch.tensor([3])
decoder = CTCDecoder(3,3)
print(decoder.decode(inputs, seq_len))
# print(pad_sequence(decoder.decode(inputs, seq_len), batch_first=True, padding_value=-1))
#tensor([[ 2, 1],
# [ 1, 2],
# [ 1, -1]])
#
hyp = torch.tensor([[1,2,3], [1,2,3]])
hyp_lens = torch.tensor([3,3])
truth = torch.tensor([[4,5,6], [4, 5, 6]])
t_lens = torch.tensor([3,3])
print(edit_distance(hyp,hyp_lens,truth, t_lens)
mwer=CTCMWERLoss(8)
labels=torch.tensor([[1,0,2]])
labels_length = torch.tensor([3])
print(mwer.forward(inputs, labels, labels_length, torch.tensor(3)))
#tensor(0.0136)
```