https://github.com/srush/anynp
Proof-of-concept of global switching between numpy/jax/pytorch in a library.
https://github.com/srush/anynp
Last synced: about 1 year ago
JSON representation
Proof-of-concept of global switching between numpy/jax/pytorch in a library.
- Host: GitHub
- URL: https://github.com/srush/anynp
- Owner: srush
- Created: 2024-06-18T15:35:51.000Z (almost 2 years ago)
- Default Branch: master
- Last Pushed: 2024-06-18T15:54:53.000Z (almost 2 years ago)
- Last Synced: 2025-04-12T15:13:23.536Z (about 1 year ago)
- Language: Python
- Size: 7.81 KB
- Stars: 18
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# anynp
Proof-of-concept of global switching between numpy/jax/pytorch in a library.
This is a wrapper around the Array API and the `array_api_compat` lib. It adds a stub for the `np` type so that mypy doesn't complain and a context handler.
```python
from switcher import switch, XP
from numpy.typing import ArrayLike
def my_fun(x) -> ArrayLike:
array = switch.xp.asarray(x)
print(type(array))
return array
x = my_fun([10])
with switch.set_context(XP.Jax):
x = my_fun([10])
print(x.at[0].set(0))
with switch.set_context(XP.Torch):
my_fun([10])
my_fun([20])
```