https://github.com/byte-sourcerer/array_contract
Contract programming for Pytorch, Numpy
https://github.com/byte-sourcerer/array_contract
contracts ndarray numpy python python3 pytorch tensor
Last synced: 5 months ago
JSON representation
Contract programming for Pytorch, Numpy
- Host: GitHub
- URL: https://github.com/byte-sourcerer/array_contract
- Owner: byte-sourcerer
- Created: 2020-08-08T14:30:52.000Z (over 5 years ago)
- Default Branch: master
- Last Pushed: 2020-08-14T06:53:09.000Z (over 5 years ago)
- Last Synced: 2025-09-14T09:02:19.013Z (5 months ago)
- Topics: contracts, ndarray, numpy, python, python3, pytorch, tensor
- Language: Python
- Homepage: https://pypi.org/project/arraycontract/
- Size: 8.79 KB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# ArrayContract
```python
from arraycontract import shape, _
import torch
@shape(x=(_, 'N'), y=('N', _))
def matrix_dot(x, y):
return x @ y
matrix_dot(torch.rand(3,4), torch.rand(4,5)) # OK
matrix_dot(torch.rand(3,4), torch.rand(3,5)) # raise AssertionError
```
```python
from arraycontract import shape, _
import torch
from torch import nn
linear = nn.Linear(3, 4)
@shape((..., 3))
def forward_linear(x):
"""
requires x.shape[-1] == 3
"""
return linear(x)
forward_linear(torch.rand(4,5,3)) # OK
forward_linear(torch.rand(4,4)) # raise AssertionError
```
```python
from arraycontract import dtype
from arraycontract import ndim
import torch
@ndim(x=3, y=4)
def ndim_contract(x, y):
print("requires x.ndim == 3 and y.ndim == 4")
@dtype(x=torch.long)
def dtype_contract(x):
print("requires x.dtype == torch.long")
```
```python
from arraycontract import Trigger
from arraycontract import dtype
import torch
Trigger.dtype_check_trigger = False
@dtype(x=torch.long)
def dtype_contract(x):
print("not requires x.dtype == torch.long")
dtype_contract(torch.rand(3, 4).float()) # OK
```