{"id":20519408,"url":"https://github.com/apple/ml-cross-entropy","last_synced_at":"2025-05-14T18:04:58.503Z","repository":{"id":263001331,"uuid":"887718293","full_name":"apple/ml-cross-entropy","owner":"apple","description":null,"archived":false,"fork":false,"pushed_at":"2025-04-23T00:14:15.000Z","size":445,"stargazers_count":434,"open_issues_count":12,"forks_count":33,"subscribers_count":15,"default_branch":"main","last_synced_at":"2025-05-03T20:02:42.448Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"other","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/apple.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":"CONTRIBUTING.md","funding":null,"license":"LICENSE","code_of_conduct":"CODE_OF_CONDUCT.md","threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2024-11-13T06:53:37.000Z","updated_at":"2025-05-01T14:34:34.000Z","dependencies_parsed_at":"2024-12-24T22:00:23.545Z","dependency_job_id":"68a322f4-2785-412e-b078-78ff6466aa14","html_url":"https://github.com/apple/ml-cross-entropy","commit_stats":null,"previous_names":["apple/ml-cross-entropy"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/apple%2Fml-cross-entropy","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/apple%2Fml-cross-entropy/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/apple%2Fml-cross-entropy/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/apple%2Fml-cross-entropy/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/apple","download_url":"https://codeload.github.com/apple/ml-cross-entropy/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":254198514,"owners_count":22030965,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-11-15T22:13:21.463Z","updated_at":"2025-05-14T18:04:53.491Z","avatar_url":"https://github.com/apple.png","language":"Python","funding_links":[],"categories":["Python"],"sub_categories":[],"readme":"## Cut Your Losses in Large-Vocabulary Language Models\n\nThis software project accompanies the research paper:\n**[Cut Your Losses in Large-Vocabulary Language Models](https://arxiv.org/abs/2411.09009)**,\n*Erik Wijmans, Brody Huval, Alexander Hertzberg, Vladlen Koltun, and Philipp Krähenbühl*.\n\n![](assets/cce_figure.png)\n\nAs language models grow ever larger, so do their vocabularies. This has shifted the memory footprint of LLMs during training disproportionately to one single layer: the cross-entropy in the loss computation. Cross-entropy builds up a logit matrix with entries for each pair of input tokens and vocabulary items and, for small models, consumes an order of magnitude more memory than the rest of the LLM combined. We propose Cut Cross-Entropy (CCE), a method that computes the cross-entropy loss without materializing the logits for all tokens into global memory. Rather, CCE only computes the logit for the correct token and evaluates the log-sum-exp over all logits on the fly. We implement a custom kernel that performs the matrix multiplications and the log-sum-exp reduction over the vocabulary in flash memory, making global memory consumption for the cross-entropy computation negligible. This has a dramatic effect. Taking the Gemma 2 (2B) model as an example, CCE reduces the memory footprint of the loss computation from 24 GB to 1 MB, and the total training-time memory consumption of the classifier head from 28 GB to 1 GB. To improve the throughput of CCE, we leverage the inherent sparsity of softmax and propose to skip elements of the gradient computation that have a negligible (i.e., below numerical precision) contribution to the gradient. Experiments demonstrate that the dramatic reduction in memory consumption is accomplished without sacrificing training speed or convergence.\n\n## Getting started\n\n**Requirements**\n\n1. Python 3.10+\n2. PyTorch 2.4+\n3. Triton 3.0+\n4. Ampere (or newer) GPU\n\n\n**Note:**  For operating systems that are not supported by Triton (e.g., MacOS), we include a highly optimized version of\nlinear-cross-entropy using `torch.compile`. This implementation will be set to the default on MacOS.\n\n### Basic usage\n\n**Installation**\n```bash\npip install \"cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git\"\n```\n\n**Usage**\n\n```python\nfrom cut_cross_entropy import linear_cross_entropy\n\nembeddings = model.compute_embedding(inputs)\nclassifier = model.get_classifier_weights()\n\nloss = linear_cross_entropy(embeddings, classifier, labels)\n```\n\nIn causal language modeling, it is common that the model embeddings and labels need to be shifted\nsuch that the model predicts the next token.\n\n```python\nfrom cut_cross_entropy import linear_cross_entropy\n\nembeddings = model.compute_embedding(inputs)\nclassifier = model.get_classifier_weights()\n\nshift_embeddings = embeddings[..., :-1, :].flatten(0, -2)\nshift_labels = labels[..., 1:]\n\nmanual_shift_loss = linear_cross_entropy(shift_embeddings, classifier, shift_labels)\n```\n\nInstead, pass `shift=1` to perform this computation without allocating the shift_embeddings matrix.\n```python\nfrom cut_cross_entropy import linear_cross_entropy\n\nembeddings = model.compute_embedding(inputs)\nclassifier = model.get_classifier_weights()\n\n# This is the same as manual_shift_loss above\nauto_shift_loss = linear_cross_entropy(embeddings, classifier, labels, shift=1)\n```\n\nWe also provide a highly optimized implementation of linear-cross-entropy loss using `torch.compile`.\nThis is a good option\nfor scenarios where speed is the primary goal and the model has a relatively small vocabulary compared to its\nhidden dimension (when |V| \u003e\u003e D, `cce` will both save memory _and_ be faster).\nThis option also works on the CPU and older GPUs, making it useful for testing.\n\n```python\nfrom cut_cross_entropy import linear_cross_entropy\n\nembeddings = model.compute_embedding(inputs)\nclassifier = model.get_classifier_weights()\n\nloss = linear_cross_entropy(embeddings, classifier, labels, ..., impl=\"torch_compile\")\n```\n\n\nThere are several different\n\n\n### Computing Related Quantities\n\n`linear_cross_entropy` can be used as an efficient way to compute the negative log likelihood\nof a specified token. This can be used to compute various quantities.\n\n\n```python\nfrom cut_cross_entropy import linear_cross_entropy\n\n\n# linear_cross_entropy computes negative log likelihood for a target token\nnll = linear_cross_entropy(embeddings, classifier, target_token, reduction=\"none\")\n\n# Perplexity\nppl = torch.exp(nll.mean(-1))\n\n# DPO (beta and reference omitted)\ndpo_loss = -F.logsigmoid(nll[dispreferred].sum(-1) - nll[preferred].sum(-1))\n\n# PPO\nppo_loss = -torch.minimum(toch.exp(-nll - old_logp) * adv, adv + eps * adv.abs())\n```\n\n\n### Generalized Usage\n\nWhile we have discussed using CCE in the context of large language models, the only constraint\nto use CCE is that loss can be formulated using something that resembles following:\n\n```python\nlogits = X @ A.T + b  # (b is an optional bias)\nloss = F.cross_entropy(logits.float(), targets)\n```\n\nGiven that format, CCE can then be used as\n```python\nloss = linear_cross_entropy(X, A, target_token, bias=b)\n```\n\nThis is a very general and encompasses vision models, contrastive losses, e.g. CLIP, etc.\n\n\n### Transformers Integration\n\n**Installation**\n\nInstall cut-cross-entropy with transformers dependencies\n```bash\npip install \"cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git\"\n```\n\n**Usage**\n\nIf you are using transformers, you can patch transformers to use CCE directly. Note that\nlogits will no longer be returned (`None` will be returned instead).\n```python\nfrom cut_cross_entropy.transformers import cce_patch\n\ncce_patch(\"llama\")\n\n# or\n\nmodel = ...\nmodel = cce_patch(model)\n```\n\nWe currently support the Llama, Phi3, Mistral, and Gemma2 families of models.\n\n`cce_patch` takes two options. The first is the linear-cross-entropy implementation to use. Currently `\"cce\"` or `\"torch_compile\"`.\n\nThe second\nis the loss reduction. We support `\"mean\"`, `\"sum\"`, and `\"none\"`, that mirror their PyTorch counterpart.\n`\"mean\"` is the default and what the transformers trainer API expects.\nHowever,\n`\"none\"` in particular can enable for efficient computation of quantities based on the loss.\n\nFor example, the following efficiently computes the perplexity of a batch of sequences:\n```python\nimport transformers\n\nfrom cut_cross_entropy.transformers import cce_patch\n\n\nmodel = transformers.AutoModelForCausalLM.from_pretrained(...)\n\nmodel = cce_patch(model, reduction=\"none\")\n\nlabels = input_ids.clone()\nlabels[~attention_mask] = -100 # -100 is the ignore index for PyTorch and CCE.\n\noutputs = model(input_ids, attention_mask, labels=labels)\n\nloss = outputs[0] # A (B, T - 1) tensor because reduction=\"none\". T - 1 because the first input token has\n# no loss.\n\nppl = torch.exp(\n    # [:, 1:] because the first token has no loss\n    loss.sum(1) / (labels[:, 1:] != -100).count_nonzero(dim=1)\n).mean()  # Average perplexity over the batch\n```\n\n\n\n### Training and reproducing the benchmark results\n\nWe provide a training in `training/train.py`.\n\n**Installation**\n```bash\npip install \"cut-cross-entropy[all] @ git+https://github.com/apple/ml-cross-entropy.git\"\n```\n\n**Training**\n\nUse `scripts/train.sh` to train a full model.\n\n**Benchmarking**\n\nThe benchmark script can be run via `python -m benchmark`.\n\nExpected output with A100 SMX4, PyTorch 2.4.1, and CUDA 12.4.\n\n```\n          method        kind  runtime_ms  op_mem_mb test_data\n0            cce     loss-fw        46.4        1.1    gemma2\n1  torch_compile     loss-fw        49.9     4000.1    gemma2\n2       baseline     loss-fw        81.9    24000.0    gemma2\n3            cce     loss-bw        89.3     1163.0    gemma2\n4  torch_compile     loss-bw        92.3    12000.0    gemma2\n5       baseline     loss-bw       122.4    16000.0    gemma2\n6            cce  loss-fw-bw       134.8     1164.0    gemma2\n7  torch_compile  loss-fw-bw       144.0    16000.1    gemma2\n8       baseline  loss-fw-bw       208.8    28000.0    gemma2\n```\n\n### Development\n\nIf dependencies are installed locally, `cut-cross-entropy` will work without a pip install as long as `python` is executed in the root path of the github repo.\n\nTo install directly from the github repo, either use an (editable) install or manipulate PYTHONPATH, e.g.\n\n```bash\npip install -e \".[dev]\"\n\n# or\npip install \".[dev]\"\n\n# or\nexport PYTHONPATH=/path/to/ml-cross-entropy:${PYTHONPATH}\n```\n\n## Citation\n\n```\n@inproceedings{wijmans2025cut,\n  author       = {Erik Wijmans and\n                  Brody Huval and\n                  Alexander Hertzberg and\n                  Vladlen Koltun and\n                  Philipp Kr\\\"ahenb\\\"uhl},\n  title        = {Cut Your Losses in Large-Vocabulary Language Models},\n  booktitle    = {International Conference on Learning Representations},\n  year         = {2025},\n}\n```\n\n\n## License\nThis sample code is released under the [LICENSE](LICENSE) terms.\n\n## Acknowledgements\n\nOur codebase is built using multiple opensource contributions, please see [Acknowledgements](ACKNOWLEDGEMENTS.md) for more details.\n\nPlease check the paper for a complete list of references and datasets used in this work.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fapple%2Fml-cross-entropy","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fapple%2Fml-cross-entropy","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fapple%2Fml-cross-entropy/lists"}