{"id":15400730,"url":"https://github.com/dbraun/dac-jax","last_synced_at":"2025-04-15T22:30:37.650Z","repository":{"id":240536329,"uuid":"754879299","full_name":"DBraun/DAC-JAX","owner":"DBraun","description":"JAX Implementations of Descript Audio Codec and EnCodec","archived":false,"fork":false,"pushed_at":"2025-03-05T03:47:37.000Z","size":270,"stargazers_count":24,"open_issues_count":0,"forks_count":2,"subscribers_count":2,"default_branch":"main","last_synced_at":"2025-03-29T03:12:13.319Z","etag":null,"topics":["audio","audio-codec","audio-compression","jax","machine-learning"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/DBraun.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2024-02-08T23:46:27.000Z","updated_at":"2025-03-25T21:06:40.000Z","dependencies_parsed_at":"2024-05-19T16:30:15.190Z","dependency_job_id":"b7bd33f8-5633-439a-ab35-589824e42ae8","html_url":"https://github.com/DBraun/DAC-JAX","commit_stats":{"total_commits":27,"total_committers":1,"mean_commits":27.0,"dds":0.0,"last_synced_commit":"1d0c8b035ba124f9e64da5dc54906fe7ceb47852"},"previous_names":["dbraun/dac-jax"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/DBraun%2FDAC-JAX","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/DBraun%2FDAC-JAX/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/DBraun%2FDAC-JAX/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/DBraun%2FDAC-JAX/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/DBraun","download_url":"https://codeload.github.com/DBraun/DAC-JAX/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":249165833,"owners_count":21223330,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["audio","audio-codec","audio-compression","jax","machine-learning"],"created_at":"2024-10-01T15:54:48.069Z","updated_at":"2025-04-15T22:30:37.304Z","avatar_url":"https://github.com/DBraun.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# DAC-JAX and EnCodec-JAX\n\nThis repository holds **unofficial** JAX implementations of Descript's DAC and Meta's EnCodec.\nWe are not affiliated with Descript or Meta.\n\nYou can read the DAC-JAX paper [here](https://arxiv.org/abs/2405.11554).\n\n## Background\n\nIn 2022, Meta published \"[High Fidelity Neural Audio Compression](https://arxiv.org/abs/2210.13438)\".\nThey eventually open-sourced the code inside [AudioCraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/ENCODEC.md).\n\nIn 2023, Descript published a related work \"[High-Fidelity Audio Compression with Improved RVQGAN](https://arxiv.org/abs/2306.06546)\"\nand released their code under the name [DAC](https://github.com/descriptinc/descript-audio-codec/) (Descript Audio Codec).\n\nBoth EnCodec and DAC are neural audio codecs which use residual vector quantization inside a fully convolutional\nencoder-decoder architecture.\n\n## Usage\n\n### Installation\n\n1. Upgrade `pip` and `setuptools`:\n    ```bash\n    pip install --upgrade pip setuptools\n    ```\n\n2. Install the **CPU** version of [PyTorch](https://pytorch.org/).\n   We strongly suggest the CPU version because trying to install a GPU version can conflict with JAX's CUDA-related installation.\n   PyTorch is required because it's used to load pretrained model weights.\n\n3. Install [JAX](https://jax.readthedocs.io/en/latest/installation.html) (with GPU support).\n\n4. Install DAC-JAX with one of the following:\n\n    \u003c!-- ```\n    python -m pip install dac-jax\n    ```\n    OR --\u003e\n    \n    ```\n    pip install git+https://github.com/DBraun/DAC-JAX\n    ```\n    \n    Or,\n    \n    ```bash\n    python -m pip install .\n    ```\n    \n    Or, if you intend to contribute, clone and do an [editable install](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs):\n    ```bash\n    python -m pip install -e \".[dev]\"\n    ```\n\n### Weights\nThe original Descript repository releases model weights under the MIT license. These weights are for models that natively support 16 kHz, 24kHz, and 44.1kHz sampling rates. Our scripts download these PyTorch weights and load them into JAX.\nWeights are automatically downloaded when you first run an `encode` or `decode` command. You can download them in advance with one of the following commands:\n```bash\npython -m dac_jax download_model # downloads the default 44kHz variant\npython -m dac_jax download_model --model_type 44khz --model_bitrate 16kbps # downloads the 44kHz 16 kbps variant\npython -m dac_jax download_model --model_type 44khz # downloads the 44kHz variant\npython -m dac_jax download_model --model_type 24khz # downloads the 24kHz variant\npython -m dac_jax download_model --model_type 16khz # downloads the 16kHz variant\n```\n\nEnCodec weights can be downloaded similarly. This will download the 32 kHz EnCodec used in [MusicGen](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md).\n```bash\npython -m dac_jax download_encodec\n```\n\nFor both DAC and EnCodec, the default download location is `~/.cache/dac_jax`. You can change the location by setting an **absolute path** value for an environment variable `DAC_JAX_CACHE`. For example, on macOS/Linux:\n```bash\nexport DAC_JAX_CACHE=/Users/admin/my-project/dac_jax_models\n```\n\nIf you do this, remember to still have `DAC_JAX_CACHE` set before you use the `load_model` function.\n\n### Compress audio\n```\npython -m dac_jax encode /path/to/input --output /path/to/output/codes\n```\n\nThis command will create `.dac` files with the same name as the input files.\nIt will also preserve the directory structure relative to input root and\nre-create it in the output directory. Please use `python -m dac_jax encode --help`\nfor more options.\n\n### Reconstruct audio from compressed codes\n```\npython -m dac_jax decode /path/to/output/codes --output /path/to/reconstructed_input\n```\n\nThis command will create `.wav` files with the same name as the input files.\nIt will also preserve the directory structure relative to input root and\nre-create it in the output directory. Please use `python -m dac_jax decode --help`\nfor more options.\n\n### Programmatic usage (DAC and EnCodec)\n\nHere we use `jax.jit` for optimized encoding and decoding.\nThis does not do sample-rate conversion or volume normalization in the encoder or decoder.\n\n```python\nfrom functools import partial\n\nimport jax\nfrom jax import numpy as jnp\nimport librosa\n\nimport dac_jax\n\nmodel, variables = dac_jax.load_model(model_type=\"44khz\")\n\n# If you want to use pretrained 32 kHz EnCodec from Meta's MusicGen, use this:\n# model, variables = dac_jax.load_encodec_model()\n\n@jax.jit\ndef encode_to_codes(x: jnp.ndarray):\n    codes, scale = model.apply(\n        variables,\n        x,\n        method=\"encode\",\n    )\n    return codes, scale\n\n@partial(jax.jit, static_argnums=(1, 2))\ndef decode_from_codes(codes: jnp.ndarray, scale, length: int = None):\n    recons = model.apply(\n        variables,\n        codes,\n        scale,\n        length,\n        method=\"decode\",\n    )\n    return recons\n\n# Load a mono audio file with the correct sample rate\nsignal, sample_rate = librosa.load('input.wav', sr=model.sample_rate, mono=True, duration=.5)\n\nsignal = jnp.array(signal, dtype=jnp.float32)\nwhile signal.ndim \u003c 3:\n    signal = jnp.expand_dims(signal, axis=0)\n\noriginal_length = signal.shape[-1]\n\ncodes, scale = encode_to_codes(signal)\nassert codes.shape[1] == model.num_codebooks\n\nrecons = decode_from_codes(codes, scale, original_length)\n```\n\n### DAC with Binding\n\nHere we use DAC-JAX as a \"[bound](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html#bind)\" module, freeing us from repeatedly passing variables as an argument and using `.apply`. Note that bound modules are not meant to be used in fine-tuning.\n\n```python\nimport dac_jax\nfrom dac_jax import DACFile\n\nfrom jax import numpy as jnp\nimport librosa\n\n# Download a model and bind variables to it.\nmodel, variables = dac_jax.load_model(model_type=\"44khz\")\nmodel = model.bind(variables)\n\n# Load a mono audio file\nsignal, sample_rate = librosa.load('input.wav', sr=44100, mono=True, duration=.5)\n\nsignal = jnp.array(signal, dtype=jnp.float32)\nwhile signal.ndim \u003c 3:\n    signal = jnp.expand_dims(signal, axis=0)\n\n# Encode audio signal as one long file (may run out of GPU memory on long files).\n# This performs resampling to the codec's sample rate and volume normalization.\ndac_file = model.encode_to_dac(signal, sample_rate)\n\n# Save to a file\ndac_file.save(\"dac_file_001.dac\")\n\n# Load a file\ndac_file = DACFile.load(\"dac_file_001.dac\")\n\n# Decode audio signal. Since we're passing a dac_file, this undoes the \n# previous sample rate conversion and volume normalization.\ny = model.decode(dac_file)\n\n# Calculate mean-square error of reconstruction in time-domain\nmse = jnp.square(y-signal).mean()\n```\n\n### DAC compression with constant GPU memory regardless of input length:\n\n```python\nimport dac_jax\n\nimport jax\nimport jax.numpy as jnp\nimport librosa\n\n# Download a model and set padding to False because we will use the chunk functions.\nmodel, variables = dac_jax.load_model(model_type=\"44khz\", padding=False)\n\n# Load a mono audio file at any sample rate\nsignal, sample_rate = librosa.load('input.wav', sr=None, mono=True)\n\nsignal = jnp.array(signal, dtype=jnp.float32)\nwhile signal.ndim \u003c 3:\n    # signal will eventually be shaped [B, C, T]\n    signal = jnp.expand_dims(signal, axis=0)\n\n# Jit-compile these functions because they're used inside a loop over chunks.\n@jax.jit\ndef compress_chunk(x):\n    return model.apply(variables, x, method='compress_chunk')\n\n@jax.jit\ndef decompress_chunk(c):\n    return model.apply(variables, c, method='decompress_chunk')\n\nwin_duration = 0.5  # Adjust based on your GPU's memory size\ndac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration)\n\n# Save and load to and from disk\ndac_file.save(\"compressed.dac\")\ndac_file = dac_jax.DACFile.load(\"compressed.dac\")\n\n# Decompress it back to audio\ny = model.decompress(decompress_chunk, dac_file)\n```\n\n## DAC Training\nThe baseline model configuration can be trained using the following commands.\n\n```bash\npython scripts/train.py --args.load conf/final/44khz.yml --train.ckpt_dir=\"/tmp/dac_jax_runs\"\n```\n\nIn root directory, monitor with Tensorboard (`runs` will appear next to `scripts`):\n```bash\ntensorboard --logdir=\"/tmp/dac_jax_runs\"\n```\n\n## Testing\n\n```\npython -m pytest tests\n```\n\n## Limitations\n\nPull requests—especially ones which address any of the limitations below—are welcome.\n\n* We implement the \"chunked\" `compress`/`decompress` methods from the PyTorch repository, although this technique has some problems outlined [here](https://github.com/descriptinc/descript-audio-codec/issues/39).\n* We have not run all evaluation scripts in the `scripts` directory. For some of them, it makes sense to just keep using PyTorch instead of JAX.\n* The model architecture code (`model/dac.py`) has many static methods to help with finding DAC's `delay` and `output_length`. Please help us refactor this so that code is not so duplicated and at risk of typos.\n* In `audio_utils.py` we use [DM_AUX's](https://github.com/google-deepmind/dm_aux) STFT function instead of `jax.scipy.signal.stft`. We believe this is faster but requires more memory.\n* The source code of DAC-JAX has some `todo:` markings which indicate (mostly minor) improvements we'd like to have.\n* We don't have a Docker image yet like the original [DAC repository](https://github.com/descriptinc/descript-audio-codec) does.\n* Please check the limitations of [argbind](https://github.com/pseeth/argbind?tab=readme-ov-file#limitations-and-known-issues).\n* We don't provide a training script for EnCodec.\n\n## Citation\n\nIf you use this repository in your work, please cite  EnCodec:\n```\n@article{defossez2022high,\n  title={High fidelity neural audio compression},\n  author={D{\\'e}fossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi},\n  journal={arXiv preprint arXiv:2210.13438},\n  year={2022}\n}\n```\n\nDAC:\n\n```\n@article{kumar2024high,\n  title={High-fidelity audio compression with improved rvqgan},\n  author={Kumar, Rithesh and Seetharaman, Prem and Luebs, Alejandro and Kumar, Ishaan and Kumar, Kundan},\n  journal={Advances in Neural Information Processing Systems},\n  volume={36},\n  year={2024}\n}\n```\n\n\n\nand DAC-JAX:\n\n```\n@misc{braun2024dacjax,\n  title={{DAC-JAX}: A {JAX} Implementation of the Descript Audio Codec}, \n  author={David Braun},\n  year={2024},\n  eprint={2405.11554},\n  archivePrefix={arXiv},\n  primaryClass={cs.SD}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdbraun%2Fdac-jax","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fdbraun%2Fdac-jax","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdbraun%2Fdac-jax/lists"}