Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/srush/parallax
https://github.com/srush/parallax
Last synced: 3 months ago
JSON representation
- Host: GitHub
- URL: https://github.com/srush/parallax
- Owner: srush
- License: mit
- Created: 2020-05-19T01:49:51.000Z (over 4 years ago)
- Default Branch: master
- Last Pushed: 2020-05-25T18:24:30.000Z (over 4 years ago)
- Last Synced: 2024-08-01T13:31:30.459Z (6 months ago)
- Language: Python
- Size: 64.5 KB
- Stars: 157
- Watchers: 6
- Forks: 4
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- awesome-jax - Parallax - Immutable Torch Modules for JAX. <img src="https://img.shields.io/github/stars/srush/parallax?style=social" align="center"> (Libraries / New Libraries)
README
# Parallax - Immutable Torch Modules for JAX
Parallax is a prototype for a pure module system for JAX implemented by Sabrina Mielke (@sjmielke) and Sasha Rush (@srush).
Main ideas:
* Make param modules immutable trees.
* Replace all imperative style coding and init.
* Avoid tracking state for most applications by first distributing seeds / globals through tree.```python
from parallax import Module, Parameter, ParamInit
class Dense(Module):
# All parameter-holders are explicitly declared.
weight : Parameter
bias : Parameter# Setup replace __init__ and creates shapes and binds lazy initializers.
def __init__(self, in_size, out_size):
super().__init__()
self.weight = ParamInit((out_size, in_size), init.xavier_normal())
self.bias = ParamInit((out_size,), init.normal())# Forward is just like standard pytorch.
def forward(self, input):
return self.weight @ input + self.bias# Hook for pretty printing
def extra_repr(self):
return "%d, %d"%(self.weight.shape[1], self.weight.shape[0])class Dropout(Module):
# Arbitrary constants allowed.
rate : float
def __init__(self, rate):
super().__init__()
self.rate = ratedef forward(self, input):
# RNG state is use-once or split. Attached to tree.
state = self.rngif self.mode == "train":
keep = jax.random.bernoulli(state, self.rate, input.shape)
return jax.numpy.where(keep, input / self.rate, 0)
else:
return inputclass BinaryNetwork(Module):
# No difference between modules and parameters
dense1 : Dense
dense2 : Dense
dense3 : Dense
dropout : Dropoutdef __init__(self, input_size, hidden_size):
super().__init__()
self.dense1 = Dense(input_size, hidden_size)
self.dense2 = Dense(hidden_size, hidden_size)
self.dense3 = Dense(hidden_size, 1)
self.dropout = Dropout(0.2)def forward(self, input):
# Standard usage works out of the box.
x = jax.numpy.tanh(self.dense1(input))# Stochastic modules (have random seed already)
x = self.dropout(x)# Shared params / recurrence only requires split to change RNG
x = jax.numpy.tanh(self.dense2(x))
x = jax.numpy.tanh(self.dense2(x))return jax.nn.sigmoid(self.dense3(jax.numpy.tanh(x)))[0]
# Setup param tree -> declarative, immutable
layer = BinaryNetwork(5, 10)
print(layer)
print(layer.dense1)# Initialize parameters -> stateful, hidden
rng = jax.random.PRNGKey(0)
layer = layer.initialized(rng)
print(layer)
print(layer.dense1)initial_loss = None
for i in range(10):
# Thread state through parameters -> functor, hidden
rng, iter_rng = jax.random.split(rng)
layer = layer.new_state(iter_rng, mode="train")
# Jax style grad compute -> tree-shaped immutable
x = jax.numpy.zeros(5)
loss = layer(x)
if initial_loss is None:
initial_loss = loss
print(loss)
grad = layer.grad(x)
# Grad Update -> tree-shaped
layer = jax.tree_util.tree_multimap(lambda p, g: p - 0.3 * g, layer, grad)
```