{"id":21813583,"url":"https://github.com/bobmcdear/flaim","last_synced_at":"2026-03-01T03:32:44.903Z","repository":{"id":64974604,"uuid":"580214072","full_name":"BobMcDear/flaim","owner":"BobMcDear","description":"Flax Image Models - State-of-the-art pre-trained vision backbones for Flax.","archived":false,"fork":false,"pushed_at":"2023-06-01T02:14:03.000Z","size":238,"stargazers_count":19,"open_issues_count":0,"forks_count":1,"subscribers_count":1,"default_branch":"main","last_synced_at":"2025-03-22T06:04:02.007Z","etag":null,"topics":["computer-vision","deep-learning","flax","jax","machine-learning"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"gpl-3.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/BobMcDear.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE.md","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-12-20T01:59:54.000Z","updated_at":"2025-02-15T12:20:27.000Z","dependencies_parsed_at":"2024-11-27T15:01:16.107Z","dependency_job_id":null,"html_url":"https://github.com/BobMcDear/flaim","commit_stats":{"total_commits":77,"total_committers":1,"mean_commits":77.0,"dds":0.0,"last_synced_commit":"a03aaa9a162519c77cdad7e468e74d3b4a1b29e0"},"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/BobMcDear%2Fflaim","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/BobMcDear%2Fflaim/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/BobMcDear%2Fflaim/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/BobMcDear%2Fflaim/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/BobMcDear","download_url":"https://codeload.github.com/BobMcDear/flaim/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248796046,"owners_count":21162895,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["computer-vision","deep-learning","flax","jax","machine-learning"],"created_at":"2024-11-27T14:30:17.463Z","updated_at":"2026-03-01T03:32:44.846Z","avatar_url":"https://github.com/BobMcDear.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Flax Image Models\n\n• \u003cstrong\u003e[Introduction](#introduction)\u003c/strong\u003e\u003cbr\u003e\n• \u003cstrong\u003e[Installation](#installation)\u003c/strong\u003e\u003cbr\u003e\n• \u003cstrong\u003e[Usage](#usage)\u003c/strong\u003e\u003cbr\u003e\n• \u003cstrong\u003e[Examples](#examples)\u003c/strong\u003e\u003cbr\u003e\n• \u003cstrong\u003e[Available Architectures](#available-architectures)\u003c/strong\u003e\u003cbr\u003e\n• \u003cstrong\u003e[Contributing](#contributing)\u003c/strong\u003e\u003cbr\u003e\n• \u003cstrong\u003e[Acknowledgements](#acknowledgements)\u003c/strong\u003e\u003cbr\u003e\n\n\n\n## Introduction\n\nflaim is a library of state-of-the-art pre-trained vision models, plus common deep learning modules in computer vision, for Flax.\nIt exposes a host of diverse image models through a straightforward interface with an emphasis on simplicity, leanness, and readability,\nand supplies lower-level modules for designing custom architectures.\n\n## Installation\n\nflaim can be installed through ```pip install flaim```. Beware that pip installs the CPU version of JAX, and you must [manually install JAX](https://github.com/google/jax#installation) yourself to run your programs on a GPU or TPU.\n\n## Usage\n\n```flaim.get_model``` is the central function of flaim and manages model retrieval. It takes a handful\nof arguments:\n* ```model_name``` (```str```): The name of the desired model.\n* ```pretrained``` (```Union[str, int, bool]```): Every flaim network is accompanied by at least one group of pre-trained\nparameters. For example, those of MaxViT-Small (```maxvit_small```) are ```in1k_224```, ```in1k_384```, and ```in1k_512```,\ncorresponding to parameters trained on ImageNet1K at resolutions 224 x 224, 384 x 384, and 512 x 512 respectively. When ```pretrained``` is\n  * A string, the pre-trained parameters of this name are returned, e.g., ```pretrained = 'in1k_224'```. This is the recommended means of loading pre-trained parameters, for it is the most explicit.\n  * An integer, a set of parameters trained at this resolution is returned. For instance, ```pretrained = 384``` would return a set of parameters\n  trained at a resolution of 384 x 384. It should be borne in mind that some models might not have parameters trained at this resolution, in which case an exception is thrown.\n  * ```True```, a default collection of pre-trained parameters is returned. Users should be wary of this option because certain models such as MaxViT cannot handle variable resolutions, and therefore\n  the returned pre-trained parameters might not be compatible with one's input shapes. In such scenarios, passing the desired resolution to ```pretrained``` would be the more judicious choice.\n  * ```False```, the parameters are randomly-initialized.\n\u003cbr\u003e\u003cbr\u003e\n\n* ```input_size``` (```int```): When ```pretrained``` is ```False```, ```input_size``` refers to the input size the model should expect\nand is used to initialize the parameters. Providing the correct value for ```input_size``` is especially important for fixed-resolution\narchitectures such as ViT.\n* ```jit``` (```bool```): Whether to JIT the model's initialization function. The advantage of JITting the initialization function\nis that no actual forward pass with real data is performed, unlike the default configuration. On the other hand, JIT compilation\ncan be a time-consuming process.\n* ```prng``` (```Optional[jax.random.KeyArray]```): PRNG key used for initializing the model. A PRNG key with a seed of 0 is created when ```prng = None```.\n* ```n_classes``` (```int```): The number of output classes. This argument's value can fall under three groups:\n  * 0: The model outputs the raw final feature maps. For instance, a ResNet is composed of a stem and four stages, followed\n  by a head constituted of global average pooling and a fully-connected layer. When the value of this argument is 0, the output of\n  the fourth stage is returned, and the head is discarded.\n  * -1: Every part of the head, except for the linear layer, is applied and the output returned. In the ResNet example, the output of\n  the pooling layer is returned.\n  * Positive integers: ```n_classes``` is interpreted as the desired number of output categories.\n\u003cbr\u003e\u003cbr\u003e\n\n```flaim.get_model``` returns the model, its parameters, and the normalization statistics associated with the parameters. When ```pretrained``` is ```False```, these statistics are simply an empty dictionary. The snippet below constructs an ImageNet1K-trained ResNet-50 with 10 output classes.\n\n```python\nimport flaim\n\n\nmodel, vars, norm_stats = flaim.get_model(\n        model_name='resnet50',\n        pretrained='in1k_224',\n        n_classes=10,\n        )\n```\n\nPerforming a forward pass with flaim is similar to any other Flax model. However, networks\nthat behave differently during training versus inference, e.g., due to batch normalization,\nreceive a ```training``` argument indicating whether the model should be in training mode or not. Furthermore, like\nany other Flax module incorporating batch normalization, ```batch_stats``` must be passed to ```mutable```\nto update batch normalization's running statistics during training.\n\n```python\nfrom jax import numpy as jnp\n\n# input should be normalized using norm_stats beforehand\ninput = jnp.ones((2, 224, 224, 3))\n\n# Training\noutput, new_batch_stats = model.apply(vars, input, training=True, mutable=['batch_stats'])\n# Inference\noutput = model.apply(vars, input, training=False, mutable=False)\n```\n\nFinally, the model's intermediate activations can be captured by passing ```intermediates``` to ```mutable```.\n\n```python\noutput, intermediates = model.apply(vars, input, training=False, mutable=['intermediates'])\n```\n\nIf the model is hierarchical, ```intermediates```'s entries are the output of each network stage and can be looked up through\n```intermediates['intermediates']['stage_ind']```, where ```ind``` is the index of the desired stage, with 0 being reserved for the stem. For isotropic models, the output of every block is returned, accessible via ```intermediates['intermediates']['block_ind']```, where ```ind``` is the index of the desired block and 0 is once again reserved for the stem.\n\nIt should be noted that Flax's [```sow```](https://flax-linen.readthedocs.io/en/v0.10.3/api_reference/flax.linen/module.html#flax.linen.Module.sow) API, which is used utilized by flaim, appends the intermediate activations to a tuple; that is, if _n_ forward passes are performed, ```intermediates['intermediates']['stage_ind']``` or ```intermediates['intermediates']['block_ind']``` would be tuples of length _n_, with the *i*\u003csup\u003eth\u003c/sup\u003e item corresponding to the *i*\u003csup\u003eth\u003c/sup\u003e forward pass.\n\n## Examples\n\n[```examples/```](https://github.com/bobmcdear/flaim/blob/main/examples/) includes a series of annotated notebooks for solving various vision problems such as object classification using flaim.\n\n## Available Architectures\n\nAll available architectures and their pre-trained parameters, plus short descriptions and references, are listed [here](https://github.com/bobmcdear/flaim/blob/main/ARCHITECTURES.md).\n\n```flaim.list_models``` also returns a list of (name of model, name of pre-trained parameters) pairs, e.g., (```resnet50```, ```in1k_224```) and has two arguments:\n\n* ```model_pattern``` (```str```): A regex pattern that, if not an empty string, ensures only pairs where the model name contains this expression are returned.\n* ```params_pattern``` (```Union[str, int]```): If ```params_pattern``` is a non-empty string, only pairs where the pre-trained parameters' name contains this regex pattern are returned. When an integer, only pairs where the pre-trained parameters were trained on images of this resolution are returned.\n\nThis function is demonstrated below.\n\n```python\n# Every model\nprint(flaim.list_models())\n\n# ResNeXt-based networks of depth 50\nprint(flaim.list_models(model_pattern='resnext50'))\n\n# Models trained on ImageNet22K\nprint(flaim.list_models(params_pattern='in22k'))\n\n# ViTs of input size 384 x 384\nprint(flaim.list_models(model_pattern='^vit', params_pattern=384))\n```\n\n## Contributing\n\nCode contributions are currently not accepted, however, there are three alternatives for those seeking to help flaim evolve:\n\n* Bugs: If you discover any bugs, please [open up an issue](https://github.com/BobMcDear/flaim/issues/new?assignees=BobMcDear\u0026labels=bug\u0026template=bug_report.md\u0026title=%5BBug+report%5D), specify your system information, and provide a description of the problem as well as a reproducible example.\u003cbr\u003e\n* Feature request: If there are particular architectures or modules that you think would be beneficial additions to flaim, please request them in an [Ideas discussion thread](https://github.com/BobMcDear/flaim/discussions/new?category=ideas).\u003cbr\u003e\n* Questions: If you have questions regarding a model, a segment of code, or anything else, please ask them by creating a [Q\u0026A discussion thread](https://github.com/BobMcDear/flaim/discussions/new?category=q-a).\u003cbr\u003e\n\n\n## Acknowledgements\n\nMany thanks to Ross Wightman for the amazing timm package, which was an inspiration for flaim and has been an indispensable guide during development. Additionally, the pre-trained parameters are stored on Hugging Face Hub; big thanks to Hugging Face for offering this service gratis. Also thanks to Google's [TPU Research Cloud (TRC) program](https://sites.research.google/trc/about/) for providing hardware used to accelerate the development of this project.\n\nReferences for ```flaim.models``` can be found [here](https://github.com/bobmcdear/flaim/blob/main/ARCHITECTURES.md), and ones for ```flaim.layers``` are in the source code.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fbobmcdear%2Fflaim","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fbobmcdear%2Fflaim","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fbobmcdear%2Fflaim/lists"}