{"id":16930457,"url":"https://github.com/ayaka14732/bart-base-jax","last_synced_at":"2025-03-22T11:31:18.954Z","repository":{"id":39590009,"uuid":"474519598","full_name":"ayaka14732/bart-base-jax","owner":"ayaka14732","description":"JAX implementation of the bart-base model","archived":false,"fork":false,"pushed_at":"2023-04-11T15:05:59.000Z","size":3301,"stargazers_count":30,"open_issues_count":0,"forks_count":4,"subscribers_count":2,"default_branch":"main","last_synced_at":"2025-03-18T11:04:01.285Z","etag":null,"topics":["bart","jax","natural-language-processing","nlp","nlp-model"],"latest_commit_sha":null,"homepage":"https://arxiv.org/abs/1910.13461","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":null,"status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/ayaka14732.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-03-27T02:55:21.000Z","updated_at":"2025-01-19T20:17:46.000Z","dependencies_parsed_at":"2024-10-13T20:41:51.092Z","dependency_job_id":"2d7fd7c9-7cce-44e9-bb63-90e1830d4094","html_url":"https://github.com/ayaka14732/bart-base-jax","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ayaka14732%2Fbart-base-jax","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ayaka14732%2Fbart-base-jax/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ayaka14732%2Fbart-base-jax/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ayaka14732%2Fbart-base-jax/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/ayaka14732","download_url":"https://codeload.github.com/ayaka14732/bart-base-jax/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":244951418,"owners_count":20537384,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["bart","jax","natural-language-processing","nlp","nlp-model"],"created_at":"2024-10-13T20:41:45.740Z","updated_at":"2025-03-22T11:31:18.409Z","avatar_url":"https://github.com/ayaka14732.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# JAX Implementation of bart-base\n\nThis project is a JAX implementation of the [bart-base](https://arxiv.org/abs/1910.13461) model. The aim of this project is to provide a versatile codebase for research on Transformer-based LLM architecture and demonstrate how Transformer-based language models can be implemented using JAX and trained on Google Cloud TPUs.\n\nThis project is supported by Cloud TPUs from Google's [TPU Research Cloud](https://sites.research.google/trc/about/) (TRC).\n\nThis project is inspired by [hyunwoongko/transformer](https://github.com/hyunwoongko/transformer), while the code for this project is entirely written by myself.\n\n* [News](#news)\n* [Environment Setup](#environment-setup)\n* [Model Architecture](#model-architecture)\n* [Model Parameters](#model-parameters)\n* [Dataset](#dataset)\n    * [English Wikipedia](#english-wikipedia)\n* [Data Preprocessing](#data-preprocessing)\n* [Data Loader](#data-loader)\n* [Training](#training)\n* [Evaluation](#evaluation)\n* [Generation](#generation)\n* [Analysis](#analysis)\n\n## News\n\n**2022-11-07:** [WIP] I am working on [ayaka14732/TransCan](https://github.com/ayaka14732/TransCan).\n\n**2022-10-27:** I published the [Cantonese BART](https://huggingface.co/Ayaka/bart-base-cantonese) model. It is obtained by a second-stage pre-training on the [LIHKG dataset](https://github.com/ayaka14732/lihkg-scraper) based on the [fnlp/bart-base-chinese](https://huggingface.co/fnlp/bart-base-chinese) model. See [ayaka14732/bart-base-cantonese](https://github.com/ayaka14732/bart-base-cantonese) for details. [[Twitter]](https://twitter.com/ayaka14732/status/1585561115345375233)\n\n**2022-10-08:** [WIP] I am fine-tuning [fnlp/bart-base-chinese](https://huggingface.co/fnlp/bart-base-chinese) to develop a Mandarin-Taiwanese Hokkien translation model. See the [`twblg`](https://github.com/ayaka14732/bart-base-jax/tree/twblg) branch for details.\n\n**2022-09-27:** [Nixie](https://github.com/ztjhz) and I implemented [TrAVis](https://github.com/ayaka14732/TrAVis), a BERT attention visualiser that runs completely in-browser, based on this codebase. [[Twitter]](https://twitter.com/ayaka14732/status/1574627912162349056)\n\n**2022-03-27:** In addition to the regular implementation, I also implemented the model in a single line of Python code, by virtue of JAX's functional-style API. [[Twitter]](https://twitter.com/ayaka14732/status/1507955631109869574)\n\n## Environment Setup\n\nSet up TPU environment as described in [ayaka14732/tpu-starter](https://github.com/ayaka14732/tpu-starter). Then run the following commands:\n\n```sh\npython3.10 -m venv ./venv\n. ./venv/bin/activate\npip install -U pip\npip install -U wheel\npip install \"jax[tpu]==0.3.23\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html\npip install -r requirements.txt\n```\n\n## Model Architecture\n\nTODO: See `lib/model`.\n\n## Model Parameters\n\nParameter-related operations are implemented in the `lib/param_utils` directory. Notably, three functions, `flax2jax`, `pt2jax` and `jax2flax` are implemented, to allow any conversions between PyTorch, Flax and JAX implementation.\n\n| from\\to | PyTorch | Flax | JAX |\n| :- | :-: | :-: | :-: |\n| PyTorch | - | `save_pretrained` | `pt2jax` |\n| Flax | `save_pretrained` | - | `flax2jax` |\n| JAX | `jax2flax` + `save_pretrained` | `jax2flax` | - |\n\n`save_pretrained` is a function provided by the Hugging Face Transformers library, so that users can save the model in one framework and reload it in another framework. For instance, the following code saves a Flax model and reload it as a PyTorch model:\n\n```python\nwith tempfile.TemporaryDirectory() as tmpdirname:\n    model_flax.save_pretrained(tmpdirname)\n    model_pt = BartForConditionalGeneration.from_pretrained(tmpdirname, from_flax=True)\n```\n\nJAX parameters, see [param_format.txt](param_format.txt).\n\n## Dataset\n\n### English Wikipedia\n\nSplit English Wikipedia into sentences.\n\n1. Download the English Wikipedia data\n1. Extract the data by WikiExtractor\n1. Split the articles into sentences by Bling Fire\n1. Save the sentences to files (one sentence per line)\n\n```sh\npython prepare_dataset.py\n```\n\nOn Cloud TPU v3-8, the processing takes 15m18s. On Cloud TPU v4-8, it takes 4m19s. The size of the directory is about 15 GiB.\n\n## Data Preprocessing\n\nThe `[EOS]` token (`tokenizer.eos_token_id`) should be prepended before all sentences in `dst`.\n\nExample:\n\n```\nInput: ['\u003cs\u003eA flower.\u003c/s\u003e\u003cpad\u003e', '\u003cs\u003eSome good sentences.\u003c/s\u003e']\nOutput: ['\u003c/s\u003e\u003cs\u003eA flower.\u003c/s\u003e\u003cpad\u003e', '\u003c/s\u003e\u003cs\u003eSome good sentences.\u003c/s\u003e']\nInput IDs: [[0, 250, 14214, 4, 2, 1], [0, 6323, 205, 11305, 4, 2]]\nOutput IDs: [[2, 0, 250, 14214, 4, 2, 1], [2, 0, 6323, 205, 11305, 4, 2]]\n```\n\n- **src**: `[BOS]`, A, `[MSK]`, flower, `[EOS]`, `[PAD]`, `[PAD]`\n- **dst**: `[EOS]`, `[BOS]`, A, beautiful, flower, `[EOS]`, `[PAD]`\n- **label**: `[BOS]`, A, beautiful, flower, `[EOS]`, `[PAD]`, `[PAD]`\n\n\u003cdetails\u003e\n\n```python\nfrom transformers import BartTokenizer, BartForConditionalGeneration\n\nmodel_name = 'facebook/bart-base'\ntokenizer = BartTokenizer.from_pretrained(model_name)\nmodel = BartForConditionalGeneration.from_pretrained(model_name)\n\nsentences = ('A flower.', 'Some good sentences.')\n\ninputs = tokenizer(sentences, return_tensors='pt', max_length=6, padding='max_length', truncation=True)\noutput = model.generate(inputs.input_ids)\n\nprint('Input:', tokenizer.batch_decode(inputs.input_ids))\nprint('Output:', tokenizer.batch_decode(output))\n\nprint('Input IDs:', inputs.input_ids.tolist())\nprint('Output IDs:', output.tolist())\n```\n\n\u003c/details\u003e\n\n## Data Loader\n\nOn-demand data loader\n\n## Training\n\n## Evaluation\n\n## Generation\n\nTODO\n\nTypical generation process of the BART model involves the input sequences and their masks. The model generates the output autoregressively.\n\nWhile greedy decoding is the simplest generation algorithm for autoregressive language models, other algorithms like beam search and sampling can improve the quality of the generated sentences and therefore improve performance. In this project, we refrain from implementing these generation algorithms and leave the work to the Hugging Face Transformers library.\n\nHowever, generation functions in the Hugging Face Transformers library are coupled with the implementation of their original models, which makes them inaccessible for customized models. To tackle this problem, we convert our model to a regular Hugging Face Transformer model.\n\n## Analysis\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fayaka14732%2Fbart-base-jax","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fayaka14732%2Fbart-base-jax","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fayaka14732%2Fbart-base-jax/lists"}