{"id":17245508,"url":"https://github.com/yjlolo/pytorch-deep-markov-model","last_synced_at":"2025-04-14T04:09:15.229Z","repository":{"id":55015195,"uuid":"288708170","full_name":"yjlolo/pytorch-deep-markov-model","owner":"yjlolo","description":"PyTorch re-implementation of [Structured Inference Networks for Nonlinear State Space Models, AAAI 17]","archived":false,"fork":false,"pushed_at":"2021-03-06T15:55:18.000Z","size":228,"stargazers_count":24,"open_issues_count":1,"forks_count":1,"subscribers_count":2,"default_branch":"master","last_synced_at":"2025-04-14T04:09:10.486Z","etag":null,"topics":["aaai","markov-model","pytorch-implementation","reimplementation","sequential-data","variational-autoencoders","variational-inference"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/yjlolo.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2020-08-19T11:00:47.000Z","updated_at":"2025-01-03T04:28:46.000Z","dependencies_parsed_at":"2022-08-14T09:10:19.675Z","dependency_job_id":null,"html_url":"https://github.com/yjlolo/pytorch-deep-markov-model","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/yjlolo%2Fpytorch-deep-markov-model","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/yjlolo%2Fpytorch-deep-markov-model/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/yjlolo%2Fpytorch-deep-markov-model/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/yjlolo%2Fpytorch-deep-markov-model/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/yjlolo","download_url":"https://codeload.github.com/yjlolo/pytorch-deep-markov-model/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248819405,"owners_count":21166477,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["aaai","markov-model","pytorch-implementation","reimplementation","sequential-data","variational-autoencoders","variational-inference"],"created_at":"2024-10-15T06:29:45.501Z","updated_at":"2025-04-14T04:09:15.201Z","avatar_url":"https://github.com/yjlolo.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# pytorch-deep-markov-model\nPyTorch re-implementatoin of the Deep Markov Model (https://arxiv.org/abs/1609.09869)\n```\n@inproceedings{10.5555/3298483.3298543,\n    author = {Krishnan, Rahul G. and Shalit, Uri and Sontag, David},\n    title = {Structured Inference Networks for Nonlinear State Space Models},\n    year = {2017},\n    publisher = {AAAI Press},\n    booktitle = {Proceedings of the Thirty-First AAAI Conference on Artificial Intelligence},\n    pages = {2101–2109},\n    numpages = {9},\n    location = {San Francisco, California, USA},\n    series = {AAAI'17}\n}\n```\n**Note:** \n1. The calculated metrics in `model/metrics.py` do not match those reported in the paper, which might be (more likely) due to differences in parameter settings and metric calculations.\n2. The current implementatoin only supports JSB polyphonic music dataset.\n\n## Under-development\nRefer to the branch `factorial-dmm` for a model described as [Factorial DMM](https://groups.csail.mit.edu/sls/publications/2019/SameerKhurana_ICASSP-2019.pdf).\nThe other branch `refractor` is trying to improve readability with increased options of models (DOCUMENT NOT UPDATED YET!).\n\n## Usage\nTraining the model with the default `config.json`:\n    \n    python train.py -c config.json\n\n\nadd `-i` flag to specifically name the experiment that is to be saved under `saved/`.\n\n## `config.json`\nThis file specifies parameters and configurations.\nBelow explains some key parameters.\n\n**A careful fine-tuning of the parameters seems necessary to match the reported performances.**\n```javascript\n{\n    \"arch\": {\n        \"type\": \"DeepMarkovModel\",\n        \"args\": {\n            \"input_dim\": 88,\n            \"z_dim\": 100,\n            \"emission_dim\": 100,\n            \"transition_dim\": 200,\n            \"rnn_dim\": 600,\n            \"rnn_type\": \"lstm\",\n            \"rnn_layers\": 1,\n            \"rnn_bidirection\": false,   // condition z_t on both directions of inputs,\n\t    \t\t\t\t// manually turn off `reverse_rnn_input` if True\n\t\t\t\t\t// (this is minor and could be quickly fixed)\n            \"use_embedding\": true,      // use extra linear layer before RNN\n            \"orthogonal_init\": true,    // orthogonal initialization for RNN\n\t    \"gated_transition\": true,       // use linear/non-linear gated transition\n            \"train_init\": false,        // make z0 trainble\n            \"mean_field\": false,        // use mean-field posterior q(z_t | x)\n            \"reverse_rnn_input\": true,  // condition z_t on future inputs\n            \"sample\": true              // sample during reparameterization\n        }\n    },\n    \"optimizer\": {\n        \"type\": \"Adam\",\n        \"args\":{\n            \"lr\": 0.0008,               // default value from the author's source code\n            \"weight_decay\": 0.0,        // debugging stage indicates that 1.0 prevents training\n            \"amsgrad\": true,\n            \"betas\": [0.9, 0.999]\n        }\n    },\n    \"trainer\": {\n        \"epochs\": 3000,\n        \"overfit_single_batch\": false,  // overfit one single batch for debug\n\n        \"save_dir\": \"saved/\",\n        \"save_period\": 500,\n        \"verbosity\": 2,\n        \n        \"monitor\": \"min val_loss\",\n        \"early_stop\": 100,\n\n        \"tensorboard\": true,\n\n        \"min_anneal_factor\": 0.0,\n        \"anneal_update\": 5000\n    }\n}\n```\n\n## References\n0. Project template brought from the [pytorch-template](https://github.com/victoresque/pytorch-template)\n1. The original [source code](https://github.com/clinicalml/structuredinference/tree/master/expt-polyphonic-fast) in Theano\n2. PyTorch implementation in [Pyro](https://github.com/pyro-ppl/pyro/tree/dev/examples/dmm) framework\n3. Another PyTorch implementation by [@guxd](https://github.com/guxd/deepHMM)\n\n## To-Do\n- [ ] fine-tune to match the reported performances in the paper\n- [ ] correct (if any) errors in metric calculation, `model/metric.py`\n- [ ] optimize important sampling\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fyjlolo%2Fpytorch-deep-markov-model","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fyjlolo%2Fpytorch-deep-markov-model","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fyjlolo%2Fpytorch-deep-markov-model/lists"}