{"id":18373880,"url":"https://github.com/erksch/fnet-pytorch","last_synced_at":"2025-04-06T19:32:27.932Z","repository":{"id":48390368,"uuid":"383177095","full_name":"erksch/fnet-pytorch","owner":"erksch","description":"Unofficial PyTorch implementation of Google's FNet: Mixing Tokens with Fourier Transforms. With checkpoints.","archived":false,"fork":false,"pushed_at":"2022-09-13T21:13:02.000Z","size":35,"stargazers_count":73,"open_issues_count":5,"forks_count":8,"subscribers_count":2,"default_branch":"master","last_synced_at":"2025-03-22T06:12:24.681Z","etag":null,"topics":["fnet","language-model","pytorch","transformer"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/erksch.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2021-07-05T14:56:43.000Z","updated_at":"2025-02-25T13:53:13.000Z","dependencies_parsed_at":"2022-08-25T20:11:40.618Z","dependency_job_id":null,"html_url":"https://github.com/erksch/fnet-pytorch","commit_stats":null,"previous_names":[],"tags_count":1,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/erksch%2Ffnet-pytorch","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/erksch%2Ffnet-pytorch/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/erksch%2Ffnet-pytorch/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/erksch%2Ffnet-pytorch/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/erksch","download_url":"https://codeload.github.com/erksch/fnet-pytorch/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247539374,"owners_count":20955302,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["fnet","language-model","pytorch","transformer"],"created_at":"2024-11-06T00:12:34.630Z","updated_at":"2025-04-06T19:32:25.956Z","avatar_url":"https://github.com/erksch.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# FNet PyTorch\n\n[![PyPI](https://img.shields.io/pypi/v/fnet-pytorch?logo=PyPI\u0026color=blue)](https://pypi.org/project/fnet-pytorch)\n\nA PyTorch implementation of FNet from the paper _FNet: Mixing Tokens with Fourier Transforms_ by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, and Santiago Ontanon ([arXiv](https://arxiv.org/abs/2105.03824)).\n\nAdditional to the architecture implementation, this repository offers a script for converting a checkpoint from the [official FNet implementation](https://github.com/google-research/google-research/tree/master/f_net) (written in Jax) to a PyTorch checkpoint (statedict and or model export).\n\n## Using a pre-trained model\n\nWe offer the following converted checkpoints and pre-trained models\n\n| Model               | Jax checkpoint                                                                                                                             | PyTorch checkpoint                                                                                                                                     | Arch Info                          | Dataset      | Train Info                                                                                                       |\n| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------- | ------------ | ---------------------------------------------------------------------------------------------------------------- |\n| FNet Large          | [checkpoint (official)](https://storage.googleapis.com/gresearch/f_net/checkpoints/large/f_net_checkpoint)                                 | [checkpoint (converted)](https://voize-checkpoints-public.s3.eu-central-1.amazonaws.com/fnet/pytorch_checkpoints/fnet_large_pt_checkpoint.zip)         | E 1024, D 1024, FF 4096, 24 layers | C4           | see paper / official project                                                                                     |\n| FNet Base           | [checkpoint (official)](https://storage.googleapis.com/gresearch/f_net/checkpoints/base/f_net_checkpoint)                                  | [checkpoint (converted)](https://voize-checkpoints-public.s3.eu-central-1.amazonaws.com/fnet/pytorch_checkpoints/fnet_base_pt_checkpoint.zip)          | E 768, D 768, FF 3072, 12 layers   | C4           | see paper / official project                                                                                     |\n| FNet Small          | [checkpoint (ours)](https://voize-checkpoints-public.s3.eu-central-1.amazonaws.com/fnet/jax_checkpoints/fnet_small_jax_checkpoint)         | [checkpoint (converted)](https://voize-checkpoints-public.s3.eu-central-1.amazonaws.com/fnet/pytorch_checkpoints/fnet_small_pt_checkpoint.zip)         | E 768, D 312, FF 3072, 4 layers    | Wikipedia EN | trained with official training code. 1M steps, BS 64, LR 1e-4                                                    |\n| FNet Small (german) | [checkpoint (ours)](https://voize-checkpoints-public.s3.eu-central-1.amazonaws.com/fnet/jax_checkpoints/fnet_small_de_250k_jax_checkpoint) | [checkpoint (converted)](https://voize-checkpoints-public.s3.eu-central-1.amazonaws.com/fnet/pytorch_checkpoints/fnet_small_de_250k_pt_checkpoint.zip) | E 312, D 312, FF 3072, 4 layers    | Wikipedia DE | trained with official training code, but with word piece tokenizer with custom vocab. 250k steps, BS 64, LR 1e-4 |\n\nThe PyTorch checkpoints marked with _converted_ are converted Jax checkpoints using the technique described below.\n\nYou can install this repository as a package running\n\n```python\npip install fnet-pytorch\n```\n\nNow, you can load a pre-trained model in PyTorch as follows.\nYou'll need the `config.json` and the `.statedict.pt` file.\n\n```python\nimport torch\nimport json\nfrom fnet import FNet, FNetForPretraining\n\nwith open('path/to/config.json', 'r') as f:\n    config = json.load(f)\n\n# if you just want the encoder\nfnet = FNet(config)\nfnet.load_state_dict(torch.load('path/to/fnet.statedict.pt'))\n\n# if you want FNet with pre-training head\nfnet = FNetForPretraining(config)\nfnet.load_state_dict(torch.load('path/to/fnet_pretraining.statedict.pt'))\n```\n\nYou can also get the config from only the state dict:\n\n```python\nfrom fnet import get_config_from_statedict\n\nstate_dict = torch.load('path/to/fnet.statedict.pt')\nconfig = get_config_from_statedict(state_dict)\nfnet = FNet(config)\nfnet.load_state_dict(state_dict)\n```\n\nBut not all config values can be inferred from the state dict alone, like dropout rate, fourier layer type and padding token index.\n`get_config_from_statedict` uses reasonable defaults for them. Look into the implementation to see which parameters are not inferred and how it might affect your use case.\n\n## Jax checkpoint conversion\n\nDownload a pre-trained Jax checkpoint of FNet from their [official GitHub page](https://github.com/google-research/google-research/tree/master/f_net#base-models) or use any checkpoint that you trained using the official implementation.  \nYou also need the SentencePiece vocab model. For the official checkpoints, use the model given [here](https://github.com/google-research/google-research/tree/master/f_net#how-to-pre-train-or-fine-tune-fnet). For custom checkpoints use your respective vocab model.\n\nInstall dependencies (ideally in a virtualenv)\n\n```bash\npip install -r requirements.txt\n```\n\nConvert a checkpoint to PyTorch\n\n```bash\npython convert_jax_checkpoint.py \\\n    --checkpoint \u003cpath/to/checkout\u003e \\\n    --vocab \u003cpath/to/vocab\u003e \\\n    --outdir \u003coutdir\u003e\n```\n\nOutput files: `config.json`, `fnet.statedict.pt`, `fnet_pretraining.statedict.pt`\n\nThe checkpoints from the official Jax implementation are of complete pre-training models, meaning they contain encoder and pre-training head weights.\nThe conversion will convert the Jax checkpoint to a PyTorch `statedict` of this project's `FNet` module (`fnet.statedict.pt`) and `FNetForPreTraining` module (`fnet_pretraining.statedict.pt`).\nYou can use the model type for your needs whether you want to run further pre-trainings or not.\n\n#### Disclaimer\n\nAlthough all model parameters will be correctly transferred to the PyTorch model, there will be slight differences between Jax and PyTorch in the inference result because their LayerNorm and GELU implementations slightly differ.\n\nFor a given inference input, all hidden states and logits of the official and converted model are equal at least up the first digit after the comma. This is programmatically verified using the script described below.\n\n### Verify conversion results\n\nYou can use the `verify_conversion.py` script to compare the inference outputs of a Jax checkpoint vs. the converted PyTorch checkpoint.\nBut since this requires properly running the Jax FNet it requires a bit of setup and some modifications to the official implementation.\n\n#### Verification Setup\n\n1. Clone the official implementation\n\n```bash\nsvn export https://github.com/google-research/google-research/trunk/f_net\n# or\ngit clone git@github.com:google-research/google-research.git\ncd google-research/f_net\n```\n\n2. Edit the config in the official implementation to fit the checkpoint you want to run.\n\n3. Add the following to the return value of `_compute_pretraining_metrics` in `models.py`:\n\n```python\nreturn {\n    ...\n    \"masked_lm_logits\": masked_lm_logits,\n    \"next_sentence_logits\": next_sentence_logits\n}\n```\n\n2. Create a `setup.py` file in the parent directory of the `f_net` directory with the following content\n\n```python\nfrom setuptools import setup\n\nsetup(\n    name='fnet_jax',\n    version='0.1.0',\n    install_requires=[],\n    packages=['f_net']\n)\n```\n\n3. Install as a dependency in your `fnet-pytorch` project\n\n```bash\npip install -e path/to/dir-of-\"setup.py\"\n```\n\n#### Run the verification script\n\n```bash\npython verify_conversion.py \\\n    --jax path/to/jax_checkpoint \\\n    --torch path/to/fnet_for_pretraining.statedict.pt \\\n    --config path/to/config.json \\\n    --vocab path/to/vocab\n```\n\nThis should initialize both models from the checkpoints and run inference on a sample text and compare the output logits.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ferksch%2Ffnet-pytorch","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Ferksch%2Ffnet-pytorch","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ferksch%2Ffnet-pytorch/lists"}