{"id":25144709,"url":"https://github.com/erfanzar/eformer","last_synced_at":"2025-04-28T11:22:49.464Z","repository":{"id":171258756,"uuid":"647614782","full_name":"erfanzar/eformer","owner":"erfanzar","description":"(EasyDel Former) is a utility library designed to simplify and enhance the development in JAX","archived":false,"fork":false,"pushed_at":"2025-04-23T09:36:31.000Z","size":3340,"stargazers_count":27,"open_issues_count":0,"forks_count":3,"subscribers_count":3,"default_branch":"main","last_synced_at":"2025-04-28T11:22:31.792Z","etag":null,"topics":["easydel","flax","jax","lax","numpy"],"latest_commit_sha":null,"homepage":"https://eformer.readthedocs.io/en/latest","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/erfanzar.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":"CONTRIBUTING.md","funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2023-05-31T06:52:43.000Z","updated_at":"2025-04-23T09:36:34.000Z","dependencies_parsed_at":null,"dependency_job_id":"5d3df55f-5f28-405d-b52f-8408504f22b2","html_url":"https://github.com/erfanzar/eformer","commit_stats":{"total_commits":216,"total_committers":4,"mean_commits":54.0,"dds":0.08333333333333337,"last_synced_commit":"befdd2875ba7e1ed09802686267317f6c4bba508"},"previous_names":["erfanzar/fxutils","erfanzar/fjformer"],"tags_count":6,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/erfanzar%2Feformer","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/erfanzar%2Feformer/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/erfanzar%2Feformer/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/erfanzar%2Feformer/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/erfanzar","download_url":"https://codeload.github.com/erfanzar/eformer/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":251303041,"owners_count":21567624,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["easydel","flax","jax","lax","numpy"],"created_at":"2025-02-08T19:47:29.853Z","updated_at":"2025-04-28T11:22:49.441Z","avatar_url":"https://github.com/erfanzar.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# eformer (EasyDel Former)\n\n[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)\n[![Python](https://img.shields.io/badge/Python-3.8%2B-blue)](https://www.python.org/)\n[![JAX](https://img.shields.io/badge/JAX-Compatible-brightgreen)](https://github.com/google/jax)\n\n**eformer** (EasyDel Former) is a utility library designed to simplify and enhance the development of machine learning models using JAX. It provides a collection of tools for sharding, custom PyTrees, quantization, mixed precision training, and optimized operations, making it easier to build and scale models efficiently.\n\n## Features\n\n- **Mixed Precision Training (`mpric`)**: Advanced mixed precision utilities supporting float8, float16, and bfloat16 with dynamic loss scaling.\n- **Sharding Utilities (`escale`)**: Tools for efficient sharding and distributed computation in JAX.\n- **Custom PyTrees (`jaximus`)**: Enhanced utilities for creating custom PyTrees and `ArrayValue` objects, updated from Equinox.\n- **Custom Calling (`callib`)**: A tool for custom function calls and direct integration with Triton kernels in JAX.\n- **Optimizer Factory**: A flexible factory for creating and configuring optimizers like AdamW, Adafactor, Lion, and RMSProp.\n- **Custom Operations and Kernels**:\n  - Flash Attention 2 for GPUs/TPUs (via Triton and Pallas).\n  - 8-bit and NF4 quantization for efficient model.\n  - Many others to be added.\n- **Quantization Support**: Tools for 8-bit and NF4 quantization, enabling memory-efficient model deployment.\n\n## Installation\n\nYou can install `eformer` via pip:\n\n```bash\npip install eformer\n```\n\n## Quick Start\n\n### Mixed Precision Handler with mpric\n\n```python\nfrom eformer.mpric import PrecisionHandler\n\n# Create a handler with float8 compute precision\nhandler = PrecisionHandler(\n    policy=\"p=f32,c=f8_e4m3,o=f32\",  # params in f32, compute in float8, output in f32\n    use_dynamic_scale=True\n)\n```\n\n### Customizing Arrays With ArrayValue\n\n```python\nimport jax\n\nfrom eformer.jaximus import ArrayValue, implicit\nfrom eformer.ops.quantization.quantization_functions import (\n    dequantize_row_q8_0,\n    quantize_row_q8_0,\n)\n\narray = jax.random.normal(jax.random.key(0), (256, 64), \"f2\")\n\n\nclass Array8B(ArrayValue):\n    scale: jax.Array\n    weight: jax.Array\n\n    def __init__(self, array: jax.Array):\n        self.weight, self.scale = quantize_row_q8_0(array)\n\n    def materialize(self):\n        return dequantize_row_q8_0(self.weight, self.scale)\n\n\nqarray = Array8B(array)\n\n\n@jax.jit\n@implicit\ndef sqrt(x):\n    return jax.numpy.sqrt(x)\n\n\nprint(sqrt(qarray))\nprint(qarray)\n```\n\n### Optimizer Factory\n\n```python\nfrom eformer.optimizers import OptimizerFactory, SchedulerConfig, AdamWConfig\n\n# Create an AdamW optimizer with a cosine scheduler\nscheduler_config = SchedulerConfig(scheduler_type=\"cosine\", learning_rate=1e-3, steps=1000)\noptimizer, scheduler = OptimizerFactory.create(\"adamw\", scheduler_config, AdamWConfig())\n```\n\n### Quantization\n\n```python\nfrom eformer.quantization import Array8B, ArrayNF4\n\n# Quantize an array to 8-bit\nqarray = Array8B(jax.random.normal(jax.random.key(0), (256, 64), \"f2\"))\n\n# Quantize an array to NF4\nn4array = ArrayNF4(jax.random.normal(jax.random.key(0), (256, 64), \"f2\"), 64)\n```\n\n### Advanced Mixed Precision Configuration\n\n```python\nfrom eformer.mpric import Policy, LossScaleConfig\n\n# Create a custom precision policy\npolicy = Policy(\n    param_dtype=jnp.float32,\n    compute_dtype=jnp.bfloat16,\n    output_dtype=jnp.float32\n)\n\n# Configure loss scaling\nloss_config = LossScaleConfig(\n    initial_scale=2**15,\n    growth_interval=2000,\n    scale_factor=2,\n    min_scale=1.0\n)\n\n# Create handler with custom configuration\nhandler = PrecisionHandler(\n    policy=policy,\n    use_dynamic_scale=True,\n    loss_scale_config=loss_config\n)\n```\n\n## Contributing\n\nWe welcome contributions! Please read our [Contributing Guidelines](CONTRIBUTING.md) to get started.\n\n## License\n\nThis project is licensed under the Apache License 2.0. See the [LICENSE](LICENSE) file for details.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ferfanzar%2Feformer","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Ferfanzar%2Feformer","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ferfanzar%2Feformer/lists"}