Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/michedev/tensorguard
TensorGuard helps to guard against bad Tensor shapes in any tensor based library
https://github.com/michedev/tensorguard
deep-learning numpy pytorch tensorflow
Last synced: 13 days ago
JSON representation
TensorGuard helps to guard against bad Tensor shapes in any tensor based library
- Host: GitHub
- URL: https://github.com/michedev/tensorguard
- Owner: Michedev
- License: apache-2.0
- Created: 2021-03-04T08:45:16.000Z (over 3 years ago)
- Default Branch: master
- Last Pushed: 2021-05-22T15:29:25.000Z (over 3 years ago)
- Last Synced: 2024-10-17T12:41:04.626Z (30 days ago)
- Topics: deep-learning, numpy, pytorch, tensorflow
- Language: Python
- Homepage:
- Size: 81.1 KB
- Stars: 2
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Tensor Guard
[![PyPI version fury.io](https://badge.fury.io/py/tensorguard.svg)](https://pypi.python.org/pypi/tensorguard/)
[![PyPI pyversions](https://img.shields.io/pypi/pyversions/tensorguard.svg)](https://pypi.python.org/pypi/tensorguard/)
[![PyPI download month](https://img.shields.io/pypi/dm/tensorguard.svg)](https://pypi.python.org/pypi/tensorguard/)
[![GitHub followers](https://img.shields.io/github/followers/Michedev.svg?style=social&label=Follow&maxAge=2592000)](https://github.com/Michedev?tab=followers)TensorGuard helps to guard against bad Tensor shapes in any tensor based library (e.g. Numpy, Pytorch, Tensorflow) using an intuitive symbolic-based syntax
### Installation
`pip install tensorguard`## Basic Usage
```python
import numpy as np # could be tensorflow or torch as well
import tensorguard as tg# tensorguard = tg.TensorGuard() #could be done in a OOP fashion
img = np.ones([64, 32, 32, 3])
flat_img = np.ones([64, 1024])
labels = np.ones([64])# check shape consistency
tg.guard(img, "B, H, W, C")
tg.guard(labels, "B, 1") # raises error because of rank mismatch
tg.guard(flat_img, "B, H*W*C") # raises error because 1024 != 32*32*3# guard also returns the tensor, so it can be inlined
mean_img = tg.guard(np.mean(img, axis=0), "H, W, C")# more readable reshapes
flat_img = tg.reshape(img, 'B, H*W*C')# evaluate templates
assert tg.get_dims('H, W*C+1') == [32, 97]```
## Shape Template Syntax
The shape template mini-DSL supports many different ways of specifying shapes:* numbers: `"64, 32, 32, 3"`
* named dimensions: `"B, width, height2, channels"`
* wildcards: `"B, *, *, *"`
* ellipsis: `"B, ..., 3"`
* addition, subtraction, multiplication, division: `"B*N, W/2, H*(C+1)"`
* dynamic dimensions: `"?, H, W, C"` *(only matches `[None, H, W, C]`)*### Original Repo link: https://github.com/Qwlouse/shapeguard