{"id":19156213,"url":"https://github.com/kyegomez/flashmha","last_synced_at":"2025-05-07T07:34:30.029Z","repository":{"id":180458750,"uuid":"665195802","full_name":"kyegomez/FlashMHA","owner":"kyegomez","description":"An simple pytorch implementation of Flash MultiHead Attention","archived":false,"fork":false,"pushed_at":"2024-02-05T03:27:22.000Z","size":87,"stargazers_count":20,"open_issues_count":1,"forks_count":2,"subscribers_count":2,"default_branch":"main","last_synced_at":"2025-03-27T04:46:52.405Z","etag":null,"topics":["artificial-intelligence","artificial-neural-networks","attention","attention-mechanisms","attentionisallyouneed","flash-attention","gpt4","transformer"],"latest_commit_sha":null,"homepage":"https://discord.gg/qUtxnK2NMf","language":"Jupyter Notebook","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/kyegomez.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2023-07-11T16:44:12.000Z","updated_at":"2025-02-24T09:33:53.000Z","dependencies_parsed_at":null,"dependency_job_id":"3561814c-a514-42b2-bee9-eafa62e97437","html_url":"https://github.com/kyegomez/FlashMHA","commit_stats":{"total_commits":22,"total_committers":2,"mean_commits":11.0,"dds":0.09090909090909094,"last_synced_commit":"89aa879f21b2922c5c4e1f9ee192f6844f5fe319"},"previous_names":["kyegomez/flashmha"],"tags_count":6,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/kyegomez%2FFlashMHA","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/kyegomez%2FFlashMHA/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/kyegomez%2FFlashMHA/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/kyegomez%2FFlashMHA/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/kyegomez","download_url":"https://codeload.github.com/kyegomez/FlashMHA/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":249764837,"owners_count":21322309,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["artificial-intelligence","artificial-neural-networks","attention","attention-mechanisms","attentionisallyouneed","flash-attention","gpt4","transformer"],"created_at":"2024-11-09T08:33:37.735Z","updated_at":"2025-04-19T18:31:20.395Z","avatar_url":"https://github.com/kyegomez.png","language":"Jupyter Notebook","funding_links":[],"categories":[],"sub_categories":[],"readme":"# FlashMHA\nFlashMHA is a PyTorch implementation of the Flash Multi-Head Attention mechanism. It is designed to be efficient and flexible, allowing for both causal and non-causal attention. The implementation also includes support for the Flash Attention mechanism, which is a highly efficient attention mechanism designed for GPUs.\n\n## Installation\n\nYou can install FlashMHA using pip:\n\n```shell\npip install FlashMHA\n```\n\n## Usage\n\nAfter installing FlashMHA, you can import the FlashAttention module for usage in your code:\n\n```python\nfrom FlashMHA import FlashAttention\n```\n\nor\n\n```python\nfrom FlashMHA import FlashMHA\n```\n\nNow you can create an instance of the FlashAttention class or the FlashMHA class and use it in your code accordingly.\n\nExample usage:\n\n```python\n# Import the necessary module\nfrom FlashMHA import FlashAttention\n\n# Create an instance of FlashAttention\nflash_attention = FlashAttention(causal=False, dropout=0.0)\n\n# Use the FlashAttention instance in your code\noutput = flash_attention(query, key, value)\n```\n\n```python\n# Import the necessary module\nfrom FlashMHA import FlashMHA\n\n# Create an instance of FlashMHA\nflash_mha_attention = FlashMHA(causal=False, dropout=0.0)\n\n# Use the FlashMHA instance in your code\noutput = flash_mha_attention(query, key, value)\n```\n\nMake sure to replace `query`, `key`, and `value` with your own input tensors.\n\nNow you can utilize the FlashAttention or FlashMHA module in your code by following the provided examples.\nIn this example, `query`, `key`, and `value` are input tensors with shape `(batch_size, sequence_length, embed_dim)`. The FlashMHA model applies the multi-head attention mechanism to these inputs and returns the output tensor.\n\n## Documentation\n\n\n### `FlashAttention`\n\nFlashAttention is a PyTorch module that implements the Flash Attention mechanism, a highly efficient attention mechanism designed for GPUs. It provides a fast and flexible solution for attention computations in deep learning models.\n\n## Parameters\n\n- `causal` (bool, optional): If set to True, applies causal masking to the sequence. Default: False.\n- `dropout` (float, optional): The dropout probability. Default: 0.\n- `flash` (bool, optional): If set to True, enables the use of Flash Attention. Default: False.\n\n## Inputs\n\n- `q` (Tensor): The query tensor of shape (batch_size, num_heads, query_length, embed_dim).\n- `k` (Tensor): The key tensor of shape (batch_size, num_heads, key_length, embed_dim).\n- `v` (Tensor): The value tensor of shape (batch_size, num_heads, value_length, embed_dim).\n- `mask` (Tensor, optional): An optional mask tensor of shape (batch_size, num_heads, query_length, key_length), used to mask out specific positions. Default: None.\n- `attn_bias` (Tensor, optional): An optional additive bias tensor of shape (batch_size, num_heads, query_length, key_length), applied to the attention weights. Default: None.\n\n## Outputs\n\n- `output` (Tensor): The output tensor of shape (batch_size, query_length, embed_dim).\n\n## `FlashMHA`\n\nFlashMHA is a PyTorch module that implements the Flash Multi-Head Attention mechanism, which combines multiple FlashAttention layers. It is designed to be efficient and flexible, allowing for both causal and non-causal attention.\n\n## Parameters\n\n- `embed_dim` (int): The dimension of the input embedding.\n- `num_heads` (int): The number of attention heads.\n- `bias` (bool, optional): If set to False, the layers will not learn an additive bias. Default: True.\n- `batch_first` (bool, optional): If True, then the input and output tensors are provided as (batch, seq, feature). Default: True.\n- `dropout` (float, optional): The dropout probability. Default: 0.\n- `causal` (bool, optional): If True, applies causal masking to the sequence. Default: False.\n- `device` (torch.device, optional): The device to run the model on. Default: None.\n- `dtype` (torch.dtype, optional): The data type to use for the model parameters. Default: None.\n\n## Inputs\n\n- `query` (Tensor): The query tensor of shape (batch_size, sequence_length, embed_dim).\n- `key` (Tensor): The key tensor of shape (batch_size, sequence_length, embed_dim).\n- `value` (Tensor): The value tensor of shape (batch_size, sequence_length, embed_dim).\n\n## Outputs\n\n- `output` (Tensor): The output tensor of shape (batch_size, sequence_length, embed_dim).\n\n## License\n\nFlashAttention and FlashMHA are open-source software, licensed under the MIT license. For more details, please refer to the [GitHub repository](https://github.com/kyegomez/FlashAttention).","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fkyegomez%2Fflashmha","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fkyegomez%2Fflashmha","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fkyegomez%2Fflashmha/lists"}