{"id":28398181,"url":"https://github.com/nx-ai/mlstm_kernels","last_synced_at":"2025-06-28T14:31:41.580Z","repository":{"id":267575961,"uuid":"901383965","full_name":"NX-AI/mlstm_kernels","owner":"NX-AI","description":"Tiled Flash Linear Attention library for fast and efficient mLSTM Kernels.","archived":false,"fork":false,"pushed_at":"2025-05-18T20:56:10.000Z","size":11147,"stargazers_count":56,"open_issues_count":6,"forks_count":3,"subscribers_count":4,"default_branch":"main","last_synced_at":"2025-06-01T13:01:35.722Z","etag":null,"topics":["deep-learning","llm","rnn","triton-kernels","xlstm"],"latest_commit_sha":null,"homepage":"https://arxiv.org/abs/2503.14376","language":"Jupyter Notebook","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"other","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/NX-AI.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null}},"created_at":"2024-12-10T15:00:21.000Z","updated_at":"2025-05-18T20:56:13.000Z","dependencies_parsed_at":"2024-12-11T06:30:48.045Z","dependency_job_id":"061e0eb3-ed8b-4538-8f52-bfc81fffce9a","html_url":"https://github.com/NX-AI/mlstm_kernels","commit_stats":null,"previous_names":["nx-ai/mlstm_kernels"],"tags_count":5,"template":false,"template_full_name":null,"purl":"pkg:github/NX-AI/mlstm_kernels","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/NX-AI%2Fmlstm_kernels","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/NX-AI%2Fmlstm_kernels/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/NX-AI%2Fmlstm_kernels/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/NX-AI%2Fmlstm_kernels/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/NX-AI","download_url":"https://codeload.github.com/NX-AI/mlstm_kernels/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/NX-AI%2Fmlstm_kernels/sbom","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":262444795,"owners_count":23312229,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","llm","rnn","triton-kernels","xlstm"],"created_at":"2025-06-01T04:12:02.356Z","updated_at":"2025-06-28T14:31:41.574Z","avatar_url":"https://github.com/NX-AI.png","language":"Jupyter Notebook","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Tiled Flash Linear Attention - mLSTM Kernels\n\n\u003cimg src=\"./res/Figure_1-7.svg\" width=\"350px\" alt=\"xLSTM Figure 1\"\u003e \u003cimg src=\"./res/Figure 2 - paper.svg\" width=\"400px\" alt=\"xLSTM Figure 1\"\u003e\n\n\u003ePaper: [https://arxiv.org/abs/2503.14376](https://arxiv.org/abs/2503.14376)\n\u003e\n\u003eAuthors: Maximilian Beck, Korbinian Pöppel, Phillip Lippe, Sepp Hochreiter\n\n\n## About\nThis library provides fast and efficient mLSTM training and inference Triton kernels.\nThe chunkwise-parallel mLSTM Kernels are built on Tiled Flash Linear Attention (TFLA).\n\nThis repository also contains an easy to extend library for any kind of runtime benchmarks, which we use to benchmark our mLSTM kernels, as well as full mLSTM Huggingface models.\n\n## mLSTM Kernel Library Overview\n\nAt its core the mLSTM Kernel library contains several implementations of the mLSTM in JAX, PyTorch as well as kernels in Triton,\nwhich build three toplevel modules within the `mlstm_kernels` library:\n\n- `jax`: Contains JAX native implementations of the mLSTM, as well as JAX Triton integrations.\n- `torch`: Contains PyTorch native implementations of the mLSTM, as well the Triton integrations for PyTorch. It also contains the configurable PyTorch backend module for simple integration of the mLSTM kernels into your models (see below for further details).\n- `triton`: Contains the Triton kernels for the mLSTM, as well as kernel launch parameter heuristics.\n\nThe `utils` module contains code for unit tests, additional analysis (such as the transfer behavior analysis from the TFLA paper) or the benchmark library, which is discussed in detail below.\n\nEach of the three toplevel modules, contains three different types of implementations and kernels for the mLSTM:\n\n- `chunkwise`: Chunkwise kernels, that process chunks of the sequence in parallel. These include the TFLA kernels.\n- `parallel`: Parallel kernels that process a sequence in parallel (like Attention). Overall the runtime of these kernels scales quadratically with sequence length.\n- `recurrent`: Recurrent step kernels for text generation during inference.\n\n## Benchmark of TFLA mLSTM kernels\n\nRuntime comparison of mLSTM chunkwise kernels against other baselines on a NVIDA H100 GPU with a constant number of tokens.\nThis means that as we increase the sequence length on the x-axis we proportionally decrease the batch size to keep the overall number of tokens constant. This is the same setup as for example in FlashAttention 3.\n\n![Kernel Benchmark](./res/plot_tfla_mlstm_kernel_benchmark--paper-rerun.svg)\n\n**Left**: Forward pass\n**Right**: Forward and backward pass\n\n### Kernel description\n\nWe benchmark the two mLSTM versions: mLSTM with exponential input gate (mLSTMexp) and mLSTM with sigmoid input gate (mLSTMsig)\n\n- **mLSTMexp (limit chunk)**: mLSTMexp kernel with limited chunk size (`chunk_size=64`).\n- **mLSTMexp (TFLA XL chunk)**: mLSTMexp TFLA kernel with unlimited chunk size (in this benchmark `chunk_size=128`)\n- **mLSTMsig (TFLA XL chunk)**: mLSTMsig TFLA kernel with unlimited chunk size (in this benchmark `chunk_size=128`)\n\n\u003e In the following `limit_chunk` means chunkwise kernels that are limited in chunk_size and `xl_chunk` means TFLA kernels.\n\nFor more details we refer to the TFLA paper.\n\n\n## Installation\n\nYou can find the conda environment file in the `envs/` folder. We recommend to use the latest file, i.e. `environment_pt251cu124.yaml`\n\nThen you can install the mLSTM kernels via pip: `pip install mlstm_kernels`\nor by cloning the repository.\n\n\n## How to use and integrate our mLSTM kernels\n\nIn this library we proivide PyTorch, JAX and Triton implementations of the mLSTM.\nFor the Triton kernels, we provide wrappers in PyTorch and JAX.\n\nThere are two options to use our implementations and kernels:\n\n### Option 1 (Recommended): Use via backend module\nThis is the recommended option, if you want to use our mLSTM kernels in your own (language) model.\nThe backend module is implemented in `mlstm_kernels/torch/backend_module.py` and provides a configurable wrapper around all our mLSTM implementations and kernels.\n\n\u003eNote: This is also how these kernels are implemented in our official implementation for the xLSTM 7B model (see [xLSTM 7B model.py](https://github.com/NX-AI/xlstm/blob/main/xlstm/xlstm_large/model.py))\n\nIt allows to switch between training and inference mode and automatically selects the respective kernels.\n\nFor example the following code snippet configures the `mLSTMBackend` to use our TFLA mLSTMexp kernel:\n\n```python\n# we use the mLSTMexp TFLA kernel\n# we also configure to use the triton step kernel for inference\nmlstm_backend_config = mLSTMBackendConfig(\n    chunkwise_kernel=\"chunkwise--triton_xl_chunk\",\n    sequence_kernel=\"native_sequence__triton\",\n    step_kernel=\"triton\",\n    chunk_size=256,\n    return_last_states=False,\n)\n\nmlstm_backend = mLSTMBackend(mlstm_backend_config)\n\n# run the backend\nDEVICE = torch.device(\"cuda\")\nDTYPE = torch.bfloat16\nB = 2\nS = 512\nDHQK = 128\nDHHV = 256\nNH = 4\n\n# create input tensors\ntorch.manual_seed(1)\nmatQ = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)\nmatK = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)\nmatV = torch.randn((B, NH, S, DHHV), dtype=DTYPE, device=DEVICE)\nvecI = torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)\nvecF = 3.0 + torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)\n\nmatH = mlstm_backend(q=matQ, k=matK, v=matV, i=vecI, f=vecF)\n```\n\n**Quickstart**: Have a look at the demo notebook `demo/integrate_mlstm_via_backend_module_option1.ipynb`.\n\n\n### Option 2: Direct import\n\nIf you directly want to use a specific kernel you can directly import the kernel from the respective module.\nThe following code snippet import the TFLA mLSTMexp kernel and runs a forward pass.\n\n```python\nimport torch\n# directly import mLSTMexp TFLA kernel\nfrom mlstm_kernels.torch.chunkwise.triton_xl_chunk import mlstm_chunkwise__xl_chunk\n\n# run the kernel\nDEVICE = torch.device(\"cuda\")\nDTYPE = torch.bfloat16\nB = 2\nS = 512\nDHQK = 128\nDHHV = 256\nNH = 4\n\ntorch.manual_seed(1)\nmatQ = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)\nmatK = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)\nmatV = torch.randn((B, NH, S, DHHV), dtype=DTYPE, device=DEVICE)\nvecI = torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)\nvecF = 3.0 + torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)\n\nmatH1 = mlstm_chunkwise__xl_chunk(\n    q=matQ, k=matK, v=matV, i=vecI, f=vecF, return_last_states=False, chunk_size=256\n)\n```\n\n### Option 3: Select the kernel via the kernel specifier\n\nYou can also get a specific kernel function via its kernel specifier.\n\nFirst, display all available kernels via `get_available_mlstm_kernels()`.\nThis displays all kernels that can be used for training and that have a similar function signature such that they can be used interchangably.\n\n```python\n# display all available mlstm chunkwise and parallel kernels\nfrom mlstm_kernels.torch import get_available_mlstm_kernels\n\nget_available_mlstm_kernels()\n```\n```\n['chunkwise--native_autograd',\n 'chunkwise--native_custbw',\n 'chunkwise--triton_limit_chunk',\n 'chunkwise--triton_xl_chunk',\n 'chunkwise--triton_xl_chunk_siging',\n 'parallel--native_autograd',\n 'parallel--native_custbw',\n 'parallel--native_stablef_autograd',\n 'parallel--native_stablef_custbw',\n 'parallel--triton_limit_headdim',\n 'parallel--native_siging_autograd',\n 'parallel--native_siging_custbw']\n```\n\nThen select a kernel via `get_mlstm_kernel()`:\n\n```python\n# select the kernel\nfrom mlstm_kernels.torch import get_mlstm_kernel\n\nmlstm_chunkwise_xl_chunk = get_mlstm_kernel(\"chunkwise--triton_xl_chunk\")\n\nmatH2 = mlstm_chunkwise_xl_chunk(\n    q=matQ, k=matK, v=matV, i=vecI, f=vecF, return_last_states=False, chunk_size=256\n)\n\ntorch.allclose(matH1, matH2, atol=1e-3, rtol=1e-3) # True\n```\n\n**Quickstart for option 2 and 3**: Have a look at the demo notebook `demo/integrate_mlstm_via_direct_import_option2and3.ipynb`.\n\n\n\n### Using the JAX wrappers\n\nThe JAX module `mlstm_kernels.jax` mirrors the PyTorch module `mlstm_kernels.torch` and can be used in the way as the PyTorch kernels with option 2.\n\n\u003c!-- We also aim provide a backend module for Flax soon. --\u003e\n\n## Benchmark Library\n\nThe module `mlstm_kernels.utils.benchmark` contains a configurable benchmark library for benchmarking the runtime and GPU memory usage of kernels or models.\nWe use this library for all our benchmarks in the TFLA paper and the xLSTM 7B paper.\n\n### Overview\n\n**Step 1:** To begin please have a look at `mlstm_kernels/utils/benchmark/benchmarks/interface.py`\n\nAt the core of the benchmark library, there is the `BenchmarkInterface` dataclass, which is the abstract base class that every new benchmark should inherit from.\nThe `BenchmarkInterface` dataclass holds generic benchmark parameters, defines the `setup_benchmark` function that must be overridden for every specific benchmark and also defines the function to benchmark `benchmark_fn`, which is the function that is benchmarked.\nTo run the benchmark the `BenchmarkInterface` has the method `run_benchmark`.\n\nThe `BenchmarkCreator` defines the benchmark collection, i.e. the collection of benchmarks that can be run and configured together via a single config.\nTo create a new benchmark collection, with several benchmarks one has to implement a new `BenchmarkCreator`.\nThis is a function that takes as input a `KernelSpec` dataclass (containing the specification for the benchmark class) and a parameter dict with overrides. It then creates and returns the specified benchmark.\n\n**Step 2:** Next have a look at `mlstm_kernels/utils/benchmark/param_handling.py` in order to understand how the benchmarks are configured through a unified config.\n\nWe use the dataclass `KernelSpec` to provide a unified interface to our kernel benchmarks. The `kernel_name` must be a unique specifier within a benchmark collection. The `additional_params` field are parameters that are overriden in the respective `BenchmarkInterface` class.\n\nOne level above is the `BenchmarkConfig` dataclass. This config class enables to configure sweeps over multiple `KernelSpec` dataclasses.\n\n**Step 3:** Finally, have a look at `mlstm_kernels/utils/benchmark/run_benchmark.py` and a corresponding benchmark script, e.g. `scripts/run_training_kernel_benchmarks.py`.\n\nThe \"benchmark loops\" are implemented in `run_benchmark.py`. These take as input a `BenchmarkConfig` and a `BenchmarkCreator` and run every benchmark member specified in the kernel specs with every parameter combination.\n\nThe `run_and_record_benchmarks()` functions executes these loops, and records the results to disk via .csv files and plots.\n\nFinally, in our case we create scripts that collect several configured benchmarks, which we can then run via different arguments, see for e.g. `scripts/run_training_kernel_benchmarks.py`.\n\nYou should now be able to understand the structure of our benchmark suites, i.e. collections of benchmarks that are run together.\nIn this repository we create several benchmark suites, for example the kernel benchmarks for the TFLA paper or the model benchmarks for the xLSTM 7B paper.\nThese are implemented in `mlstm_kernels/utils/benchmark/benchmarks/training_kernel_benchmarks.py` and `mlstm_kernels/utils/benchmark/benchmarks/huggingface_model_benchmark.py`, respectively.\n\n**Quickstart:** For a quick start please have a look at the demo notebook: `demo/kernel_speed_benchmark.ipynb`.\n\n### Running kernel benchmarks\n\nThe following command runs the mLSTM kernels from the figure above.\nNote that you need a large GPU memory in order to fit the long sequences and large embedding dimension of 4096 for a 7B model.\n\n``` bash\nPYTHONPATH=. python scripts/run_training_kernel_benchmarks.py --consttoken_benchmark mlstm_triton --folder_suffix \"mlstm_bench\" --num_heads 16 --half_qkdim 1\n```\n\nIt will create a new subfolder in `outputs_kernel_benchmarks/` that contains the results.\n\n## Running the unit tests\n\nThe unit tests cross-check the different kernel implementations on numerical deviations for different dtypes.\nYou can run all of them with the following command:\n\n```bash\npytest -s tests/torch\n# make sure you are in a JAX GPU environment\npytest -s tests/jax\n```\n\nThe `-s` disables the log capturing so you see the results directly on the command line.\nEach test will log the outputs to a new folder with the timestamp as name in the `test_outputs/` directory.\n\nNote: The the JAX tests were only tested on NVIDIA H100 GPUs.\n\n## Citation\n\nPlease cite our papers if you use this codebase, or otherwise find our work valuable:\n\n```\n@article{beck:25tfla,\n  title        = {{Tiled Flash Linear Attention}: More Efficient Linear RNN and xLSTM Kernels},\n  author       = {Maximilian Beck and Korbinian Pöppel and Phillip Lippe and Sepp Hochreiter},\n  year         = {2025},\n  volume       = {2503.14376},\n  journal      = {arXiv},\n  primaryclass = {cs.LG},\n  url          = {https://arxiv.org/abs/2503.14376}\n}\n\n@article{beck:25xlstm7b,\n  title        = {{xLSTM 7B}: A Recurrent LLM for Fast and Efficient Inference},\n  author       = {Maximilian Beck and Korbinian Pöppel and Phillip Lippe and Richard Kurle and Patrick M. Blies and Günter Klambauer and Sebastian Böck and Sepp Hochreiter},\n  year         = {2025},\n  volume       = {2503.13427},\n  journal      = {arXiv},\n  primaryclass = {cs.LG},\n  url          = {https://arxiv.org/abs/2503.13427}\n}\n\n@inproceedings{beck:24xlstm,\n      title={xLSTM: Extended Long Short-Term Memory},\n      author={Maximilian Beck and Korbinian Pöppel and Markus Spanring and Andreas Auer and Oleksandra Prudnikova and Michael Kopp and Günter Klambauer and Johannes Brandstetter and Sepp Hochreiter},\n      booktitle = {Thirty-eighth Conference on Neural Information Processing Systems},\n      year={2024},\n      url={https://arxiv.org/abs/2405.04517},\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fnx-ai%2Fmlstm_kernels","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fnx-ai%2Fmlstm_kernels","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fnx-ai%2Fmlstm_kernels/lists"}