{"id":13444507,"url":"https://github.com/MzeroMiko/mamba-mini","last_synced_at":"2025-03-20T18:33:04.680Z","repository":{"id":221006348,"uuid":"753075277","full_name":"MzeroMiko/mamba-mini","owner":"MzeroMiko","description":"An efficient pytorch implementation of selective scan in one file, works with both cpu and gpu, with corresponding mathematical derivation. It is probably the code which is the most close to selective_scan_cuda in mamba.","archived":false,"fork":false,"pushed_at":"2024-03-04T10:37:07.000Z","size":1220,"stargazers_count":61,"open_issues_count":4,"forks_count":0,"subscribers_count":3,"default_branch":"main","last_synced_at":"2024-08-01T04:02:11.120Z","etag":null,"topics":["efficient","mamba","pytorch","selective-scan"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":null,"status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/MzeroMiko.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null}},"created_at":"2024-02-05T12:24:37.000Z","updated_at":"2024-07-25T16:02:17.000Z","dependencies_parsed_at":null,"dependency_job_id":"5a85b67f-c49e-435b-b0bb-4cc23daa9f6f","html_url":"https://github.com/MzeroMiko/mamba-mini","commit_stats":null,"previous_names":["mzeromiko/mamba-mini"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/MzeroMiko%2Fmamba-mini","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/MzeroMiko%2Fmamba-mini/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/MzeroMiko%2Fmamba-mini/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/MzeroMiko%2Fmamba-mini/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/MzeroMiko","download_url":"https://codeload.github.com/MzeroMiko/mamba-mini/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":221792892,"owners_count":16881289,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["efficient","mamba","pytorch","selective-scan"],"created_at":"2024-07-31T04:00:27.797Z","updated_at":"2024-10-28T06:31:09.171Z","avatar_url":"https://github.com/MzeroMiko.png","language":"Python","funding_links":[],"categories":["Input-dependent gating."],"sub_categories":[],"readme":"# mamba-mini\nAn efficient implementation of selective scan in one file, works with both cpu and gpu, with corresponding mathematical derivation. It is probably the code which is the most close to selective_scan_cuda in mamba.\n\n### update!\n* **`20240304: New implementation with new derivations!`** we now support a new approach to implement selective_scan chunk-parallely: [`selective_scan_easyv3`](./test_selective_scan_easy.py). It is faster than `selective_scan_easy` when `d_state=1`, but still slower than `mamba_ssm` with cuda. We would implement it in `triton` and test the speed in the future. \n\n### mathematical derivation to `chunk-naive version`\ncode is in [`selective_scan_easy`](./test_selective_scan_easy.py) and [`SelectiveScanEasy`](./test_selective_scan_easy.py).\n![image](./assets/derivation.png)\n\n### mathematical derivation to `chunk-parallel version`\nThis is the chunk parallel version of selective scan, with support to some different branches.\ncode is in [`selective_scan_easyv3`](./test_selective_scan_easy.py).\n![image](./assets/derivation_general.png)\n![image](./assets/derivation_wdk.png)\n![image](./assets/derivation_wdv.png)\n![image](./assets/derivation_dk1.png)\n\n### naive code\n```python\nimport torch\ndef selective_scan_easy(us, dts, As, Bs, Cs, Ds, delta_bias=None, delta_softplus=False, return_last_state=False, chunksize=64):\n    \"\"\"\n    # B: batch_size, G: groups, D: dim, N: state dim, L: seqlen\n    us: B, G * D, L \n    dts: B, G * D, L\n    As: G * D, N\n    Bs: B, G, N, L\n    Cs: B, G, N, L\n    Ds: G * D\n    delta_bias: G * D\n    # chunksize can be any as you like. But as the chunksize raises, hs may get None, as exp(sum(delta) A) is really small\n    \"\"\"\n    def selective_scan_chunk(us, dts, As, Bs, Cs, hprefix):\n        \"\"\"\n        partial(h) / partial(t) = Ah + Bu; y = Ch + Du;\n        =\u003e partial(h*exp(-At)) / partial(t) = Bu*exp(-At);\n        =\u003e h_t = h_0 + sum_{0}_{t}_{Bu*exp(A(t-v)) dv};\n        =\u003e h_b = exp(A(dt_a + ... + dt_{b-1})) * (h_a + sum_{a}_{b-1}_{Bu*exp(-A(dt_a + ... + dt_i)) dt_i});\n           y_i = C_i*h_i + D*u_i\n        \"\"\"\n        \"\"\"\n        us, dts: (L, B, G, D) # L is chunk_size\n        As: (G, D, N)\n        Bs, Cs: (L, B, G, N)\n        Ds: (G, D)\n        hprefix: (B, G, D, N)\n        \"\"\"\n        ts = dts.cumsum(dim=0)\n        Ats = torch.einsum(\"gdn,lbgd-\u003elbgdn\", As, ts).exp()\n        scale = Ats[-1].detach()\n        rAts = Ats / scale\n        duts = dts * us\n        dtBus = torch.einsum(\"lbgd,lbgn-\u003elbgdn\", duts, Bs)\n        hs_tmp = rAts * (dtBus / rAts).cumsum(dim=0) \n        hs = hs_tmp + Ats * hprefix.unsqueeze(0)\n        ys = torch.einsum(\"lbgn,lbgdn-\u003elbgd\", Cs, hs) \n        return ys, hs\n    \n    inp_dtype = us.dtype\n    has_D = Ds is not None\n\n    dts = dts.float()\n    if delta_bias is not None:\n        dts = dts + delta_bias.view(1, -1, 1).float()\n    if delta_softplus:\n        dts = torch.nn.functional.softplus(dts)\n    \n    if len(Bs.shape) == 3:\n        Bs = Bs.unsqueeze(1)\n    if len(Cs.shape) == 3:\n        Cs = Cs.unsqueeze(1)\n    B, G, N, L = Bs.shape\n    us = us.view(B, G, -1, L).permute(3, 0, 1, 2).float()\n    dts = dts.view(B, G, -1, L).permute(3, 0, 1, 2).float()\n    As = As.view(G, -1, N).float()\n    Bs = Bs.permute(3, 0, 1, 2).float()\n    Cs = Cs.permute(3, 0, 1, 2).float()\n    Ds = Ds.view(G, -1).float() if has_D else None\n    D = As.shape[1]\n    \n    oys = []\n    # ohs = []\n    hprefix = us.new_zeros((B, G, D, N), dtype=torch.float)\n    for i in range(0, L - 1, chunksize):\n        ys, hs = selective_scan_chunk(\n            us[i:i + chunksize], dts[i:i + chunksize], \n            As, Bs[i:i + chunksize], Cs[i:i + chunksize], hprefix, \n        )\n        oys.append(ys)\n        # ohs.append(hs)\n        hprefix = hs[-1]\n\n    oys = torch.cat(oys, dim=0)\n    # ohs = torch.cat(ohs, dim=0)\n    if has_D:\n        oys = oys + Ds * us\n    oys = oys.permute(1, 2, 3, 0).view(B, -1, L)\n    oys = oys.to(inp_dtype)\n    # hprefix = hprefix.to(inp_dtype)\n\n    return oys if not return_last_state else (oys, hprefix.view(B, G * D, N))\n\n```\n\n### to test\n```bash\npytest test_selective_scan.py\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FMzeroMiko%2Fmamba-mini","html_url":"https://awesome.ecosyste.ms/projects/github.com%2FMzeroMiko%2Fmamba-mini","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FMzeroMiko%2Fmamba-mini/lists"}