https://github.com/ermongroup/bcd-nets
Code for `BCD Nets: Scalable Variational Approaches for Bayesian Causal Discovery`, Neurips 2021
https://github.com/ermongroup/bcd-nets
Last synced: 5 months ago
JSON representation
Code for `BCD Nets: Scalable Variational Approaches for Bayesian Causal Discovery`, Neurips 2021
- Host: GitHub
- URL: https://github.com/ermongroup/bcd-nets
- Owner: ermongroup
- Created: 2021-10-26T04:54:49.000Z (almost 4 years ago)
- Default Branch: main
- Last Pushed: 2021-10-26T05:00:22.000Z (almost 4 years ago)
- Last Synced: 2025-05-05T14:21:56.818Z (5 months ago)
- Language: Python
- Size: 292 KB
- Stars: 26
- Watchers: 7
- Forks: 6
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
This folder contains the code for 'Scalable Variational Approaches for Bayesian Causal Discovery'.
# Installation
To install, use conda with
`conda env create -f environment.yml`.
If this fails for some reason, the key packages are
`jax jaxlib ott-jax cdt sklearn matplotlib optax dm-haiku tensorflow_probability torch wandb cython fuzzywuzzy python-Levenshtein sumu lingam`You may have to recompile the cython module for the Hungarian algorithm by running
`cython -3 mine.pyx` and then
`g++ -shared -pthread -fPIC -fwrapv -O3 -Wall -fno-strict-aliasing -o mine.so mine.c`
in the `c_modules` directory.# Running Experiments
Run with the `--use_wandb` flag to write results to a new weights and biases project. Otherwise, the results will be printed to stout.In `utils.py` you may need to uncomment line 11 and replace your path to the `Rscript` binary
To run BCD Nets and GOLEM experiments in figure 1, for one random seed use arguments such as
`python main.py -s 0 --n_data 100 --dim 32 --degree 1 --num_steps 30000 --do_ev_noise --sem_type linear-gauss --batch_size 256 --print_golem_solution --degree 1`To run the baselines, run
`python main.py --eval_eid --run_baselines --n_data 100 --dim 32 --sem_type linear-gauss --only_baselines --degree 2 --do_ev_noise --n_baseline_seeds 3`To run GOLEM, run
`python main.py --eval_eid --print_golem_solution --n_data 100 --dim 32 --sem_type linear-gauss --degree 2 --do_ev_noise --num_steps 10`To run on the Sachs dataset, include the `--use_sachs` flag.