{"id":15601033,"url":"https://github.com/lucidrains/routing-transformer","last_synced_at":"2025-04-05T14:04:07.828Z","repository":{"id":39666971,"uuid":"266185474","full_name":"lucidrains/routing-transformer","owner":"lucidrains","description":"Fully featured implementation of Routing Transformer","archived":false,"fork":false,"pushed_at":"2021-11-06T23:11:38.000Z","size":35876,"stargazers_count":291,"open_issues_count":10,"forks_count":30,"subscribers_count":12,"default_branch":"master","last_synced_at":"2025-03-29T13:09:57.141Z","etag":null,"topics":["artificial-intelligence","attention-mechanism","deep-learning","pytorch","transformer"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/lucidrains.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2020-05-22T18:54:10.000Z","updated_at":"2025-03-25T03:26:44.000Z","dependencies_parsed_at":"2022-09-20T07:10:18.316Z","dependency_job_id":null,"html_url":"https://github.com/lucidrains/routing-transformer","commit_stats":null,"previous_names":[],"tags_count":38,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Frouting-transformer","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Frouting-transformer/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Frouting-transformer/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Frouting-transformer/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/lucidrains","download_url":"https://codeload.github.com/lucidrains/routing-transformer/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247345850,"owners_count":20924102,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["artificial-intelligence","attention-mechanism","deep-learning","pytorch","transformer"],"created_at":"2024-10-03T02:12:40.866Z","updated_at":"2025-04-05T14:04:07.808Z","avatar_url":"https://github.com/lucidrains.png","language":"Python","funding_links":[],"categories":["Python"],"sub_categories":[],"readme":"## Routing Transformer\n\n\u003cimg src=\"./routing_attention.png\" width=\"500px\"\u003e\u003c/img\u003e\n\n[![PyPI version](https://badge.fury.io/py/routing-transformer.svg)](https://badge.fury.io/py/routing-transformer)\n\nA fully featured implementation of \u003ca href=\"https://arxiv.org/pdf/2003.05997.pdf\"\u003eRouting Transformer\u003c/a\u003e. The paper proposes using k-means to route similar queries / keys into the same cluster for attention.\n\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1sw1Hjx3EllhKZh4nhJ3TIZ978HjKVUFQ?usp=sharing) 131k tokens\n\n### Install\n\n```bash\n$ pip install routing_transformer\n```\n\n### Usage\n\nA simple language model\n\n```python\nimport torch\nfrom routing_transformer import RoutingTransformerLM\n\nmodel = RoutingTransformerLM(\n    num_tokens = 20000,\n    dim = 512,\n    heads = 8,\n    depth = 12,\n    max_seq_len = 8192,\n    causal = True,           # auto-regressive or not\n    emb_dim = 128,           # embedding factorization, from Albert\n    weight_tie = False,      # weight tie layers, from Albert\n    tie_embedding = False,   # multiply final embeddings with token weights for logits\n    dim_head = 64,           # be able to fix the dimension of each head, making it independent of the embedding dimension and the number of heads\n    attn_dropout = 0.1,      # dropout after attention\n    attn_layer_dropout = 0., # dropout after self attention layer\n    ff_dropout = 0.1,        # feedforward dropout\n    layer_dropout = 0.,      # layer dropout\n    window_size = 128,       # target window size of each cluster\n    n_local_attn_heads = 4,  # number of local attention heads\n    reversible = True,       # reversible networks for memory savings, from Reformer paper\n    ff_chunks = 10,          # feed forward chunking, from Reformer paper\n    ff_glu = True,           # use GLU variant in feedforward\n    pkm_layers = (4, 7),     # specify layers to use product key memory. paper shows 1 or 2 modules near the middle of the transformer is best\n    pkm_num_keys = 128,      # defaults to 128, but can be increased to 256 or 512 as memory allows\n    moe_layers = (3, 6),     # specify which layers to use mixture of experts\n    moe_num_experts = 4,     # number of experts in the mixture of experts layer, defaults to 4. increase for adding more parameters to model\n    moe_loss_coef = 1e-2,    # the weight for the auxiliary loss in mixture of experts to keep expert usage balanced\n    num_mem_kv = 8,          # number of memory key/values to append to each cluster of each head, from the 'All-Attention' paper. defaults to 1 in the causal case for unshared QK to work\n    use_scale_norm = False,  # use scale norm, simplified normalization from 'Transformers without Tears' paper\n    use_rezero = False,      # use Rezero with no normalization\n    shift_tokens = True      # shift tokens by one along sequence dimension, for a slight improvement in convergence\n).cuda()\n\nx = torch.randint(0, 20000, (1, 8192)).long().cuda()\ninput_mask = torch.ones_like(x).bool().cuda()\n\ny, aux_loss = model(x, input_mask = input_mask) # (1, 8192, 20000)\naux_loss.backward() # add auxiliary loss to main loss before backprop\n```\n\nA simple transformer\n\n```python\nimport torch\nfrom routing_transformer import RoutingTransformer\n\nmodel = RoutingTransformer(\n    dim = 512,\n    heads = 8,\n    depth = 12,\n    max_seq_len = 8192,\n    window_size = 128,\n    n_local_attn_heads = 4\n).cuda()\n\nx = torch.randn(1, 8192, 512).cuda()\ninput_mask = torch.ones(1, 8192).bool().cuda()\n\ny, aux_loss = model(x, input_mask = input_mask) # (1, 8192, 512)\naux_loss.backward() # add auxiliary loss to main loss before backprop\n```\n\n## Encoder Decoder\n\nTo use a full encoder, decoder, simply import the `RoutingTransformerEncDec` class. Save for the `dim` keyword, all other keywords will be either prepended with `enc_` or `dec_` for the encoder and decoder `RoutingTransformerLM` class respectively.\n\n```python\nimport torch\nfrom routing_transformer import RoutingTransformerEncDec\n\nmodel = RoutingTransformerEncDec(\n    dim=512,\n    enc_num_tokens = 20000,\n    enc_depth = 4,\n    enc_heads = 8,\n    enc_max_seq_len = 4096,\n    enc_window_size = 128,\n    dec_num_tokens = 20000,\n    dec_depth = 4,\n    dec_heads = 8,\n    dec_max_seq_len = 4096,\n    dec_window_size = 128,\n    dec_reversible = True\n).cuda()\n\nsrc = torch.randint(0, 20000, (1, 4096)).cuda()\ntgt = torch.randint(0, 20000, (1, 4096)).cuda()\nsrc_mask = torch.ones_like(src).bool().cuda()\ntgt_mask = torch.ones_like(tgt).bool().cuda()\n\nloss, aux_loss = model(src, tgt, enc_input_mask = src_mask, dec_input_mask = tgt_mask, return_loss = True, randomly_truncate_sequence = True)\nloss.backward()\naux_loss.backward()\n\n# do your training, then to sample up to 2048 tokens based on the source sequence\nsrc = torch.randint(0, 20000, (1, 4096)).cuda()\nstart_tokens = torch.ones(1, 1).long().cuda() # assume starting token is 1\n\nsample = model.generate(src, start_tokens, seq_len = 2048, eos_token = 2) # (1, \u003c= 2048, 20000)\n```\n\n## Product Key Memory\n\nTo see the benefits of using PKM, the learning rate of the values must be set higher than the rest of the parameters. (Recommended to be `1e-2`)\n\nYou can follow the instructions here to set it correctly https://github.com/lucidrains/product-key-memory#learning-rates\n\n## Kmeans Hyperparameters\n\n1. `kmeans_ema_decay = {defaults to 0.999}`\n\nThis is the exponential moving average decay for updating the k-means. The lower this is, the faster the means will adjust, but at the cost of stability.\n\n2. `commitment_factor = {defaults to 1e-4}`\n\nThe weight of the auxiliary loss that encourages tokens to get closer (commit) to the k-mean centroids that were chosen for them.\n\n## Updating kmeans manually\n\nThe following instructions will allow you to update the kmeans manually. By default the kmeans are updated automatically on every backward pass.\n\n```python\nimport torch\nfrom routing_transformer import RoutingTransformerLM, AutoregressiveWrapper\n\nmodel = RoutingTransformerLM(\n    num_tokens = 20000,\n    dim = 1024,\n    heads = 8,\n    depth = 6,\n    window_size = 256,\n    max_seq_len = 8192,\n    causal = True,\n    _register_kmeans_update = False # set to False to disable auto-updating\n)\n\nmodel = AutoregressiveWrapper(model)\n\nx = torch.randint(0, 20000, (1, 8192))\nloss = model(x, return_loss = True)\nloss.backward()\n\n# update kmeans with this call\nmodel.update_kmeans()\n```\n\n## Issues\n\nThis architecture has trouble generalizing to shorter sequence lengths when decoding tokens from 1 -\u003e maximum sequence length. The simplest and surest solution is to randomly truncate the sequence during training. This helps the network and the kmeans generalize to variable number of tokens, at the cost of prolonged training.\n\nIf you are priming the network with the full sequence length at start, then you will not face this problem, and you can skip this training procedure.\n\n\n```python\nimport torch\nfrom routing_transformer import RoutingTransformerLM, AutoregressiveWrapper\n\nmodel = RoutingTransformerLM(\n    num_tokens = 20000,\n    dim = 1024,\n    heads = 8,\n    depth = 12,\n    window_size = 256,\n    max_seq_len = 8192,\n    causal = True\n)\n\nmodel = AutoregressiveWrapper(model)\n\nx = torch.randint(0, 20000, (1, 8192))\nloss = model(x, return_loss = True, randomly_truncate_sequence = True) # (1, 8192, 20000)\n```\n\n## Appreciation\n\nSpecial thanks to \u003ca href=\"https://github.com/AranKomat\"\u003eAran Komatsuzaki\u003c/a\u003e for bootstrapping the initial implementation in Pytorch that evolved into this library.\n\n## Citation\n\n```bibtex\n@misc{roy*2020efficient,\n    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},\n    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},\n    year    = {2020},\n    url     = {https://arxiv.org/pdf/2003.05997.pdf}\n}\n```\n\n```bibtex\n@misc{shazeer2020glu,\n    title   = {GLU Variants Improve Transformer},\n    author  = {Noam Shazeer},\n    year    = {2020},\n    url     = {https://arxiv.org/abs/2002.05202}    \n}\n```\n\n```bibtex\n@inproceedings{kitaev2020reformer,\n    title       = {Reformer: The Efficient Transformer},\n    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},\n    booktitle   = {International Conference on Learning Representations},\n    year        = {2020},\n    url         = {https://openreview.net/forum?id=rkgNKkHtvB}\n}\n```\n\n```bibtex\n@inproceedings{fan2020reducing,\n    title     ={Reducing Transformer Depth on Demand with Structured Dropout},\n    author    ={Angela Fan and Edouard Grave and Armand Joulin},\n    booktitle ={International Conference on Learning Representations},\n    year      ={2020},\n    url       ={https://openreview.net/forum?id=SylO2yStDr}\n}\n```\n\n```bibtex\n@misc{lan2019albert,\n    title       = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},\n    author      = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},\n    year        = {2019},\n    url         = {https://arxiv.org/abs/1909.11942}\n}\n```\n\n```bibtex\n@misc{lample2019large,\n    title   = {Large Memory Layers with Product Keys},\n    author  = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},\n    year    = {2019},\n    eprint  = {1907.05242},\n    archivePrefix = {arXiv}\n}\n```\n\n```bibtex\n@article{DBLP:journals/corr/abs-1907-01470,\n    author    = {Sainbayar Sukhbaatar and\n               Edouard Grave and\n               Guillaume Lample and\n               Herv{\\'{e}} J{\\'{e}}gou and\n               Armand Joulin},\n    title     = {Augmenting Self-attention with Persistent Memory},\n    journal   = {CoRR},\n    volume    = {abs/1907.01470},\n    year      = {2019},\n    url       = {http://arxiv.org/abs/1907.01470}\n}\n```\n\n```bibtex\n@misc{bhojanapalli2020lowrank,\n    title   = {Low-Rank Bottleneck in Multi-head Attention Models},\n    author  = {Srinadh Bhojanapalli and Chulhee Yun and Ankit Singh Rawat and Sashank J. Reddi and Sanjiv Kumar},\n    year    = {2020},\n    eprint  = {2002.07028}\n}\n```\n\n```bibtex\n@article{1910.05895,\n    author  = {Toan Q. Nguyen and Julian Salazar},\n    title   = {Transformers without Tears: Improving the Normalization of Self-Attention},\n    year    = {2019},\n    eprint  = {arXiv:1910.05895},\n    doi     = {10.5281/zenodo.3525484},\n}\n```\n\n```bibtex\n@misc{bachlechner2020rezero,\n    title   = {ReZero is All You Need: Fast Convergence at Large Depth},\n    author  = {Thomas Bachlechner and Bodhisattwa Prasad Majumder and Huanru Henry Mao and Garrison W. Cottrell and Julian McAuley},\n    year    = {2020},\n    url     = {https://arxiv.org/abs/2003.04887}\n}\n```\n\n```bibtex\n@misc{vaswani2017attention,\n    title   = {Attention Is All You Need},\n    author  = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},\n    year    = {2017},\n    eprint  = {1706.03762},\n    archivePrefix = {arXiv},\n    primaryClass = {cs.CL}\n}\n```\n\n```bibtex\n@software{peng_bo_2021_5196578,\n    author       = {PENG Bo},\n    title        = {BlinkDL/RWKV-LM: 0.01},\n    month        = {aug},\n    year         = {2021},\n    publisher    = {Zenodo},\n    version      = {0.01},\n    doi          = {10.5281/zenodo.5196578},\n    url          = {https://doi.org/10.5281/zenodo.5196578}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flucidrains%2Frouting-transformer","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Flucidrains%2Frouting-transformer","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flucidrains%2Frouting-transformer/lists"}