{"id":15361526,"url":"https://github.com/muhd-umer/pvt-flax","last_synced_at":"2025-11-07T11:30:53.397Z","repository":{"id":154884917,"uuid":"532872269","full_name":"muhd-umer/pvt-flax","owner":"muhd-umer","description":"Unofficial JAX/Flax implementation of Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions.","archived":false,"fork":false,"pushed_at":"2023-12-15T12:43:19.000Z","size":153,"stargazers_count":2,"open_issues_count":0,"forks_count":0,"subscribers_count":2,"default_branch":"main","last_synced_at":"2024-12-27T21:29:35.543Z","etag":null,"topics":["deep-learning","flax","image-classification","implementation","jax"],"latest_commit_sha":null,"homepage":"https://arxiv.org/abs/2102.12122","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/muhd-umer.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-09-05T11:34:16.000Z","updated_at":"2023-08-12T11:25:09.000Z","dependencies_parsed_at":"2023-12-15T13:46:21.969Z","dependency_job_id":"c3ec48fb-9849-4ab5-807f-f02113d1c9ef","html_url":"https://github.com/muhd-umer/pvt-flax","commit_stats":{"total_commits":96,"total_committers":2,"mean_commits":48.0,"dds":"0.20833333333333337","last_synced_commit":"3a72d9210ed00c204eb1c79fb21e6ee6615504e4"},"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/muhd-umer%2Fpvt-flax","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/muhd-umer%2Fpvt-flax/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/muhd-umer%2Fpvt-flax/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/muhd-umer%2Fpvt-flax/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/muhd-umer","download_url":"https://codeload.github.com/muhd-umer/pvt-flax/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":239529206,"owners_count":19654066,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","flax","image-classification","implementation","jax"],"created_at":"2024-10-01T12:55:30.835Z","updated_at":"2025-02-18T18:41:47.961Z","avatar_url":"https://github.com/muhd-umer.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Pyramid Vision Transformer\n[![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT)\n[![JAX](https://img.shields.io/badge/JAX-0.3.16-orange.svg)](https://github.com/google/jax)\n\u003ca href=\"https://arxiv.org/abs/2106.13797\"\u003e\u003cimg src=\"https://img.shields.io/badge/arXiv-Paper-\u003cCOLOR\u003e.svg\" \u003e\u003c/a\u003e\u003c/h1\u003e \n\nThis repo contains the **unofficial** JAX/Flax implementation of \u003ca href=\"https://arxiv.org/abs/2106.13797\"\u003ePVT v2: Improved Baselines with Pyramid Vision Transformer\u003c/a\u003e. \u003cbr/\u003e\nAll credits to the authors **Wenhai Wang, Enze Xie, Xiang Li, Deng-Ping Fan, Kaitao Song, Ding Liang, Tong Lu, Ping Luo, Ling Shao** for their wonderful work.\n\n## Dependencies\n*It is recommended to create a new virtual environment so that updates/downgrades of packages do not break other projects.*\u003cbr/\u003e\n- Environment characteristics:\u003cbr/\u003e`python = 3.9.12` `cuda = 11.3` `jax = 0.3.16` `flax = 0.6.0`\n- Follow the instructions on [official JAX/Flax documentation](https://flax.readthedocs.io/en/latest/installation.html) for installing their packages.\n\n  ```\n  pip install -r requirements.txt\n  ```\n\n*Note: Flax is not dependent on TensorFlow itself, however, we make use of methods that take advantage of `tf.io.gfile`* As such, we only install `tensorflow-cpu`. Same is the case with PyTorch, we only install it in order to use their `torch.data.DataLoader`.\n\n## Run\nTo get started, clone this repo and install the required dependencies.\u003cbr/\u003e\n\n### Datasets\n- **TensorFlow Datasets** - Refer to [TensorFlow Dataset Image Classification Catalog](https://www.tensorflow.org/datasets/catalog/overview#image_classification) and accordingly modify the following keys in `config/default.py`.\n\n  ```python\n  config.dataset_name = \"imagenette\"\n  config.data_shape = [224, 224]\n  config.num_classes = 10\n  config.split_keys = [\"train\", \"validation\"]\n  ```\n\n- **PyTorch DataLoader** - To load datasets in PyTorch style, use the wrapper for torch.DataLoader in `data/numpyloader.py` -\u003e `NumpyLoader` along with a custom collate function.\n- **Custom Dataset** - Currently, this repo does not support out of the box support for custom image classification dataset. However, you can manipulate `NumpyLoader` to accomplish this.\n\n### Training\n- Configure the **{key: value pairs}** in the config file present at `config/default.py`.\u003cbr/\u003e\n- Execute train.py with path to checkpoint and --eval-only argument. Example usage:\n\n  ```python\n  python train.py --model-name \"PVT_V2_B0\" --work-dir \"output/\"\n  ```\n\n### Evaluation\n- Execute train.py with appropriate arguments. Example usage:\n\n  ```python\n  python train.py --model-name \"PVT_V2_B0\" \\\n                  --eval-only \\\n                  --checkpoint_dir \"output/\"\n  ```\n## To do\n- [ ] Convert ImageNet pretrained PyTorch weights (.pth) to Flax weights\n\n*Note: Since my undergrad studies are resuming after summer break, I may or may not be able to find time to complete the above tasks. \nIf you want to implement the aforelisted tasks, I'll be more than glad to merge your pull request. ❤️*\n\n## Acknowledgements\nWe acknowledge the excellent implementation of PVT in [MMDetection](https://github.com/open-mmlab/mmdetection), [PyTorch Image Models](https://github.com/rwightman/pytorch-image-models) and the [official implementation](https://github.com/whai362/PVT). I referred to these implementations as a source of reference.\n\n## Citing PVT\n- **PVT v1**\n\n  ```\n  @inproceedings{wang2021pyramid,\n    title={Pyramid vision transformer: A versatile backbone for dense prediction without convolutions},\n    author={Wang, Wenhai and Xie, Enze and Li, Xiang and Fan, Deng-Ping and Song, Kaitao and Liang, Ding and Lu, Tong and Luo, Ping and Shao, Ling},\n    booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},\n    pages={568--578},\n    year={2021}\n  }\n  ```\n\n- **PVT v2**\n\n  ```\n  @article{wang2021pvtv2,\n    title={Pvtv2: Improved baselines with pyramid vision transformer},\n    author={Wang, Wenhai and Xie, Enze and Li, Xiang and Fan, Deng-Ping and Song, Kaitao and Liang, Ding and Lu, Tong and Luo, Ping and Shao, Ling},\n    journal={Computational Visual Media},\n    volume={8},\n    number={3},\n    pages={1--10},\n    year={2022},\n    publisher={Springer}\n  }\n  ```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmuhd-umer%2Fpvt-flax","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fmuhd-umer%2Fpvt-flax","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmuhd-umer%2Fpvt-flax/lists"}