{"id":25509191,"url":"https://github.com/tonyduan/mixture-density-network","last_synced_at":"2025-04-10T14:04:54.190Z","repository":{"id":43291059,"uuid":"174052059","full_name":"tonyduan/mixture-density-network","owner":"tonyduan","description":"Mixture density network implemented in PyTorch.","archived":false,"fork":false,"pushed_at":"2023-05-18T06:29:12.000Z","size":57,"stargazers_count":141,"open_issues_count":2,"forks_count":22,"subscribers_count":2,"default_branch":"master","last_synced_at":"2025-03-24T12:47:34.502Z","etag":null,"topics":["machine-learning","mixture-density-networks","pytorch"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/tonyduan.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2019-03-06T01:59:39.000Z","updated_at":"2025-03-21T21:21:21.000Z","dependencies_parsed_at":"2023-02-17T03:15:29.284Z","dependency_job_id":null,"html_url":"https://github.com/tonyduan/mixture-density-network","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/tonyduan%2Fmixture-density-network","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/tonyduan%2Fmixture-density-network/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/tonyduan%2Fmixture-density-network/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/tonyduan%2Fmixture-density-network/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/tonyduan","download_url":"https://codeload.github.com/tonyduan/mixture-density-network/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248231602,"owners_count":21069360,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["machine-learning","mixture-density-networks","pytorch"],"created_at":"2025-02-19T08:54:13.930Z","updated_at":"2025-04-10T14:04:54.154Z","avatar_url":"https://github.com/tonyduan.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"### Mixture Density Network\n\nLast update: December 2022.\n\n---\n\nLightweight implementation of a mixture density network [1] in PyTorch.\n\n#### Setup\n\nSuppose we want to regress response $\\mathbf{y} \\in \\mathbb{R}^{d}$ using covariates $\\mathbf{x} \\in \\mathbb{R}^n$.\n\nWe model the conditional distribution as a mixture of Gaussians\n```math\np_\\theta(\\mathbf{y}|\\mathbf{x}) = \\sum_{k=1}^K \\pi_k N(\\boldsymbol\\mu^{(k)}, {\\boldsymbol\\Sigma}^{(k)}),\n```\nwhere the mixture distribution parameters are output by a neural network dependent on $\\mathbf{x}$.\n```math\n\\begin{align*}\n( \\boldsymbol\\pi \u0026 \\in\\Delta^{K-1} \u0026 \\boldsymbol\\mu^{(k)}\u0026\\in\\mathbb{R}^{d} \u0026\\boldsymbol\\Sigma^{(k)}\u0026\\in \\mathrm{S}_+^d) = f_\\theta(\\mathbf{x})\n\\end{align*}\n```\nThe training objective is to maximize log-likelihood. The objective is clearly non-convex.\n```math\n\\begin{align*}\n\\log p_\\theta(\\mathbf{y}|\\mathbf{x})\n\u0026 \\propto\\log \\sum_{k}\\left(\\pi_k\\exp\\left(-\\frac{1}{2}\\left(\\mathbf{y}-\\boldsymbol\\mu^{(k)}\\right)^\\top {\\boldsymbol\\Sigma^{(k)}}^{-1}\\left(\\mathbf{y}-\\boldsymbol\\mu^{(k)}\\right) -\\frac{1}{2}\\log\\det \\boldsymbol\\Sigma^{(k)}\\right)\\right)\\\\\n\u0026 = \\mathrm{logsumexp}_k\\left(\\log\\pi_k - \\frac{1}{2}\\left(\\mathbf{y}-\\boldsymbol\\mu^{(k)}\\right)^\\top {\\boldsymbol\\Sigma^{(k)}}^{-1}\\left(\\mathbf{y}-\\boldsymbol\\mu^{(k)}\\right) -\\frac{1}{2}\\log\\det \\boldsymbol\\Sigma^{(k)}\\right)\\\\\n\\end{align*}\n```\nImportantly, we need to use `torch.log_softmax(...)` to compute logits $\\log \\boldsymbol\\pi$ for numerical stability.\n\n#### Noise Model\n\nThere are several options we can make to constrain the noise model $\\boldsymbol\\Sigma^{(k)}$.\n\n1. No assumptions, $\\boldsymbol\\Sigma^{(k)} \\in \\mathrm{S}_+^d$.\n2. Fully factored, let $\\boldsymbol\\Sigma^{(k)} = \\mathrm{diag}({\\boldsymbol\\sigma^{(k)}}^{2}), {\\boldsymbol\\sigma^{(k)}}^{2}\\in\\mathbb{R}_+^d$ where the noise level for each dimension is predicted separately.\n3. Isotrotopic, let $\\boldsymbol\\Sigma^{(k)} = {\\sigma^{(k)}}^{2}\\mathbf{I}, {\\sigma^{(k)}}^{2}\\in\\mathbb{R}_+$ which assumes the same noise level for each dimension over $d$.\n4. Isotropic across clusters, let $\\boldsymbol\\Sigma^{(k)} = \\sigma^2\\mathbf{I}, \\sigma^2\\in\\mathbb{R}_+$ which assumes the same noise level for each dimension over $d$ *and* cluster.\n5. Fixed isotropic, same as above but do not learn $\\sigma^2$.\n\nThse correspond to the following objectives.\n```math\n\\begin{align*}\n\\log p_\\theta(\\mathbf{y}|\\mathbf{x}) \u0026 = \\mathrm{logsumexp}_k\\left(\\log\\pi_k - \\frac{1}{2}\\left(\\mathbf{y}-\\boldsymbol\\mu^{(k)}\\right)^\\top {\\boldsymbol\\Sigma^{(k)}}^{-1}\\left(\\mathbf{y}-\\boldsymbol\\mu^{(k)}\\right) -\\frac{1}{2}\\log\\det \\boldsymbol\\Sigma^{(k)}\\right)  \\tag{1}\\\\\n\u0026 = \\mathrm{logsumexp}_k \\left(\\log\\pi_k - \\frac{1}{2}\\left\\|\\frac{\\mathbf{y}-\\boldsymbol\\mu^{(k)}}{\\boldsymbol\\sigma^{(k)}}\\right\\|^2-\\|\\log\\boldsymbol\\sigma^{(k)}\\|_1\\right) \\tag{2}\\\\\n\u0026 = \\mathrm{logsumexp}_k \\left(\\log\\pi_k - \\frac{1}{2}\\left\\|\\frac{\\mathbf{y}-\\boldsymbol\\mu^{(k)}}{\\sigma^{(k)}}\\right\\|^2-d\\log(\\sigma^{(k)})\\right) \\tag{3}\\\\\n\u0026 = \\mathrm{logsumexp}_k \\left(\\log\\pi_k - \\frac{1}{2}\\left\\|\\frac{\\mathbf{y}-\\boldsymbol\\mu^{(k)}}{\\sigma}\\right\\|^2-d\\log(\\sigma)\\right) \\tag{4}\\\\\n\u0026 = \\mathrm{logsumexp}_k \\left(\\log\\pi_k - \\frac{1}{2}\\left\\|\\frac{\\mathbf{y}-\\boldsymbol\\mu^{(k)}}{\\sigma}\\right\\|^2\\right) \\tag{5}\n\\end{align*}\n```\nIn this repository we implement options (2, 3, 4, 5).\n\n#### Miscellaneous\n\nRecall that the objective is clearly non-convex. For example, one local minimum is to ignore all modes except one and place a single diffuse Gaussian distribution on the marginal outcome (i.e. high ${\\sigma}^{(k)}$).\n\nFor this reason it's often preferable to over-parameterize the model and specify `n_components` higher than the true hypothesized number of modes.\n\n#### Usage\n\n```python\nimport torch\nfrom src.blocks import MixtureDensityNetwork\n\nx = torch.randn(5, 1)\ny = torch.randn(5, 1)\n\n# 1D input, 1D output, 3 mixture components\nmodel = MixtureDensityNetwork(1, 1, n_components=3, hidden_dim=50)\npred_parameters = model(x)\n\n# use this to backprop\nloss = model.loss(x, y)\n\n# use this to sample a trained model\nsamples = model.sample(x)\n```\n\nFor further details see the `examples/` folder. Below is a model fit with 3 components in `ex_1d.py`.\n\n![ex_model](examples/ex_1d.png \"Example model output\")\n\n#### References\n\n[1] Bishop, C. M. Mixture density networks. (1994).\n\n[2] Ha, D. \u0026 Schmidhuber, J. Recurrent World Models Facilitate Policy Evolution. in *Advances in Neural Information Processing Systems 31* (eds. Bengio, S. et al.) 2450–2462 (Curran Associates, Inc., 2018).\n\n#### License\n\nThis code is available under the MIT License.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ftonyduan%2Fmixture-density-network","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Ftonyduan%2Fmixture-density-network","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ftonyduan%2Fmixture-density-network/lists"}