Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/erfanzar/fjdiffusion
implementation of in Jax/Flax !
https://github.com/erfanzar/fjdiffusion
flax generative-adversarial-network generative-model jax stable-diffusion unet
Last synced: about 1 month ago
JSON representation
implementation of in Jax/Flax !
- Host: GitHub
- URL: https://github.com/erfanzar/fjdiffusion
- Owner: erfanzar
- License: apache-2.0
- Created: 2023-06-08T13:04:58.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2023-07-23T13:05:13.000Z (over 1 year ago)
- Last Synced: 2024-11-07T16:17:17.438Z (3 months ago)
- Topics: flax, generative-adversarial-network, generative-model, jax, stable-diffusion, unet
- Language: Python
- Homepage:
- Size: 136 KB
- Stars: 4
- Watchers: 3
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# FJDiffusion
Implementation and Pretrained Diffusion Models in JAX
## Description
This project provides an implementation of diffusion models in JAX, along with pretrained models. Diffusion models are a
class of generative models that can be used for tasks such as image generation, inpainting, denoising, and
super-resolution. This project aims to make it easy for researchers and developers to use diffusion models in their own
projects.## Features
- Implementation of diffusion models in JAX
- Pretrained models for various tasks
- Easy-to-use API for training and inference
- Support for distributed training on multiple GPUs or TPUs
- Extensive documentation and examples## Installation
To install the project, follow these steps:
1. Clone the repository:
```bash
git clone https://github.com/erfanzar/FJDiffusion.git
```2. Change into the project directory:
```bash
cd FJDiffusion
```3. Install the required dependencies:
```bash
pip install -r requirements.txt
```## Usage
To use the project, follow these steps:
1. Import the necessary modules:
##### TODO
2. Load a pretrained model:
##### TODO
3. Generate samples from the model:
##### TODO
4. Perform inference on an input:
##### TODO
For more detailed usage instructions and examples, please refer to the documentation.
## Getting PartitionRules
here's how you can get partition rules of each model in order to use them for pjit and fsdp
#### Unet2D
```python
from FJDiffusion.moonwalker.configs import Unet2DConfigpartition_rules = Unet2DConfig.get_partition_rules(fully_fsdp=True)
```#### VAE
```python
from FJDiffusion.moonwalker.configs import AutoencoderKlConfigpartition_rules = AutoencoderKlConfig.get_partition_rules(fully_fsdp=True)
```#### CLIPTextModel
```python
from FJDiffusion.moonwalker.configs import get_clip_partition_rulespartition_rules = get_clip_partition_rules(fully_fsdp=True)
```## Contributing
Contributions to this project are welcome! If you would like to contribute, please follow these steps:
1. Fork the repository.
2. Create a new branch for your feature or bug fix.
3. Make your changes and commit them.
4. Push your changes to your fork.
5. Submit a pull request.Please make sure to follow the code style and conventions used in the project.
## License
This project is licensed under the Apache v2.0 License. See
the [LICENSE](https://github.com/erfanzar/FJDiffusion/blob/main/LICENSE) file for more information.## Acknowledgements
This project is built upon the work of only one researcher / developer. I would like to say if there's any problem
in open-source implementations and pretrained models after final releases please let me know <3.## Contact
If you have any questions or suggestions regarding this project, please feel free to contact me at
[email protected]