https://github.com/renovamen/torchmasked
Masked tensor operations for PyTorch.
https://github.com/renovamen/torchmasked
mask pytorch
Last synced: 8 months ago
JSON representation
Masked tensor operations for PyTorch.
- Host: GitHub
- URL: https://github.com/renovamen/torchmasked
- Owner: Renovamen
- License: mit
- Created: 2021-11-25T04:10:44.000Z (almost 4 years ago)
- Default Branch: main
- Last Pushed: 2021-11-25T09:31:53.000Z (almost 4 years ago)
- Last Synced: 2025-02-07T22:04:55.907Z (8 months ago)
- Topics: mask, pytorch
- Language: Python
- Homepage:
- Size: 10.7 KB
- Stars: 0
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# torchmasked
Tensor operations with mask for PyTorch.
[](https://pypi.org/project/torchmasked/) [](https://github.com/Renovamen/torchmasked/blob/main/LICENSE) [](https://github.com/Renovamen/torchmasked/actions/workflows/unittest.yaml)
Sometimes you need to perform operations on PyTorch tensors with the masked elements been ignored, for example:
```python
>>> input = torch.tensor([1., 2., 3.])
>>> result = torch.sum(input)
>>> print(result)tensor(6.)
>>> mask = torch.tensor([1, 1, 0]).byte()
>>> masked_result = torchmasked.masked_sum(input, mask)
>>> print(masked_result)tensor(3.) # element input[2] is masked and ignored
```Then this package could be helpful.
## Requirements
Tested on Python 3.6+ and PyTorch 1.4+.
## Installation
From PyPI:
```bash
pip install torchmasked
```From source:
```bash
pip install git+https://github.com/Renovamen/torchmasked.git --upgrade# or
python setup.py install
```
## Supported Operations
The usage is the same as PyTorch's original functions. Refer to [PyTorch documentation](https://pytorch.org/docs/stable/index.html) or the [source code](torchmasked) for details.
- [`torchmasked.masked_max`](torchmasked/functional.py) (masked version of [`torch.max`](https://pytorch.org/docs/stable/generated/torch.max.html))
- [`torchmasked.masked_min`](torchmasked/functional.py) ([`torch.min`](https://pytorch.org/docs/stable/generated/torch.min.html))
- [`torchmasked.masked_sum`](torchmasked/functional.py) ([`torch.sum`](https://pytorch.org/docs/stable/generated/torch.sum.html))
- [`torchmasked.masked_mean`](torchmasked/functional.py) ([`torch.mean`](https://pytorch.org/docs/stable/generated/torch.min.html))
- [`torchmasked.masked_softmax`](torchmasked/functional.py) ([`torch.nn.functional.softmax`](https://pytorch.org/docs/stable/generated/torch.nn.functional.softmax.html)) / [`torchmasked.nn.MaskedSoftmax`](torchmasked/nn.py) ([`torch.nn.Softmax`](https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html))
## License
[MIT](LICENSE)