{"id":24886363,"url":"https://github.com/mzguntalan/zephyr","last_synced_at":"2025-05-09T00:04:40.542Z","repository":{"id":258094562,"uuid":"873620580","full_name":"mzguntalan/zephyr","owner":"mzguntalan","description":"Zephyr is a declarative neural network library on top of JAX allowing for easy and fast neural network designing, creation, and manipulation","archived":false,"fork":false,"pushed_at":"2024-12-09T04:27:45.000Z","size":217,"stargazers_count":35,"open_issues_count":0,"forks_count":0,"subscribers_count":5,"default_branch":"main","last_synced_at":"2025-05-09T00:04:34.602Z","etag":null,"topics":["deep-learning","deep-neural-networks","jax","machine-learning","neural-network"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/mzguntalan.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2024-10-16T13:23:33.000Z","updated_at":"2025-03-02T05:12:22.000Z","dependencies_parsed_at":"2024-12-09T05:32:04.276Z","dependency_job_id":null,"html_url":"https://github.com/mzguntalan/zephyr","commit_stats":null,"previous_names":["mzguntalan/zephyr"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mzguntalan%2Fzephyr","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mzguntalan%2Fzephyr/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mzguntalan%2Fzephyr/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mzguntalan%2Fzephyr/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/mzguntalan","download_url":"https://codeload.github.com/mzguntalan/zephyr/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":253166514,"owners_count":21864475,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","deep-neural-networks","jax","machine-learning","neural-network"],"created_at":"2025-02-01T15:14:50.987Z","updated_at":"2025-05-09T00:04:40.522Z","avatar_url":"https://github.com/mzguntalan.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# zephyr\n\n![Version 0.0.12](https://img.shields.io/badge/version-0.0.12-green)\n\nZephyr makes coding your machine learning ideas, short, fast, and to the point. Do a lot more while still writing less and still being more readable.\n\n- fast: it is built on JAX\n- easy: declarative syntax makes coding a lot shorter. If you know math, python, (jax) numpy, then you can write zephyr\n- short: no boiler plate, focus on computations, not on initializing modules.\n- precise: tags makes it possible to target groups of weights for nuanced update rule\n- generic but still easier to write: use whatever you want, even (fixed size) recursions are welcome like f(params, x, i) = f(params, x-1, i+1)\n\n## Overview\n\nZephyr, at the core, are just the `trace` and `validate` functions with extra utilities. `trace` gives you parameters. `validate` checks expressions related to parameters.\n\nThe main mindset in writing zephyr is to think in FP and declarative-manner. Think of composable transformations instead of methods - transformations of both\ndata or arrays AND functions. The examples below, will progressively re-write procedural/imperative-oriented code to the use of function transformations.\nThis puts the focus on what the transformation will be, rather than what the arrays become after each step.\n\nBefore we start. A neural network is just function, usually of `params`, `x`, and hyper-parameters. `f(params, x, **hyperparameters)`. If we wanted to get a function\nwithout the hyperparameters, since those never change, we can use python's `partial` and rewrite as `new_f = partial(f, **hyperparameters)` and use `new_f(params, x)`. However, using `partial` could get tedious as it doesn't give you signature hints in your editor. Instead, you can use the more readable, zephyr's `_` notation which is an alias for `placeholder_hole` which zephyr nets accept and auto-partializes the function. So we could write `new_f = f(_,_, **hyperparameters)` where `_` stands in for values we pas in later. To make your own function accept `_` holes, you can use the `flexible` decorator.\n\nOne more thing, this library was heavily inspired by Haiku, and so `params` is a dictionary whose leaves are Arrays. Zephyr, uses the same convention.\n\n## New Features not on the README yet\n\n- tags : A way to update weights in a more precise manner. Example: update weights differently depending how deep they are in a layer; update different subnetworks differently; so on. (it's rare to this so I don't have an example here, but it is possible in zephyr)\n\n## Installation\n\n```bash\npip install z-zephyr --upgrade\n```\n\n## Contents\n\n[Examples](#examples) | [Sharp Bits](#gotchas) | [Direction](#direction) | [Motivation](#motivation)\n\n## Examples\u003ca id=\"examples\"\u003e\u003c/a\u003e\n\nLook at the Following Examples\n\n0. [Imports](#imports): Common Gateway for Imports\n1. [Encoder and Decoder](#ende): This example will show you some of the layers in `zephyr.nets`. We use zephyr's `chain` function to chain functions(neural networks) together.\n2. [Parameter Creation](#parameters): This example will show you how to use custom parameters in your functions/nets.\n3. [Dealing with random keys](#thread): This example will show you that keys are just Arrays and part of your input. Nevertheless, there are some zephyr utilities you could use\n   to transform functions in ways that are useful for dealing with keys.\n\n### Imports \u003ca id=\"imports\"\u003e\u003c/a\u003e\n\nThese are the imports for all the examples\n\n```python\nfrom zephyr.functools.composition import thread_key, thread_params\nfrom jax import numpy as jnp, random, jit, nn\nfrom zephyr import nets, trace\nfrom zephyr.nets import chain\nfrom zephyr.functools.partial import placeholder_hole as _, flexible\n```\n\n### Example: Encoder and Decoder\u003ca id=\"ende\"\u003e\u003c/a\u003e\n\nLet's write a random encoder and decoder. Notice that we access `params` as if we already have a `params` made. Indeed, this declarative style is something you would have to get used to. Don't worry, zephyr handles making these parameters for you.\n\nFor each of the `encoder`, `decoder`, and `model` we offer 2 versions. One focusing on `x`, and the other building the transformation then applying it to `x`. These 2 versions are on the extreme, with the first being several lines of code, and the second being a single line of code(broken up). The next examples will use other rewrites that are less extreme.\n\nEncoder: Notice that there neural networks are used just like normal functions. Within each use, we can explicitly see everything, the params, the input/s, and the hyperparameters. This makes code short and concise.\n\n```python\n\n@flexible\ndef encoder(params, x):\n    x = nets.mlp(params[\"mlp\"], x, [256,256,256]) # b 256\n    x = nets.layer_norm(params[\"ln\"], x, -1)\n    x = nets.branch_linear(params[\"br\"], x, 64) # b 64 256\n\n    for i in range(3):\n        x = nets.conv_1d(params[\"conv\"][i], x, 64, 5)\n        x = nn.relu(x)\n        x = nets.max_pool(params, x, (3,3), 2)\n\n    x = jnp.reshape(x, [x.shape[0], -1]) # b 256\n    x = nets.linear(params[\"linear\"], x, 4) # b 4\n    return x\n\n\n@flexible\ndef encoder(params, x):\n    return chain([\n        nets.mlp(params[\"mlp\"], _, [256, 256, 256]),\n        nets.layer_norm(params[\"ln\"], _, -1),\n        nets.branch_linear(params[\"br\"], _, 64),\n        * [\n            chain([\n                nets.conv_1d(params[\"conv\"][i], _, 64, 5),\n                nn.relu,\n                nets.max_pool(params, _, (3,3), 2),\n            ]) for i in range(3)\n        ],\n        lambda x: jnp.reshape(x, [x.shape[0], -1]),\n        nets.linear(params[\"linear\"], _, 4)\n    ])(x)\n\n```\n\nDecoder: Notice that skip connections can be wrapped within a `skip` function/network that automatically adds a skip connection as `skip(f)(x) = f(x) + x`.\n\n```python\n@flexible\ndef decoder(params, z):\n    x = nets.mlp(params[\"mlp\"], x, [256,256,256]) # b 256\n    x = nets.branch_linear(params[\"br\"], x, 64) # b 64 256\n\n    for i in range(3):\n        x = nets.multi_head_self_attention(params[\"mha\"][i], x, 64, 5)\n        x = x + nets.mlp(params[\"attn_mlp\"][i], x, [256, 256])\n        x = nets.layer_norm(params[\"attn_ln\"][i], x, -1)\n\n    x = jnp.reshape(x, [x.shape[0], -1]) # b (64 * 128) = b 16384\n    x = nets.linear(params[\"linear\"], x, 128) # b 128\n    return x\n\n@flexible\ndef decoder(params, z):\n    return chain([\n        nets.mlp(params[\"mlp\"], _, [256, 256, 256]),\n        nets.branch_linear(params[\"br\"], _, 64),\n        *[\n            chain([\n                nets.multi_head_self_attention(params[\"mha\"][i], _, 64, 5),\n                nets.skip(nets.mlp(params[\"attn_mlp\"][i], _, [256,256])),\n                nets.layer_norm(params[\"attn_ln\"][i], _, -1),\n            ]) for i in range(3)\n        ],\n        lambda x: jnp.reshape(x, [x.shape[0], -1]),\n        nets.linear(params[\"linear\"], _, 128) # b 128\n    ])(x)\n```\n\nModel:\n\n```python\ndef model(params, x):\n    z = encoder(params[\"encoder\"], x)\n    reconstructed_x = decoder(params[\"decoder\"], z)\n    return reconstructed_x\n\ndef model(params, x):\n    return chain([\n        encoder(params[\"encoder\"], _),\n        decoder(params[\"decoder\"], _),\n    ])(x)\n```\n\nTo get an initial `params`, we simply use the trace function as follows.\n\n```python\nkey = random.PRNGKey(0) # needed to randomly initialize weights\nx = jnp.ones([64, 8]) # sample input batch:w\n\n\nparams = trace(model, key)\n\nfast_model = jit(model) # tracing of `trace` cannot trace a jit-ed function, please use the non-jit-ed version when tracing\nsample_outputs = fast_model(params, x) # b 8\n```\n\nFor model surgery or study: if you wanted to use just the enoder, then you can do `z = encoder(params[\"encoder\"], x)`. You can do the same with any function/layer.\n\n### Examples: Making your own parameters\u003ca id=\"parameters\"\u003e\u003c/a\u003e\n\nTo illustrate this, we will make our own `linear` layer using zephyr. In line with the declarative thinking, we specify what the shape of the paramters would look like -\nIdeally, we can put this in the type annotation, but that's ignored by Python, so we instead use zephyr's `validate` as an alternative. One main use of `validate` is\nto specify parameter shape, initializer, and other relationships it might have with hyperparameters.\n\n```python\n@flexible\ndef linear(params, x, out_target):\n    validate(params[\"weights\"], (x.shape[-1], out_target))\n    validate(params[\"bias\"], (out_target,))\n    x = x @ params[\"weights\"] + params[\"bias\"]\n    return x\n```\n\nAs said, earlier we wil show rewrites which is up to you. This is just to show what is possible. There is a way to write this in way that resembles the pattern of\nother FP languages where they assume some variables exist and give it to you with a `where` keyword, similar to math statements.\n\n```python\n@flexible\ndef linear(params, x, out_target):\n    return (lambda w, b: x @ w + b)(\n        validate(params[\"weights\"], (x.shape[-1], out_target)),\n        validate(params[\"bias\"], (out_target,)),\n\n    )\n```\n\nNotice the use of `validate` here. `validate` is actually just a way to enfore \"type annotations\" (albeit dependent types because we're really specifying shapes)\nbecause they have to be specified somewhere for zephyr to trace it. Nevertheless, `validate` acts like the identity function and returns its first parameter unchanged.\n\nTo use it, we simply use the `trace` function and use normally as follows.\n\n```python\nkey = random.PRNGKey(0)\nmodel = linear(_,_, 256)\nparams = trace(model, key, x)\nmodel(params, x) # use it like this\n\n# or jit it\nfast_model = jit(model)\nfast_model(params, x)\n```\n\n### Dealing with random keys \u003ca id=\"thread\"\u003e\u003c/a\u003e\n\nRandom keys or RNGs are somewhat an unfamiliar concept usually, since in FP you have to be explicit with these. So when you try to get rid of it using OO then\nit tends to stick out like a sore thumb at the end. In zephyr, we embrace this and treat key as you would anything - it is just input to data.\n\nHere a simple model using dropout.\n\n```python\ndef model(params, x, key):\n    for i in range(3):\n        x = nets.mlp(params[\"mlp\"][i], x, [256, 256])\n        key, subkey = random.split(key)\n        x = nets.dropout(params, subkey, x, 0.2)\n    x = nn.sigmoid(x)\n    return x\n```\n\nAs with previous examples, we offer rewrites of this, none of which are \"more elegent\". Choose the one that best suits you.\n\nZephyr has a `thread` function with specific variants such as `thread_key`, `thread_params`, and `thread_identity` which should be enough for most cases.\n\nAnother rewrite would factor out the repeating block into its own function as follows.\n\n```python\ndef block(params, key, x):\n    return chain([\n        nets.mlp(params[\"mlp\"], _, [256,256]),\n        nets.dropout(params, key, _, 0.2)\n    ])(x)\n\ndef model(params, x, key):\n    blocks = thread_params([block for i in range(3)], params) # each block is block(key,x)\n    blocks = thread_key(blocks, key) # each block is block(x)\n\n    return chain(blocks + [nn.relu])(x)\n\n```\n\nTo use it, we simply use the `trace` function and use normally as follows.\n\n```python\ntrace_key, apply_key_1, apply_key_2, key = random.split(key, 4) # split the keys ;p\n\nparams = trace(model, trace_key, x, apply_key_1)\nmodel(params, x, apply_key_2) # use it like this\n\n# or jit it\nfast_model = jit(model)\nfast_model(params, x, apply_key_2)\n```\n\n## Sharp Bits \u003ca id=\"gotchas\"\u003e\u003c/a\u003e\n\n1. Documentation Strings are sparse: I'll add them soon :3.\n2. JAX Sharp Bits: You'll be dealing with JAX sharp bits sometimes like \"str and int can't be compared\" which is a jax thing, since Zephyr is such a thin library on top of JAX (it isn't even a thin wrapper). Any trouble you might have, you can open an issue and i'll help.\n3. Bugs: If you use it, there'll probably be bugs, if you report them, I'll work on them immediately.\n4. Missing nets: like RNNs, I'll add them soon when I need them or requested.\n5. Instability: Things are still changing a lot. I might implement other nets/layers in a different way or change names or move things.\n\n## Direction \u003ca id=\"direction\"\u003e\u003c/a\u003e\n\nI would like to provide more FP tooling for python in zephyr and so I could write zephyr nets in more FP-style. Zephyr itself, it's core, is probably close\nto stable: mainly `trace` and `validate`, anything else is just to make coding easier or shorter.\n\n## Motivation and Inspiration\u003ca id=\"motivation\"\u003e\u003c/a\u003e\n\nThis library is heavily inspired by [Haiku](https://github.com/google-deepmind/dm-haiku)'s `transform` function which eventually\nconverts impure functions/class-method-calls into a pure function paired with an initilized `params` PyTree. This is my favorite\napproach so far because it is closest to pure functional programming. Zephyr tries to push this to the simplest and make neural networks\nsimply just a function.\n\nThis library is also inspired by other frameworks I have tried in the past: Tensorflow and PyTorch. Tensorflow allows for shape\ninference to happen after the first pass of inputs, PyTorch (before the Lazy Modules) need the input shapes at layer creation. Zephyr\nwants to be as easy as possible and will strive to always use at-inference-time shape-inference and use relative axis positions whenever possible.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmzguntalan%2Fzephyr","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fmzguntalan%2Fzephyr","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmzguntalan%2Fzephyr/lists"}