{"id":18293457,"url":"https://github.com/archinetai/a-unet","last_synced_at":"2025-04-30T22:14:14.078Z","repository":{"id":64912551,"uuid":"578613050","full_name":"archinetai/a-unet","owner":"archinetai","description":"A toolbox that provides hackable building blocks for generic 1D/2D/3D UNets, in PyTorch.","archived":false,"fork":false,"pushed_at":"2023-06-12T22:27:06.000Z","size":23,"stargazers_count":85,"open_issues_count":1,"forks_count":9,"subscribers_count":3,"default_branch":"main","last_synced_at":"2025-04-30T22:14:09.118Z","etag":null,"topics":["deep-learning","machine-learning","unet"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/archinetai.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-12-15T13:21:45.000Z","updated_at":"2025-02-17T11:17:04.000Z","dependencies_parsed_at":"2024-06-21T17:50:54.717Z","dependency_job_id":"75877255-be51-46d3-b5af-011e535bd009","html_url":"https://github.com/archinetai/a-unet","commit_stats":{"total_commits":32,"total_committers":2,"mean_commits":16.0,"dds":0.0625,"last_synced_commit":"e0933e73a60cf9e7d42fcb68c173b23d903f258d"},"previous_names":[],"tags_count":16,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/archinetai%2Fa-unet","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/archinetai%2Fa-unet/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/archinetai%2Fa-unet/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/archinetai%2Fa-unet/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/archinetai","download_url":"https://codeload.github.com/archinetai/a-unet/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":251789618,"owners_count":21644086,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","machine-learning","unet"],"created_at":"2024-11-05T14:24:39.123Z","updated_at":"2025-04-30T22:14:14.056Z","avatar_url":"https://github.com/archinetai.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# A-UNet\n\nA toolbox that provides hackable building blocks for generic 1D/2D/3D UNets, in PyTorch.\n\n## Install\n```bash\npip install a-unet\n```\n\n[![PyPI - Python Version](https://img.shields.io/pypi/v/a-unet?style=flat\u0026colorA=black\u0026colorB=black)](https://pypi.org/project/a-unet/)\n\n\n## Usage\n\n### Basic UNet\n\n\u003cdetails\u003e \u003csummary\u003e (Code): A convolutional only UNet generic to any dimension. \u003c/summary\u003e\n\n```py\nfrom typing import List\nfrom a_unet import T, Downsample, Repeat, ResnetBlock, Skip, Upsample\nfrom torch import nn\n\ndef UNet(\n    dim: int,\n    in_channels: int,\n    channels: List[int],\n    factors: List[int],\n    blocks: List[int],\n) -\u003e nn.Module:\n    # Check lengths\n    n_layers = len(channels)\n    assert n_layers == len(factors) and n_layers == len(blocks), \"lengths must match\"\n\n    # Resnet stack\n    def Stack(channels: int, n_blocks: int) -\u003e nn.Module:\n        # The T function is used create a type template that pre-initializes paramters if called\n        Block = T(ResnetBlock)(dim=dim, in_channels=channels, out_channels=channels)\n        resnet = Repeat(Block, times=n_blocks)\n        return resnet\n\n    # Build UNet recursively\n    def Net(i: int) -\u003e nn.Module:\n        if i == n_layers: return nn.Identity()\n        in_ch, out_ch = (channels[i - 1] if i \u003e 0 else in_channels), channels[i]\n        factor = factors[i]\n        # Wraps modules with skip connection that merges paths with torch.add\n        return Skip(torch.add)(\n            Downsample(dim=dim, factor=factor, in_channels=in_ch, out_channels=out_ch),\n            Stack(channels=out_ch, n_blocks=blocks[i]),\n            Net(i + 1),\n            Stack(channels=out_ch, n_blocks=blocks[i]),\n            Upsample(dim=dim, factor=factor, in_channels=out_ch, out_channels=in_ch),\n        )\n    return Net(0)\n```\n\n\u003c/details\u003e\n\n```py\nunet = UNet(\n  dim=2,\n  in_channels=8,\n  channels=[256, 512],\n  factors=[2, 2],\n  blocks=[2, 2]\n)\nx = torch.randn(1, 8, 16, 16)\ny = unet(x) # [1, 8, 16, 16]\n```\n\n\n### ApeX UNet\n\n\u003cdetails\u003e \u003csummary\u003e (Code): ApeX is a UNet template complete with tools for easy customizability. The following example UNet includes multiple features: (1) custom item arrangement for resnets, modulation, attention, and cross attention, (2) custom skip connection with concatenation, (3) time conditioning (usually used for diffusion), (4) classifier free guidance. \u003c/summary\u003e\n\n```py\nfrom typing import Sequence, Optional, Callable\n\nfrom a_unet import TimeConditioningPlugin, ClassifierFreeGuidancePlugin\nfrom a_unet.apex import (\n    XUNet,\n    XBlock,\n    ResnetItem as R,\n    AttentionItem as A,\n    CrossAttentionItem as C,\n    ModulationItem as M,\n    SkipCat\n)\n\ndef UNet(\n    dim: int,\n    in_channels: int,\n    channels: Sequence[int],\n    factors: Sequence[int],\n    items: Sequence[int],\n    attentions: Sequence[int],\n    cross_attentions: Sequence[int],\n    attention_features: int,\n    attention_heads: int,\n    embedding_features: Optional[int] = None,\n    skip_t: Callable = SkipCat,\n    resnet_groups: int = 8,\n    modulation_features: int = 1024,\n    embedding_max_length: int = 0,\n    use_classifier_free_guidance: bool = False,\n    out_channels: Optional[int] = None,\n):\n    # Check lengths\n    num_layers = len(channels)\n    sequences = (channels, factors, items, attentions, cross_attentions)\n    assert all(len(sequence) == num_layers for sequence in sequences)\n\n    # Define UNet type with time conditioning and CFG plugins\n    UNet = TimeConditioningPlugin(XUNet)\n    if use_classifier_free_guidance:\n        UNet = ClassifierFreeGuidancePlugin(UNet, embedding_max_length)\n\n    return UNet(\n        dim=dim,\n        in_channels=in_channels,\n        out_channels=out_channels,\n        blocks=[\n            XBlock(\n                channels=channels,\n                factor=factor,\n                items=([R, M] + [A] * n_att + [C] * n_cross) * n_items,\n            ) for channels, factor, n_items, n_att, n_cross in zip(*sequences)\n        ],\n        skip_t=skip_t,\n        attention_features=attention_features,\n        attention_heads=attention_heads,\n        embedding_features=embedding_features,\n        modulation_features=modulation_features,\n        resnet_groups=resnet_groups\n    )\n```\n\n\u003c/details\u003e\n\n```py\nunet = UNet(\n    dim=2,\n    in_channels=2,\n    channels=[128, 256, 512, 1024],\n    factors=[2, 2, 2, 2],\n    items=[2, 2, 2, 2],\n    attentions=[0, 0, 0, 1],\n    cross_attentions=[1, 1, 1, 1],\n    attention_features=64,\n    attention_heads=8,\n    embedding_features=768,\n    use_classifier_free_guidance=False\n)\nx = torch.randn(2, 2, 64, 64)\ntime = [0.2, 0.5]\nembedding = torch.randn(2, 512, 768)\ny = unet(x, time=time, embedding=embedding) # [2, 2, 64, 64]\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Farchinetai%2Fa-unet","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Farchinetai%2Fa-unet","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Farchinetai%2Fa-unet/lists"}