https://github.com/asem000/serket
The ✨Magical✨ JAX ML Library.
https://github.com/asem000/serket
jax machine-learning neural-network
Last synced: 9 months ago
JSON representation
The ✨Magical✨ JAX ML Library.
- Host: GitHub
- URL: https://github.com/asem000/serket
- Owner: ASEM000
- License: apache-2.0
- Created: 2022-08-20T16:50:45.000Z (almost 4 years ago)
- Default Branch: main
- Last Pushed: 2025-01-25T18:53:51.000Z (over 1 year ago)
- Last Synced: 2025-04-13T11:11:52.879Z (about 1 year ago)
- Topics: jax, machine-learning, neural-network
- Language: Python
- Homepage: https://serket.rtfd.io
- Size: 46.3 MB
- Stars: 17
- Watchers: 1
- Forks: 0
- Open Issues: 4
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
The ✨Magical✨ JAX ML Library.
*Serket is the goddess of magic in Egyptian mythology
[**Installation**](#Installation)
|[**Description**](#Description)
|[**Documentation**](#Documentation)
|[**Quick Example**](#QuickExample)



[](https://codecov.io/gh/ASEM000/serket)
[](https://serket.readthedocs.io/?badge=latest)
[](https://zenodo.org/badge/latestdoi/526985786)
[](https://www.codefactor.io/repository/github/asem000/serket)
**Install development version**
```python
pip install git+https://github.com/ASEM000/serket
```
## 📖 Description and motivation
- `serket` aims to be the most intuitive and easy-to-use machine learning library in `jax`.
- `serket` is fully transparent to `jax` transformation (e.g. `vmap`,`grad`,`jit`,...).
## 📙 Documentation
- [Full documentation](https://serket.readthedocs.io/)
- [Train MNIST, UNet, ConvLSTM, PINN](https://serket.readthedocs.io/training_guides.html)
- [Model surgery, Parallelism, Mixed precision](https://serket.readthedocs.io/core_guides.html)
- [Optimizers, Augmentation composition](https://serket.readthedocs.io/other_guides.html)
- [Interoperability with keras, tensorflow](https://serket.readthedocs.io/interoperability.html)
```python
import jax, jax.numpy as jnp
import serket as sk
x_train, y_train = ..., ...
k1, k2 = jax.random.split(jax.random.key(0))
net = sk.tree_mask(sk.Sequential(
jnp.ravel,
sk.nn.Linear(28 * 28, 64, key=k1),
jax.nn.relu,
sk.nn.Linear(64, 10, key=k2),
))
@ft.partial(jax.grad, has_aux=True)
def loss_func(net, x, y):
logits = jax.vmap(sk.tree_unmask(net))(x)
onehot = jax.nn.one_hot(y, 10)
loss = jnp.mean(softmax_cross_entropy(logits, onehot))
return loss, (loss, logits)
@jax.jit
def train_step(net, x, y):
grads, (loss, logits) = loss_func(net, x, y)
net = jax.tree_map(lambda p, g: p - g * 1e-3, net, grads)
return net, (loss, logits)
for j, (xb, yb) in enumerate(zip(x_train, y_train)):
net, (loss, logits) = train_step(net, xb, yb)
accuracy = accuracy_func(logits, y_train)
net = sk.tree_unmask(net)
```
📚 Layers catalog
#### 🔗 Common API
| Group | Layers |
| ---------- | -------------------------------- |
| Containers | - `Sequential`, `Random{Choice}` |
#### 🧠 Neural network package: `serket.nn`
| Group | Layers |
| ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Attention | - `MultiHeadAttention` |
| Convolution | - `{FFT,_}Conv{1D,2D,3D}`
- `{FFT,_}Conv{1D,2D,3D}Transpose`
- `Depthwise{FFT,_}Conv{1D,2D,3D}`
- `Separable{FFT,_}Conv{1D,2D,3D}`
- `Conv{1D,2D,3D}Local`
- `SpectralConv{1D,2D,3D}` |
| Dropout | - `Dropout`
- `Dropout{1D,2D,3D}`
- `RandomCutout{1D,2D,3D}` |
| Linear | - `Linear`, `MLP`, `Identity` | |
| Normalization | - `{Layer,Instance,Group,Batch}Norm` |
| Pooling | - `{Avg,Max,LP}Pool{1D,2D,3D}`
- `Global{Avg,Max}Pool{1D,2D,3D}`
- `Adaptive{Avg,Max}Pool{1D,2D,3D}` |
| Reshaping | - `Upsample{1D,2D,3D}`
- `{Random,Center}Crop{1D,2D,3D}` ` |
| Recurrent cells | - `{SimpleRNN,LSTM,GRU,Dense}Cell`
- `{Conv,FFTConv}{LSTM,GRU}{1D,2D,3D}Cell` |
| Activations | - `Adaptive{LeakyReLU,ReLU,Sigmoid,Tanh}`,
- `CeLU`,`ELU`,`GELU`,`GLU`
- `Hard{SILU,Shrink,Sigmoid,Swish,Tanh}`,
- `Soft{Plus,Sign,Shrink}`
- `LeakyReLU`,`LogSigmoid`,`LogSoftmax`,`Mish`,`PReLU`,
- `ReLU`,`ReLU6`,`SeLU`,`Sigmoid`
- `Swish`,`Tanh`,`TanhShrink`, `ThresholdedReLU`, `Snake` |
#### 🖼️ Image package: `serket.image`
| Group | Layers |
| --------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Filter | - `{FFT,_}{Avg,Box,Gaussian,Motion}Blur2D`
- `{JointBilateral,Bilateral,Median}Blur2D`
- `{FFT,_}{UnsharpMask}2D`
- `{FFT,_}{Sobel,Laplacian}2D`
- `{FFT,_}BlurPool2D` |
| Augment | - `Adjust{Sigmoid,Log}2D`
- `{Adjust,Random}{Brightness,Contrast,Hue,Saturation}2D`,
- `RandomJigSaw2D`,`PixelShuffle2D`,
- `Pixelate2D`,`Posterize2D`,`Solarize2D`
- `FourierDomainAdapt2D` |
| Geometric | - `{Random,_}{Horizontal,Vertical}{Translate,Flip,Shear}2D`
- `{Random,_}{Rotate}2D`
- `RandomPerspective2D`
- `{FFT,_}ElasticTransform2D` |
| Color | - `RGBToGrayscale2D` , `GrayscaleToRGB2D`
- `RGBToHSV2D`, `HSVToRGB2D` |