{"id":19428884,"url":"https://github.com/eleutherai/optax-galore","last_synced_at":"2025-04-14T16:43:06.726Z","repository":{"id":253846638,"uuid":"844698240","full_name":"EleutherAI/optax-galore","owner":"EleutherAI","description":"Adds GaLore style projection wrappers to optax optimizers","archived":false,"fork":false,"pushed_at":"2024-10-03T13:58:59.000Z","size":22,"stargazers_count":3,"open_issues_count":0,"forks_count":0,"subscribers_count":1,"default_branch":"main","last_synced_at":"2024-11-10T14:16:57.092Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/EleutherAI.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2024-08-19T19:43:03.000Z","updated_at":"2024-10-03T13:59:02.000Z","dependencies_parsed_at":"2024-08-20T00:14:03.472Z","dependency_job_id":null,"html_url":"https://github.com/EleutherAI/optax-galore","commit_stats":null,"previous_names":["eleutherai/optax-galore"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/EleutherAI%2Foptax-galore","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/EleutherAI%2Foptax-galore/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/EleutherAI%2Foptax-galore/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/EleutherAI%2Foptax-galore/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/EleutherAI","download_url":"https://codeload.github.com/EleutherAI/optax-galore/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":232929114,"owners_count":18598257,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-11-10T14:17:04.751Z","updated_at":"2025-01-07T19:50:47.850Z","avatar_url":"https://github.com/EleutherAI.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Optax-GaLore\n\nOptax-GaLore is an implementation of the Gradient Low-Rank Projection (GaLore) algorithm for memory-efficient training of Large Language Models (LLMs). This project extends the Optax optimization library with GaLore functionality.\n\n## Features\n\n- Memory-efficient optimization for large-scale models\n- Compatible with existing Optax optimizers\n- Flexible projection specifications for different layer types\n- Support for convolutional layers and other multi-dimensional tensors\n\n## Installation\n\nTo install Optax-GaLore, clone this repository and install the required dependencies:\n\n```bash\ngit clone https://github.com/EleutherAI/optax-galore.git\ncd optax-galore\npip install -r requirements.txt\n```\n## Usage\n\nHere's a basic example of how to use Optax-GaLore in your project:\n\n```python\nimport jax\nimport optax\nimport optax_galore.optax_galore as og\n\n# Define your model and loss function\n# ...\n\n# Create a GaLore optimizer\nlearning_rate = 0.001\nrank = 64\nsubspace_change_freq = 1000\n\noptimizer = og.galore(\n    learning_rate=learning_rate,\n    rank=rank,\n    subspace_change_freq=subspace_change_freq\n)\n\n# Initialize optimizer state\nopt_state = optimizer.init(params)\n\n# Define the loss function\ndef loss_fn(params, batch):\n    # Your model's loss calculation\n    return loss\n\n# Define the update function\n@jax.jit\ndef update(params, opt_state, batch):\n    loss, grads = jax.value_and_grad(loss_fn)(params, batch)\n    updates, new_opt_state = optimizer.update(grads, opt_state, params)\n    new_params = optax.apply_updates(params, updates)\n    return new_params, new_opt_state, loss\n\n# In your training loop:\nfor batch in data_loader:\n    params, opt_state, loss = update(params, opt_state, batch)\n```\n\nThis updated usage example demonstrates how to create a jitted update function that includes the loss calculation, gradient computation, optimizer update, and parameter update. Using a jitted update function enables the compiler to optimize out the unprojected gradients to save memory (probably).\n\n### Advanced Usage\n\nFor more control over the projection dimensions, you can use the `dimension_pytree` parameter:\n\n```python\nimport optax_galore.optax_galore as og\nfrom optax_galore.optax_galore import ProjectionSpec\n\ndimension_pytree = {\n    'conv1': {'w': ProjectionSpec(2, 3), 'b': None},\n    'conv2': {'w': ProjectionSpec(2, 3), 'b': None}\n}\n\noptimizer = optax_galore.galore(\n    learning_rate=learning_rate,\n    rank=rank,\n    subspace_change_freq=subspace_change_freq,\n    dimension_pytree=dimension_pytree\n)\n```\n\nYou can also wrap other Optax optimizers with GaLore:\n\n```python\nbase_optimizer = optax.adam(learning_rate=0.001)\ngalore_optimizer = optax_galore.galore_wrapper(\n    base_optimizer,\n    rank=64,\n    subspace_change_freq=1000\n)\n```\n\n## Testing\n\nTo run the tests, use pytest:\n\n```bash\npytest tests/\n```\n\n## Contributing\n\nContributions to Optax-GaLore are welcome! Please feel free to submit a Pull Request.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Feleutherai%2Foptax-galore","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Feleutherai%2Foptax-galore","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Feleutherai%2Foptax-galore/lists"}