{"id":27984397,"url":"https://github.com/nightdessert/Retrieval_Head","last_synced_at":"2025-05-08T05:01:56.593Z","repository":{"id":234220627,"uuid":"788459169","full_name":"nightdessert/Retrieval_Head","owner":"nightdessert","description":"open-source code for paper: Retrieval Head Mechanistically Explains Long-Context Factuality","archived":false,"fork":false,"pushed_at":"2024-08-02T08:25:57.000Z","size":1046,"stargazers_count":130,"open_issues_count":1,"forks_count":11,"subscribers_count":2,"default_branch":"main","last_synced_at":"2024-08-02T09:53:46.325Z","etag":null,"topics":["large-language-models","long-context"],"latest_commit_sha":null,"homepage":"https://arxiv.org/abs/2404.15574","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":null,"status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/nightdessert.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2024-04-18T13:07:23.000Z","updated_at":"2024-08-02T08:26:00.000Z","dependencies_parsed_at":"2024-05-06T15:06:47.138Z","dependency_job_id":"2b09dab3-a0ca-4a06-805e-544f2485cfc1","html_url":"https://github.com/nightdessert/Retrieval_Head","commit_stats":null,"previous_names":["nightdessert/retrieval-head"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/nightdessert%2FRetrieval_Head","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/nightdessert%2FRetrieval_Head/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/nightdessert%2FRetrieval_Head/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/nightdessert%2FRetrieval_Head/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/nightdessert","download_url":"https://codeload.github.com/nightdessert/Retrieval_Head/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":253002856,"owners_count":21838640,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["large-language-models","long-context"],"created_at":"2025-05-08T05:01:48.829Z","updated_at":"2025-05-08T05:01:56.579Z","avatar_url":"https://github.com/nightdessert.png","language":"Python","funding_links":[],"categories":["A01_文本生成_文本对话"],"sub_categories":["大语言对话模型及数据"],"readme":"# Retrieval Head\nThis is the open-source code for paper:\n*[Retrieval Head Mechanistically Explains Long-Context Factuality](https://arxiv.org/abs/2404.15574)*. \n\nThis code is implemented based on *[Needle In a HayStack](https://github.com/gkamradt/LLMTest_NeedleInAHaystack/tree/main)*.\n\n【Update】 Support Phi3 now, thanks to the contribution made by @Wangmerlyn.\n## Retrieval Head Detection\nAn algorithm that statistically calculate the retrieval score of attention heads in a transformer model.\nBecause FlashAttention can not return attention matrix, this algorithm is implemented by first caching with FlashAttention and apply normal attention for decoding. \n### Environment\n**Core**: pytorch=2.0.1, transformers=4.37.2, flash-attn=2.5.6 (my environment)\n\n**Other**: rouge_score\n\nA Single 80G GPU is enough to detect up to 50K length.\n### Usage :\n```python\npython retrieval_head_detection.py  --model_path $path_to_model --s 0 --e 50000\n```\nWe find that only few samples can stablely detect some of the strongest retrieval heads. I if you are in a hurry or no fancy large GPUs avalible, you can just set  '--e' to a lower value, e.g.\n```python\npython retrieval_head_detection.py  --model_path $path_to_model --s 0 --e 5000\n```\nResults of retrieval score will be write in './head_score/$model_name.json'\n**Currently Implemented Model Families**: \nLLama([Llama-2-7B-80K](https://huggingface.co/yaofu/llama-2-7b-80k)), Yi, Qwen, Mistrial\n\n### Results:\nAll detection results are saved in \"./head_score/*.json\", where each head is saved in the format of \n```python\n{layer-head_id: [list of retrieval scores across detections]}\n```\n**Directly load a results for Analysis**\n```python\n## load head score file, llama-2-7b-80k for example\nimport json\nimport numpy as np\nwith open('./head_score/llama-2-7b-80k.json') as file:\n    head_list = json.loads(file.readline())\n## use the average retrieval score and ranking\nhead_score_list = [([int(ll) for ll in l[0].split(\"-\")],np.mean(l[1])) for l in head_list.items()]\nhead_score_list = sorted(head_score_list, key=lambda x: x[1], reverse=True) \ntop_retrieval_heads = [[l[0],  round(np.mean(l[1]), 2)] for l in head_score_list][:10]\nprint(top_retrieval_heads)\n'''\nHead:[16, 19],   Retrieval Score: 0.94      Head:[11, 15],   Retrieval Score: 0.92      \nHead:[8, 26],    Retrieval Score: 0.8       Head:[6, 9],     Retrieval Score: 0.62        \nHead:[7, 12],    Retrieval Score: 0.61      Head:[17, 22],   Retrieval Score: 0.56\nHead:[11, 2],    Retrieval Score: 0.46      Head:[6, 16],    Retrieval Score: 0.44\nHead:[19, 15],   Retrieval Score: 0.42      Head:[21, 30],   Retrieval Score: 0.4\n'''\n```\n## Influence on Needle-in-a-Haystack\nThis code is implemented by masking the given head in the attention matrix or masking the query in FalshAttention.\n### Usage:\nSetting --mask_top to K \u003e 0 to mask out top K retrieval heads, K \u003c 0 to mask out K random heads, K = 0 for no masking.\n\nA Single 80G GPU can test up to ~70K length, 2*80G GPU can test up to 100K length\n\nMasking top 30 retrieval heads vs 30 random heads:\n```python\npython needle_in_haystack_with_mask.py --mask_top 30 --s 1000 --e 100000  --model_path $path_to_model  #Results of  will be written in './results/graph/llama-2-7b-80k_block_top30'\npython needle_in_haystack_with_mask.py --mask_top -30 --s 1000 --e 100000  --model_path $path_to_model  #Results of  will be written in './results/graph/llama-2-7b-80k_block_random30'\n```\n### Reulsts and Visualization:\nReplace 'model_name' in './viz/CreateVizFromLLMTesting.ipynb' by the folder name of Needle-in-a-Haystack results.\n\n**Mask top 30 Retrieval Head for Llama-2-7b-80K**:\n![alt text](viz/top30.png)\n**Mask random 30 non-Retrieval Head for Llama-2-7b-80K**:\n![alt text](viz/random.png)\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fnightdessert%2FRetrieval_Head","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fnightdessert%2FRetrieval_Head","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fnightdessert%2FRetrieval_Head/lists"}