{"id":15601048,"url":"https://github.com/lucidrains/marge-pytorch","last_synced_at":"2025-06-22T20:35:34.135Z","repository":{"id":55591325,"uuid":"290005274","full_name":"lucidrains/marge-pytorch","owner":"lucidrains","description":"Implementation of Marge, Pre-training via Paraphrasing, in Pytorch","archived":false,"fork":false,"pushed_at":"2021-01-14T22:31:47.000Z","size":170,"stargazers_count":76,"open_issues_count":5,"forks_count":11,"subscribers_count":11,"default_branch":"master","last_synced_at":"2025-05-14T20:58:16.985Z","etag":null,"topics":["artificial-intelligence","deep-learning","pre-training","retrieval","transformers"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/lucidrains.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2020-08-24T18:17:57.000Z","updated_at":"2025-05-14T08:15:14.000Z","dependencies_parsed_at":"2022-08-15T03:50:45.152Z","dependency_job_id":null,"html_url":"https://github.com/lucidrains/marge-pytorch","commit_stats":null,"previous_names":[],"tags_count":22,"template":false,"template_full_name":null,"purl":"pkg:github/lucidrains/marge-pytorch","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fmarge-pytorch","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fmarge-pytorch/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fmarge-pytorch/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fmarge-pytorch/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/lucidrains","download_url":"https://codeload.github.com/lucidrains/marge-pytorch/tar.gz/refs/heads/master","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fmarge-pytorch/sbom","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":261126364,"owners_count":23113303,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["artificial-intelligence","deep-learning","pre-training","retrieval","transformers"],"created_at":"2024-10-03T02:13:00.144Z","updated_at":"2025-06-22T20:35:29.116Z","avatar_url":"https://github.com/lucidrains.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"\u003cimg src=\"./marge.png\" width=\"600px\"\u003e\u003c/img\u003e\n\n## Marge - Pre-training via Paraphrasing\n\nImplementation of \u003ca href=\"https://arxiv.org/abs/2006.15020\"\u003eMarge\u003c/a\u003e, Pre-training via Paraphrasing, in Pytorch. It is an alternative to masked language modeling pretraining, where an encoder / decoder attention network learns to reconstruct a target document from a collection of evidence documents.\n\nUpdate: Three researchers have independently reported that the repository works for them\n\n## Install\n\n```bash\n$ pip install marge-pytorch\n```\n\n## Usage\n\n```python\nimport torch\nimport numpy as np\nfrom torch.utils.data import DataLoader\n\nfrom marge_pytorch import Marge, TrainingWrapper\n\n# your documents must be tokenized and stored as memmap in the shape (num documents, seq length)\n\n# constants\nNUM_DOCS = 10000\nSEQ_LEN = 1024\nSHAPE = (NUM_DOCS, SEQ_LEN)\n\n# generate mock training data\nf = np.memmap('./train.dat', dtype=np.int32, mode='w+', shape=SHAPE)\nf[:] = np.random.randint(0, 20000, size=SHAPE)\ndel f\n\n# generate mock masking data\nf = np.memmap('./train.mask.dat', dtype=np.bool, mode='w+', shape=SHAPE)\nf[:] = np.full(SHAPE, True)\ndel f\n\n# instantiate model\n\nmodel = Marge(\n    dim = 512,\n    num_tokens = 20000,\n    max_seq_len = SEQ_LEN,\n    enc_depth = 12,\n    enc_retrieval_depth = 4,                # defaults to 4 as in paper (take the CLS token after the 4th layer of the encoder)\n    enc_heads = 8,\n    enc_ff_mult = 4,\n    dec_depth = 12,\n    dec_heads = 8,\n    dec_ff_mult = 16,                       # paper noted that decoder needs to have much bigger feed forward sizes\n    distill_attn = False,                   # (experimental) will add, on top of the decoder loss, an auxiliary distillation loss as defined in https://arxiv.org/abs/2012.04584\n    distill_loss_coef = 1.                  # weight of distillation auxilliary loss         \n )\n\n# wrap your model and your documents\n\ntrainer = TrainingWrapper(\n    model,\n    num_documents = NUM_DOCS,\n    doc_seq_len = SEQ_LEN,\n    num_evidence = 4,                         # number of evidence documents to fetch per target document to construct\n    reindex_batch_size = 32,                  # batch size to use when reindexing\n    documents_memmap_path = './train.dat',    # path to the mem-mapped documents\n    masks_memmap_path = './train.mask.dat',   # if None is supplied, will assume all tokens are visible\n    use_faiss_ann = True                      # set this to false if you have a low number of documents, and approximate nearest neighbor is not needed\n)\n\n# instantiate dataloader\n\ndl = DataLoader(trainer.dataset, batch_size=16)\n\n# now you can train, and use the reindex method on the training wrapper at appropriate intervals\n\nfor ind, data in enumerate(dl):\n    loss = trainer(data)\n    loss.backward()\n    # optimizer step and all that\n\n    # reindex and precompute knn every 10000 steps, as in paper\n    if ind \u003e 0 and ind % 10000 == 0:\n        trainer.reindex()\n```\n\nSave your model after much training\n\n```python\ntorch.save(model, f'./trained-model.pt')\n```\n\n## Advanced\n\nIf you would like the target and evidence documents to be from different sets, you just have to pass in up to four additional keyword arguments, as shown below.\n\n```python\ntrainer = TrainingWrapper(\n    model,\n    num_documents = NUM_DOCS,\n    doc_seq_len = SEQ_LEN,\n    num_evidence = 4,\n    reindex_batch_size = 32,\n    documents_memmap_path = './evidence.dat',\n    masks_memmap_path = './evidence.mask.dat',\n    num_targets = NUM_TARGETS,                       # 1. number of target documents, with sequence length the same as the document (evidence)\n    target_seq_len = SEQ_LEN,                        # 2. sequence length of target documents\n    target_memmap_path = './target.dat',             # 3. path to target memmap, same as documents (evidence)\n    target_masks_memmap_path = './target.mask.dat',  # 4. path to target mask memmap, same as document masks (evidence)\n    use_faiss_ann = True\n)\n```\n\n## Sampling\n\nYou can sample from the decoder with the following instructions\n\n```python\n# some random evidence from the dataset\n# or provide your own in the dimensions (b x num_evidences x seq_len)\n*_, evidence, mask = trainer.dataset[0:1]\n\n# assume 1 is start token\nprime = torch.tensor([[1.]]).long().cuda()\n\n# supply your own document similarities array (b x num_evidences)\n# if not supplied, will default to 1. for all evidence\ndoc_similarities = torch.ones(evidence.shape[:2]).float().cuda()\n\n# generate sample of length 1024\nsamples = model.generate(prime, 1024, evidence, mask = mask, similarities = doc_similarities)\n```\n\n## Citations\n\n```bibtex\n@misc{lewis2020pretraining,\n    title={Pre-training via Paraphrasing},\n    author={Mike Lewis and Marjan Ghazvininejad and Gargi Ghosh and Armen Aghajanyan and Sida Wang and Luke Zettlemoyer},\n    year={2020},\n    eprint={2006.15020},\n    archivePrefix={arXiv},\n    primaryClass={cs.CL}\n}\n```\n\n```bibtex\n@misc{komatsuzaki2020current,\n    title={Current Limitations of Language Models: What You Need is Retrieval},\n    author={Aran Komatsuzaki},\n    year={2020},\n    eprint={2009.06857},\n    archivePrefix={arXiv},\n    primaryClass={cs.CL}\n}\n```\n\n```bibtex\n@misc{izacard2020distilling,\n    title={Distilling Knowledge from Reader to Retriever for Question Answering},\n    author={Gautier Izacard and Edouard Grave},\n    year={2020},\n    eprint={2012.04584},\n    archivePrefix={arXiv},\n    primaryClass={cs.CL}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flucidrains%2Fmarge-pytorch","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Flucidrains%2Fmarge-pytorch","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flucidrains%2Fmarge-pytorch/lists"}