{"id":15896154,"url":"https://github.com/fkodom/yet-another-retnet","last_synced_at":"2025-07-07T11:14:44.084Z","repository":{"id":185496720,"uuid":"673407986","full_name":"fkodom/yet-another-retnet","owner":"fkodom","description":"A simple but robust PyTorch implementation of RetNet from \"Retentive Network: A Successor to Transformer for Large Language Models\" (https://arxiv.org/pdf/2307.08621.pdf)","archived":false,"fork":false,"pushed_at":"2023-11-24T15:03:42.000Z","size":463,"stargazers_count":104,"open_issues_count":4,"forks_count":17,"subscribers_count":2,"default_branch":"main","last_synced_at":"2025-03-17T17:11:57.555Z","etag":null,"topics":["deep-learning","llms","machine-learning","natural-language-processing","neural-networks","python","pytorch"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/fkodom.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":".github/FUNDING.yml","license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null},"funding":{"github":["fkodom"],"custom":["fkodom.substack.com"]}},"created_at":"2023-08-01T14:54:46.000Z","updated_at":"2025-01-22T14:39:23.000Z","dependencies_parsed_at":null,"dependency_job_id":"8ee8d972-0842-4078-93e6-e7567621e6bc","html_url":"https://github.com/fkodom/yet-another-retnet","commit_stats":null,"previous_names":["fkodom/yet-another-retnet"],"tags_count":13,"template":false,"template_full_name":"fkodom/python-repo-template","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/fkodom%2Fyet-another-retnet","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/fkodom%2Fyet-another-retnet/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/fkodom%2Fyet-another-retnet/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/fkodom%2Fyet-another-retnet/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/fkodom","download_url":"https://codeload.github.com/fkodom/yet-another-retnet/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":244640217,"owners_count":20486007,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","llms","machine-learning","natural-language-processing","neural-networks","python","pytorch"],"created_at":"2024-10-06T09:06:27.424Z","updated_at":"2025-03-20T15:32:28.700Z","avatar_url":"https://github.com/fkodom.png","language":"Python","funding_links":["https://github.com/sponsors/fkodom","fkodom.substack.com"],"categories":[],"sub_categories":[],"readme":"# yet-another-retnet\n\nA simple but robust PyTorch implementation of RetNet from [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf).\n\n\u003e Also see Microsoft's original implementation: [RetNet @ microsoft/torchscale](https://github.com/microsoft/torchscale/blob/main/torchscale/architecture/retnet.py)\n\u003e\n\u003e Ultimately, their implementation is the ground truth.  I have tried to make my implementation as readable and well-documented as possible, while still being consistent with the original.  My version also includes full type annotations, and a robust set of unit tests.\n\u003e\n\u003e I'm obviously biased, but I find the untyped, config-driven approach in [microsoft/torchscale](https://github.com/microsoft/torchscale/tree/main) clunky and difficult to adapt to other use cases.\n\n\u003cimg src=\"doc/retnet-scaling.jpeg\" alt=\"compare-attention-mechanisms\" width=\"600\"/\u003e\n\n\n### TODO\n\n- [x] Equivalent **parallel** and **recurrent** retention methods.  See: [retention.py](yet_another_retnet/retention.py)\n- [x] Recurrent position embedding implementation.\n- [x] `MultiScaleRetention` module.  See: [retention.py](yet_another_retnet/retention.py)\n- [x] Make relative position embeddings for `MultiScaleRetention` **optional**.\n    - The retention layer explicitly includes a position embedding update, which is based on [xPos](https://arxiv.org/pdf/2212.10554.pdf).  It does not necessarily translate well to other domains (e.g. computer vision, heterogeneous graphs).  So, I have made it optional.\n    - I'm not 100% sure why the authors did this.  It seems overly specific to the language modeling use case, and it's not clear to me that it was necessary.\n- [x] End-to-end `RetNet` module.  See: [retnet.py](yet_another_retnet/retnet.py)\n    - [x] `RetNetDecoderLayer`\n    - [x] `RetNetDecoder`\n- [x] Preconfigured 1.3B, 2.7B, and 6.7B models (untrained).  See: [retnet.py](yet_another_retnet/retnet.py)\n- [x] Reproduce inference memory and throughput benchmarks from the paper.  See: [Inference Benchmarks](#inference-benchmarks), [benchmark_inference.py](scripts/benchmark_inference.py)\n- [x] Release stable version on PyPI.\n    - [x] Prerelease\n    - [x] Stable\n- [x] Equivalent **chunkwise** retention method.\n- [x] Basic training example for language modeling.  See: [train_project_gutenberg.py](./scripts/train_project_gutenberg.py)\n\n\n## Install\n\nPyPI:\n```bash\npip install yet-another-retnet\n```\n\n\u003e **NOTE**: To run the [example training script](./scripts/train_project_gutenberg.py), you will need to include the `[train]` extra package:\n\u003e ```bash\n\u003e pip install yet-another-retnet[train]\n\u003e ```\n\nFrom source:\n```bash\npip install \"yet-another-retnet @ git+ssh://git@github.com/fkodom/yet-another-retnet.git\"\n```\n\nFor contributors:\n```bash\n# Clone/fork the repository\ngh repo clone fkodom/yet-another-retnet\ncd yet-another-retnet\n# Install all dev dependencies (tests etc.) in editable mode\npip install -e .[test]\n# Setup pre-commit hooks\npre-commit install\n```\n\n\n## About\n\nRetNet is a transformer-like architecture that has equivalent **parallel** and **recurrent** formulations.\n\n\u003cimg src=\"doc/retention-dual-forms.jpeg\" alt=\"retention-dual-forms\" width=\"600\"/\u003e\n\n\nThe benefits of this dual formulation are:\n- Accuracy comparable to Transformer-based models\n- **Parallel**: high training throughput\n- **Recurrent**: high inference throughput\n\nThis is the \"impossible triangle\" of language model design, as described by the authors:\n\n\u003cimg src=\"doc/impossible-triangle.jpeg\" alt=\"impossible-triangle\" width=\"350\"/\u003e\n\n\n## Usage\n\n### RetNet\n\nUse one of the configurations described in the paper:\n- `retnet_1_3b`\n- `retnet_2_7b`\n- `retnet_6_7b`\n\n```python\nfrom yet_another_retnet.retnet import retnet_1_3b\n\nretnet = retnet_1_3b(num_tokens=10000, device=\"cuda\")\n```\n\nor create your own `RetNet` model directly:\n\n```python\nfrom yet_another_retnet.retnet import RetNet\n\n# a very small RetNet model :D\nretnet = RetNet(\n    num_tokens=1000, # vocab size, usually taken from tokenizer\n    d_model=64,\n    nhead=4,\n    num_layers=2,\n    device=\"cuda\",\n).eval()  # Important for reproducibility!\n```\n\nEquivalent parallel, recurrent, and chunkwise usage:\n\n```python\nimport torch\n\n# Set deterministic CUDA ops\ntorch.backends.cudnn.deterministic = True\ntorch.backends.cudnn.benchmark = False\n\n# input shape: (batch_size, seq_len)\n# integer range: [0, num_tokens)\nx = torch.randint(0, 1000, (1, 16), device=\"cuda\")\n\n# Parallel usage\ny_parallel = retnet.forward_parallel(x)\n\n# Recurrent usage\noutputs = []  # container for collecting step-wise outputs\nprev_states = []  # cache layer states after each step\nfor idx in range(16):  # seq_len\n    out, prev_states = retnet.forward_recurrent(x[:, idx], idx, prev_states)\n    outputs.append(out)\ny_recurrent = torch.stack(outputs, dim=1)\n\n# Chunkwise usage\noutputs = []  # container for collecting chunk-wise outputs\nprev_states = []  # cache layer states after each step\nchunk_size = 4  # number of tokens in each chunk\nfor idx in range(0, 16, chunk_size):\n    out, prev_states = retnet.forward_chunkwise(\n        x[:, idx : idx + chunk_size], idx, prev_states\n    )\n    outputs.append(out)\ny_chunkwise = torch.cat(outputs, dim=1)\n\n# Check that outputs are equal\ntorch.testing.assert_close(y_parallel, y_recurrent)\ntorch.testing.assert_close(y_parallel, y_chunkwise)\n```\n\n**NOTE**: There is some floating point error accumulation in the recurrent formulation, which I believe is less pronounced in the parallel formulation. Especially for untrained models (when activations are very large), the two outputs may not match *exactly*.  The absolute difference should still be very small -- on the order of 1e-5 or less.\n\n\n### MultiScaleRetention\n\nEquivalent parallel, recurrent, and chunkwise usage:\n\n```python\nimport torch\n\nfrom yet_another_retnet.retention import MultiScaleRetention\n\nmhr = MultiScaleRetention(embed_dim=32, num_heads=4, device=\"cuda\").eval()\n\n# input shape: (batch_size, seq_len, embed_dim)\nq = k = v = torch.randn(1, 16, 32, device=\"cuda\")\n\n# Parallel retention\ny_parallel, _ = mhr.forward_parallel(q, k, v)\n\n# Recurrent retention\noutputs = []\nprev_state = None\nfor idx in range(16):\n    out, prev_state = mhr.forward_recurrent(\n        q[:, idx], k[:, idx], v[:, idx], idx, prev_state\n    )\n    outputs.append(out)\ny_recurrent = torch.stack(outputs, dim=1)\n\n# Chunkwise retention\noutputs = []\nprev_state = None\nchunk_size = 4\nfor idx in range(0, 16, chunk_size):\n    out, prev_state = mhr.forward_chunkwise(\n        q[:, idx : idx + chunk_size],\n        k[:, idx : idx + chunk_size],\n        v[:, idx : idx + chunk_size],\n        idx,\n        prev_state,\n    )\n    outputs.append(out)\ny_chunkwise = torch.cat(outputs, dim=1)\n\n# Check that outputs are equal\ntorch.testing.assert_close(y_parallel, y_recurrent)\ntorch.testing.assert_close(y_parallel, y_chunkwise)\n```\n\n**NOTE**: The `MultiScaleRetention` that is described in the paper includes an\nexplicit position embedding (based on xPos) as part of the retention layer.  This\ndoes not translate perfectly to other domains (e.g. computer vision, heterogeneous\ngraphs), so I have made it optional.\n\nSet `relative_position=False` to disable the position embedding.  Instead, you will\nbe responsible for adding positional information to the inputs (if needed).\n\n```python\n# Disable relative position embedding\nmhr = MultiScaleRetention(\n    embed_dim=32, num_heads=4, relative_position=False, device=\"cuda\"\n)\n# Everything else works the same as above.\n# Just add your own positional embeddings to the inputs.\n```\n\n### Retention forward pass\n\nSimilar to the example above, but head projections and positional updates are not internalized by `MultiScaleRetention`:\n\n```python\nimport torch\n\nfrom yet_another_retnet.retention import (\n    retention_chunkwise,\n    retention_parallel,\n    retention_recurrent,\n)\n\n# input shape: (batch_size, num_heads, seq_len, head_dim)\nq = k = v = torch.randn(1, 4, 32, 8, device=\"cuda\")\n\n# Parallel retention\ny_parallel, _ = retention_parallel(q, k, v)\n\n# Recurrent retention\noutputs = []\nprev_state = None\nfor i in range(32):\n    out, prev_state = retention_recurrent(q[:, :, i], k[:, :, i], v[:, :, i], prev_state)\n    outputs.append(out)\ny_recurrent = torch.stack(outputs, dim=2)\n\n# Chunkwise retention\noutputs = []\nprev_state = None\nchunk_size = 4\nfor i in range(0, 32, chunk_size):\n    out, prev_state = retention_chunkwise(\n        q[:, :, i : i + chunk_size],\n        k[:, :, i : i + chunk_size],\n        v[:, :, i : i + chunk_size],\n        prev_state,\n    )\n    outputs.append(out)\ny_chunkwise = torch.cat(outputs, dim=2)\n\n# Check that outputs are equal\ntorch.testing.assert_close(y_parallel, y_recurrent)\ntorch.testing.assert_close(y_parallel, y_chunkwise)\n```\n\n\n## Inference Benchmarks\n\n\u003e **NOTE**: The benchmarks aren't exact one-to-one comparisons, because I have a much lower-end GPU.  The authors benchmark RetNet 6.7B using an A100 80GB.  I have a 2080 Ti (11GB), so I chose to benchmark RetNet 1.3B instead.  I expect the batch size is also smaller in my benchmarks, although the authors don't specify what their batch size was.\n\u003e \n\u003e My benchmarks clearly show the same scaling trends, which is still helpful as a rough validation.\n\nFrom the paper:\n\n\u003cimg src=\"doc/benchmarks.png\" alt=\"retention-dual-forms\" width=\"600\"/\u003e\n\nFrom this repo:\n\n\u003cp float=\"left\"\u003e\n    \u003cimg src=\"doc/inference-memory.png\" alt=\"retention-dual-forms\" width=\"300\"/\u003e\n    \u003cimg src=\"doc/inference-throughput.png\" alt=\"retention-benchmarks\" width=\"300\"/\u003e\n\u003c/p\u003e\n\n\n## Parallel vs. Recurrent vs. Chunkwise\n\nWhen should you choose one formulation over the others?  Here is a general rule of thumb:\n\n* Parallel -\u003e model training\n* Recurrent -\u003e incremental token prediction\n* Chunkwise -\u003e\n    1. model training -- if training inputs are very long, and parallel formulation is too memory intensive\n    2. encoding long prompts -- if inference *prompts* are very long, chunkwise encoding is more efficient than the recurrent formulation.  The state returned from chunkwise formulation is the same as in recurrent formulation.  So, once the prompt is chunkwise encoded, use recurrent formulation to generate new tokens.\n\n\n## Citations\n\n```bibtex\n@misc{sun2023retentive,\n      title={Retentive Network: A Successor to Transformer for Large Language Models}, \n      author={Yutao Sun and Li Dong and Shaohan Huang and Shuming Ma and Yuqing Xia and Jilong Xue and Jianyong Wang and Furu Wei},\n      year={2023},\n      eprint={2307.08621},\n      archivePrefix={arXiv},\n      primaryClass={cs.CL}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ffkodom%2Fyet-another-retnet","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Ffkodom%2Fyet-another-retnet","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ffkodom%2Fyet-another-retnet/lists"}