{"id":16956237,"url":"https://github.com/juliuskunze/cwvae-jax","last_synced_at":"2025-07-15T11:06:15.603Z","repository":{"id":73768676,"uuid":"386610300","full_name":"juliuskunze/cwvae-jax","owner":"juliuskunze","description":"Clockwork VAEs in JAX/Flax","archived":false,"fork":false,"pushed_at":"2021-07-16T11:50:25.000Z","size":15,"stargazers_count":32,"open_issues_count":2,"forks_count":3,"subscribers_count":3,"default_branch":"main","last_synced_at":"2025-04-04T17:22:02.338Z","etag":null,"topics":["deep-learning","jax","machine-learning","research","video-prediction","world-models"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/juliuskunze.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2021-07-16T11:16:53.000Z","updated_at":"2024-11-13T22:46:09.000Z","dependencies_parsed_at":"2023-04-02T10:59:49.820Z","dependency_job_id":null,"html_url":"https://github.com/juliuskunze/cwvae-jax","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/juliuskunze/cwvae-jax","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/juliuskunze%2Fcwvae-jax","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/juliuskunze%2Fcwvae-jax/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/juliuskunze%2Fcwvae-jax/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/juliuskunze%2Fcwvae-jax/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/juliuskunze","download_url":"https://codeload.github.com/juliuskunze/cwvae-jax/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/juliuskunze%2Fcwvae-jax/sbom","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":265430419,"owners_count":23764001,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","jax","machine-learning","research","video-prediction","world-models"],"created_at":"2024-10-13T22:14:30.740Z","updated_at":"2025-07-15T11:06:15.558Z","avatar_url":"https://github.com/juliuskunze.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Clockwork VAEs in JAX/Flax\n\nImplementation of experiments in the paper [Clockwork Variational Autoencoders](https://arxiv.org/pdf/2102.09532.pdf) ([project website](http://danijar.com/cwvae)) using [JAX](https://github.com/google/jax) and [Flax](https://github.com/google/flax), ported from the [official TensorFlow implementation](https://github.com/vaibhavsaxena11/cwvae).\n\nRunning on a single TPU v3, training is **10x faster** than reported in the paper (60h -\u003e 6h on `minerl`).\n\n## Method\n\n\u003cimg src=\"https://danijar.com/asset/cwvae/header.gif\"\u003e\n\nClockwork VAEs are deep generative model that learn long-term dependencies in video by leveraging hierarchies of representations that progress at different clock speeds. In contrast to prior video prediction methods that typically focus on predicting sharp but short sequences in the future, Clockwork VAEs can accurately predict high-level content, such as object positions and identities, for 1000 frames.\n\nClockwork VAEs build upon the [Recurrent State Space Model (RSSM)](https://arxiv.org/pdf/1811.04551.pdf), so each state contains a deterministic component for long-term memory and a stochastic component for sampling diverse plausible futures. Clockwork VAEs are trained end-to-end to optimize the evidence lower bound (ELBO) that consists of a reconstruction term for each image and a KL regularizer for each stochastic variable in the model.\n\n## Instructions\n\nThis repository contains the code for training the Clockwork VAE model on the datasets `minerl`, `mazes`, and `mmnist`.\n\nThe datasets will automatically be downloaded into the `--datadir` directory.\n\n```sh\npython3 train.py --logdir /path/to/logdir --datadir /path/to/datasets --config configs/\u003cdataset\u003e.yml \n```\n\nThe evaluation script writes open-loop video predictions in both PNG and NPZ format and plots of PSNR and SSIM to the data directory.\n\n```sh\npython3 eval.py --logdir /path/to/logdir\n```\n\n## Known differences from the original\n\n- Flax' default kernel initializer, layer precision and GRU implementation (avoiding redundant biases) are used.\n- For some configuration parameters, only the defaults are implemented.\n- Training metrics and videos are logged with `wandb`.\n- The base configuration is in `config.py`.\n\nAdded features:\n\n- This implementation runs on TPU out-of-the-box.\n- Apart from the config file, configuration can be done via command line and `wandb`.\n- Matching the `seed` of a previous run will exactly repeat it.\n\n## Things to watch out for\n\nReplication of paper results for the `mazes` dataset has not been confirmed yet.\n\nGetting evaluation metrics is a memory bottleneck during training, due to the large `eval_seq_len`. \nIf you run out of device memory, consider lowering it during training, for example to 100. \nRemember to pass in the original value to `eval.py` to get unchanged results.\n\n## Acknowledgements\n\nThanks to [Vaibhav Saxena](https://github.com/vaibhavsaxena11) and [Danijar Hafner](https://danijar.com) for helpful discussions and to [Jamie Townsend](https://github.com/j-towns) for reviewing code.","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fjuliuskunze%2Fcwvae-jax","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fjuliuskunze%2Fcwvae-jax","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fjuliuskunze%2Fcwvae-jax/lists"}