{"id":13563917,"url":"https://github.com/lucidrains/RETRO-pytorch","last_synced_at":"2025-04-03T20:32:09.952Z","repository":{"id":40449808,"uuid":"448398954","full_name":"lucidrains/RETRO-pytorch","owner":"lucidrains","description":"Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch","archived":false,"fork":false,"pushed_at":"2023-10-30T17:10:58.000Z","size":190,"stargazers_count":859,"open_issues_count":17,"forks_count":106,"subscribers_count":26,"default_branch":"main","last_synced_at":"2025-04-01T00:32:09.585Z","etag":null,"topics":["artificial-intelligence","attention-mechanism","deep-learning","retrieval","transformers"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/lucidrains.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null}},"created_at":"2022-01-15T21:51:34.000Z","updated_at":"2025-03-13T19:06:25.000Z","dependencies_parsed_at":"2024-01-14T03:50:46.261Z","dependency_job_id":"cd7fcf36-2035-49f5-8a74-864742daa852","html_url":"https://github.com/lucidrains/RETRO-pytorch","commit_stats":{"total_commits":128,"total_committers":7,"mean_commits":"18.285714285714285","dds":0.1015625,"last_synced_commit":"ab3c4a6f66341409b2b9105661682376355b4673"},"previous_names":[],"tags_count":70,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2FRETRO-pytorch","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2FRETRO-pytorch/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2FRETRO-pytorch/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2FRETRO-pytorch/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/lucidrains","download_url":"https://codeload.github.com/lucidrains/RETRO-pytorch/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247075182,"owners_count":20879400,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["artificial-intelligence","attention-mechanism","deep-learning","retrieval","transformers"],"created_at":"2024-08-01T13:01:24.528Z","updated_at":"2025-04-03T20:32:09.644Z","avatar_url":"https://github.com/lucidrains.png","language":"Python","funding_links":[],"categories":["Python"],"sub_categories":[],"readme":"\u003cimg src=\"./RETRO.png\" width=\"500px\"\u003e\u003c/img\u003e\n\n## RETRO - Pytorch\n\nImplementation of \u003ca href=\"https://arxiv.org/abs/2112.04426\"\u003eRETRO\u003c/a\u003e, Deepmind's Retrieval based Attention net, in Pytorch. This will deviate from the paper slightly, using rotary embeddings for relative positional encoding, as well as Faiss library instead of Scann.\n\nThis library leverages \u003ca href=\"https://github.com/criteo/autofaiss\"\u003eautofaiss\u003c/a\u003e for building the index and calculating the k-nearest neighbors for all chunks.\n\n\u003ca href=\"http://jalammar.github.io/illustrated-retrieval-transformer/\"\u003eJay Alammar explanatory blogpost\u003c/a\u003e\n\nThe selling point of this retriever approach is reaching GPT-3 performance at 10x less parameters. More research is \u003ca href=\"https://arxiv.org/abs/2009.06857\"\u003edefinitely deserved\u003c/a\u003e in this area.\n\nI have also included the features necessary to scale the retrieval transformer to 1000 layers, if the claims of \u003ca href=\"https://arxiv.org/abs/2203.00555\"\u003eDeepNet paper\u003c/a\u003e is to be believed.\n\nUpdate: Someone on Reddit has gifted me a \u003ca href=\"https://old.reddit.com/r/MachineLearning/comments/s4f1p8/d_is_there_an_opensource_implementation_of_the/hstia5r/\"\u003eGold Award\u003c/a\u003e. Not sure what it is, but thank you! 🙏\n\nUpdate: Deepnorm has been validated at scale in a \u003ca href=\"https://keg.cs.tsinghua.edu.cn/glm-130b/\"\u003e130B model out of Tsinghua\u003c/a\u003e. It is now recommended that you train with `use_deepnet` set to `True`\n\n## Install\n\n```bash\n$ pip install retro-pytorch\n````\n\n## Usage\n\n```python\nimport torch\nfrom retro_pytorch import RETRO\n\nretro = RETRO(\n    chunk_size = 64,                         # the chunk size that is indexed and retrieved (needed for proper relative positions as well as causal chunked cross attention)\n    max_seq_len = 2048,                      # max sequence length\n    enc_dim = 896,                           # encoder model dim\n    enc_depth = 2,                           # encoder depth\n    dec_dim = 796,                           # decoder model dim\n    dec_depth = 12,                          # decoder depth\n    dec_cross_attn_layers = (3, 6, 9, 12),   # decoder cross attention layers (with causal chunk cross attention)\n    heads = 8,                               # attention heads\n    dim_head = 64,                           # dimension per head\n    dec_attn_dropout = 0.25,                 # decoder attention dropout\n    dec_ff_dropout = 0.25,                   # decoder feedforward dropout\n    use_deepnet = True                       # turn on post-normalization with DeepNet residual scaling and initialization, for scaling to 1000 layers\n)\n\nseq = torch.randint(0, 20000, (2, 2048 + 1))      # plus one since it is split into input and labels for training\nretrieved = torch.randint(0, 20000, (2, 32, 2, 128)) # retrieved tokens - (batch, num chunks, num retrieved neighbors, retrieved chunk with continuation)\n\nloss = retro(seq, retrieved, return_loss = True)\nloss.backward()\n\n# do above for many steps\n```\n\n\n## RETRO Training Wrapper\n\nThe aim of the `TrainingWrapper` is to process a folder of text documents into the necessary memmapped numpy arrays to begin training `RETRO`.\n\n```python\nimport torch\nfrom retro_pytorch import RETRO, TrainingWrapper\n\n# instantiate RETRO, fit it into the TrainingWrapper with correct settings\n\nretro = RETRO(\n    max_seq_len = 2048,                      # max sequence length\n    enc_dim = 896,                           # encoder model dimension\n    enc_depth = 3,                           # encoder depth\n    dec_dim = 768,                           # decoder model dimensions\n    dec_depth = 12,                          # decoder depth\n    dec_cross_attn_layers = (1, 3, 6, 9),    # decoder cross attention layers (with causal chunk cross attention)\n    heads = 8,                               # attention heads\n    dim_head = 64,                           # dimension per head\n    dec_attn_dropout = 0.25,                 # decoder attention dropout\n    dec_ff_dropout = 0.25                    # decoder feedforward dropout\n).cuda()\n\nwrapper = TrainingWrapper(\n    retro = retro,                                 # path to retro instance\n    knn = 2,                                       # knn (2 in paper was sufficient)\n    chunk_size = 64,                               # chunk size (64 in paper)\n    documents_path = './text_folder',              # path to folder of text\n    glob = '**/*.txt',                             # text glob\n    chunks_memmap_path = './train.chunks.dat',     # path to chunks\n    seqs_memmap_path = './train.seq.dat',          # path to sequence data\n    doc_ids_memmap_path = './train.doc_ids.dat',   # path to document ids per chunk (used for filtering neighbors belonging to same document)\n    max_chunks = 1_000_000,                        # maximum cap to chunks\n    max_seqs = 100_000,                            # maximum seqs\n    knn_extra_neighbors = 100,                     # num extra neighbors to fetch\n    max_index_memory_usage = '100m',\n    current_memory_available = '1G'\n)\n\n# get the dataloader and optimizer (AdamW with all the correct settings)\n\ntrain_dl = iter(wrapper.get_dataloader(batch_size = 2, shuffle = True))\noptim = wrapper.get_optimizer(lr = 3e-4, wd = 0.01)\n\n# now do your training\n# ex. one gradient step\n\nseq, retrieved = map(lambda t: t.cuda(), next(train_dl))\n\n# seq       - (2, 2049)         - 1 extra token since split by seq[:, :-1], seq[:, 1:]\n# retrieved - (2, 32, 2, 128)   - 128 since chunk + continuation, each 64 tokens\n\nloss = retro(\n    seq,\n    retrieved,\n    return_loss = True\n)\n\n# one gradient step\n\nloss.backward()\noptim.step()\noptim.zero_grad()\n\n# do above for many steps, then ...\n\n# topk sampling with retrieval at chunk boundaries\n\nsampled = wrapper.generate(filter_thres = 0.9, temperature = 1.0) # (1, \u003c2049) terminates early if all \u003ceos\u003e\n\n# or you can generate with a prompt, knn retrieval for initial chunks all taken care of\n\nprompt = torch.randint(0, 1000, (1, 128))  # start with two chunks worth of sequence\nsampled = wrapper.generate(prompt, filter_thres = 0.9, temperature = 1.0) # (1, \u003c2049) terminates early if all \u003ceos\u003e\n\n```\n\nIf you wish to force a reprocess of the training data, simply run your script with a `REPROCESS=1` environment flag as so\n\n```bash\n$ REPROCESS=1 python train.py\n```\n\n## RETRO Datasets\n\nThe `RETRODataset` class accepts paths to a number of memmapped numpy arrays containing the chunks, the index of the first chunk in the sequence to be trained on (in RETRO decoder), and the pre-calculated indices of the k-nearest neighbors per chunk.\n\nYou can use this to easily assemble the data for `RETRO` training, if you do not wish to use the `TrainingWrapper` from above.\n\nFurthermore, all the functions needed to create the necessary memmapped data is in the sections to follow.\n\n\n```python\nimport torch\nfrom torch.utils.data import DataLoader\nfrom retro_pytorch import RETRO, RETRODataset\n\n# mock data constants\n\nimport numpy as np\n\nNUM_CHUNKS = 1000\nCHUNK_SIZE = 64\nNUM_SEQS = 100\nNUM_NEIGHBORS = 2\n\ndef save_memmap(path, tensor):\n    f = np.memmap(path, dtype = tensor.dtype, mode = 'w+', shape = tensor.shape)\n    f[:] = tensor\n    del f\n\n# generate mock chunk data\n\nsave_memmap(\n    './train.chunks.dat',\n    np.int32(np.random.randint(0, 8192, size = (NUM_CHUNKS, CHUNK_SIZE + 1)))\n)\n\n# generate nearest neighbors for each chunk\n\nsave_memmap(\n    './train.chunks.knn.dat',\n    np.int32(np.random.randint(0, 1000, size = (NUM_CHUNKS, NUM_NEIGHBORS)))\n)\n\n# generate seq data\n\nsave_memmap(\n    './train.seq.dat',\n    np.int32(np.random.randint(0, 128, size = (NUM_SEQS,)))\n)\n\n# instantiate dataset class\n# which constructs the sequence and neighbors from memmapped chunk and neighbor information\n\ntrain_ds = RETRODataset(\n    num_sequences = NUM_SEQS,\n    num_chunks = NUM_CHUNKS,\n    num_neighbors = NUM_NEIGHBORS,\n    chunk_size = CHUNK_SIZE,\n    seq_len = 2048,\n    chunk_memmap_path = './train.chunks.dat',\n    chunk_nn_memmap_path = './train.chunks.knn.dat',\n    seq_memmap_path = './train.seq.dat'\n)\n\ntrain_dl = iter(DataLoader(train_ds, batch_size = 2))\n\n# one forwards and backwards\n\nretro = RETRO(\n    max_seq_len = 2048,                      # max sequence length\n    enc_dim = 896,                           # encoder model dimension\n    enc_depth = 3,                           # encoder depth\n    dec_dim = 768,                           # decoder model dimensions\n    dec_depth = 12,                          # decoder depth\n    dec_cross_attn_layers = (1, 3, 6, 9),    # decoder cross attention layers (with causal chunk cross attention)\n    heads = 8,                               # attention heads\n    dim_head = 64,                           # dimension per head\n    dec_attn_dropout = 0.25,                 # decoder attention dropout\n    dec_ff_dropout = 0.25                    # decoder feedforward dropout\n).cuda()\n\nseq, retrieved = map(lambda t: t.cuda(), next(train_dl))\n\n# seq       - (2, 2049)         - 1 extra token since split by seq[:, :-1], seq[:, 1:]\n# retrieved - (2, 32, 2, 128)   - 128 since chunk + continuation, each 64 tokens\n\nloss = retro(\n    seq,\n    retrieved,\n    return_loss = True\n)\n\nloss.backward()\n\n```\n\n## Retrieval related tools\n\nThis repository will use the default tokenizer (sentencepiece) for the cased version of BERT. Embeddings will be fetched from the vanilla BERT, and can either be masked mean pooled representation, or the CLS token.\n\nex. masked mean pooled representation\n\n```python\nfrom retro_pytorch.retrieval import bert_embed, tokenize\n\nids = tokenize([\n    'hello world',\n    'foo bar'\n])\n\nembeds = bert_embed(ids) # (2, 768) - 768 is hidden dimension of BERT\n```\n\nex. CLS token representation\n\n\n```python\nfrom retro_pytorch.retrieval import bert_embed, tokenize\n\nids = tokenize([\n    'hello world',\n    'foo bar'\n])\n\nembeds = bert_embed(ids, return_cls_repr = True) # (2, 768)\n```\n\nCreate your chunks and chunk start indices (for calculating sequence ranges for autoregressive training) using `text_folder_to_chunks_`\n\n```python\nfrom retro_pytorch.retrieval import text_folder_to_chunks_\n\nstats = text_folder_to_chunks_(\n    folder = './text_folder',\n    glob = '**/*.txt',\n    chunks_memmap_path = './train.chunks.dat',\n    seqs_memmap_path = './train.seq.dat',\n    doc_ids_memmap_path = './train.doc_ids.dat',  # document ids are needed for filtering out neighbors belonging to same document appropriately during computation of nearest neighbors\n    chunk_size = 64,\n    seq_len = 2048,\n    max_chunks = 1_000_000,\n    max_seqs = 100_000\n)\n\n# {'chunks': \u003cnumber of chunks\u003e, 'docs': \u003cnumber of documents\u003e, 'seqs': \u003cnumber of sequences\u003e}\n```\n\n## Fetching Nearest Neighbors\n\nYou can turn your memmapped chunks numpy array into embeddings and a faiss index with one command\n\n```python\nfrom retro_pytorch.retrieval import chunks_to_index_and_embed\n\nindex, embeddings = chunks_to_index_and_embed(\n    num_chunks = 1000,\n    chunk_size = 64,\n    chunk_memmap_path = './train.chunks.dat'\n)\n\nquery_vector = embeddings[:1]                   # use first embedding as query\n_, indices = index.search(query_vector, k = 2)  # fetch 2 neighbors, first indices should be self\n\nneighbor_embeddings = embeddings[indices]       # (1, 2, 768)\n\n```\n\nYou can also directly calculate the nearest neighbor file necessary for training, with `chunks_to_precalculated_knn_` command\n\n```python\nfrom retro_pytorch.retrieval import chunks_to_precalculated_knn_\n\nchunks_to_precalculated_knn_(\n    num_chunks = 1000,\n    chunk_size = 64,\n    chunk_memmap_path = './train.chunks.dat',    # path to main chunks dataset\n    doc_ids_memmap_path = './train.doc_ids.dat', # path to document ids created by text_folder_to_chunks_, used for filtering out neighbors that belong to the same document\n    num_nearest_neighbors = 2,                   # number of nearest neighbors you'd like to use\n    num_extra_neighbors = 10                     # fetch 10 extra neighbors, in the case that fetched neighbors are frequently from same document (filtered out)\n)\n\n# nearest neighbor info saved to ./train.chunks.knn.dat\n\n```\n\n## Citations\n\n```bibtex\n@misc{borgeaud2022improving,\n    title   = {Improving language models by retrieving from trillions of tokens}, \n    author  = {Sebastian Borgeaud and Arthur Mensch and Jordan Hoffmann and Trevor Cai and Eliza Rutherford and Katie Millican and George van den Driessche and Jean-Baptiste Lespiau and Bogdan Damoc and Aidan Clark and Diego de Las Casas and Aurelia Guy and Jacob Menick and Roman Ring and Tom Hennigan and Saffron Huang and Loren Maggiore and Chris Jones and Albin Cassirer and Andy Brock and Michela Paganini and Geoffrey Irving and Oriol Vinyals and Simon Osindero and Karen Simonyan and Jack W. Rae and Erich Elsen and Laurent Sifre},\n    year  = {2022},\n    eprint = {2112.04426},\n    archivePrefix = {arXiv},\n    primaryClass = {cs.CL}\n}\n```\n\n```bibtex\n@misc{su2021roformer,\n    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},\n    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},\n    year    = {2021},\n    eprint  = {2104.09864},\n    archivePrefix = {arXiv},\n    primaryClass = {cs.CL}\n}\n```\n\n```bibtex\n@article{Wang2022DeepNetST,\n    title   = {DeepNet: Scaling Transformers to 1, 000 Layers},\n    author  = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei},\n    journal = {ArXiv},\n    year    = {2022},\n    volume  = {abs/2203.00555}\n}\n```\n\n```bibtex\n@misc{zhang2021sparse,\n    title   = {Sparse Attention with Linear Units},\n    author  = {Biao Zhang and Ivan Titov and Rico Sennrich},\n    year    = {2021},\n    eprint  = {2104.07012},\n    archivePrefix = {arXiv},\n    primaryClass = {cs.CL}\n}\n```\n\n*I consider always the adult life to be the continuous retrieval of childhood.* - Umberto Eco\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flucidrains%2FRETRO-pytorch","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Flucidrains%2FRETRO-pytorch","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flucidrains%2FRETRO-pytorch/lists"}