{"id":44433006,"url":"https://github.com/scikit-learn-contrib/bde","last_synced_at":"2026-02-12T13:41:59.454Z","repository":{"id":327284468,"uuid":"1041441656","full_name":"scikit-learn-contrib/bde","owner":"scikit-learn-contrib","description":"Bayesian Deep Ensembles via MILE: easy to use, scikit-learn compatible and fast (JAX powered)","archived":false,"fork":false,"pushed_at":"2026-02-02T16:30:50.000Z","size":18578,"stargazers_count":38,"open_issues_count":0,"forks_count":0,"subscribers_count":0,"default_branch":"main","last_synced_at":"2026-02-02T21:17:15.356Z","etag":null,"topics":["jax","machine-learning","mcmc","sampling-methods","scikit-learn","uncertainty-quantification"],"latest_commit_sha":null,"homepage":"https://contrib.scikit-learn.org/bde/","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"bsd-3-clause","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/scikit-learn-contrib.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null,"notice":null,"maintainers":null,"copyright":null,"agents":null,"dco":null,"cla":null}},"created_at":"2025-08-20T13:46:01.000Z","updated_at":"2026-02-02T19:27:54.000Z","dependencies_parsed_at":null,"dependency_job_id":null,"html_url":"https://github.com/scikit-learn-contrib/bde","commit_stats":null,"previous_names":["vyron-arvanitis/bde"],"tags_count":2,"template":false,"template_full_name":null,"purl":"pkg:github/scikit-learn-contrib/bde","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/scikit-learn-contrib%2Fbde","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/scikit-learn-contrib%2Fbde/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/scikit-learn-contrib%2Fbde/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/scikit-learn-contrib%2Fbde/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/scikit-learn-contrib","download_url":"https://codeload.github.com/scikit-learn-contrib/bde/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/scikit-learn-contrib%2Fbde/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":286080680,"owners_count":29367576,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2026-02-12T08:51:36.827Z","status":"ssl_error","status_checked_at":"2026-02-12T08:51:26.849Z","response_time":55,"last_error":"SSL_read: unexpected eof while reading","robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":false,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["jax","machine-learning","mcmc","sampling-methods","scikit-learn","uncertainty-quantification"],"created_at":"2026-02-12T13:41:58.747Z","updated_at":"2026-02-12T13:41:59.449Z","avatar_url":"https://github.com/scikit-learn-contrib.png","language":"Python","readme":"﻿# Bayesian Deep Ensembles for scikit-learn \u003ca href=\"https://github.com/scikit-learn-contrib/bde\"\u003e\u003cimg src=\"doc/_static/img/logo.svg\" align=\"right\" height=\"150\" /\u003e\n\n[![Docs Status](https://github.com/scikit-learn-contrib/bde/actions/workflows/deploy-gh-pages.yml/badge.svg)](https://github.com/scikit-learn-contrib/bde/actions/workflows/deploy-gh-pages.yml)\n[![Tests](https://github.com/scikit-learn-contrib/bde/actions/workflows/python-app.yml/badge.svg)](https://github.com/scikit-learn-contrib/bde/actions/workflows/python-app.yml)\n![Lifecycle: stable](https://img.shields.io/badge/lifecycle-stable-brightgreen)\n[![License](https://img.shields.io/github/license/scikit-learn-contrib/bde)](LICENSE)\n\n\n👉 **[Start Here: Complete Online Documentation](https://scikit-learn-contrib.github.io/bde/)**\n\n\nIntroduction\n------------\n\n**bde** is a user-friendly implementation of Bayesian Deep Ensembles compatible with\nscikit-learn with a particular focus on tabular data. It exposes estimators that plug\ninto scikit-learn pipelines while leveraging JAX for accelerator-backed training,\nsampling, and uncertainty quantification.\n\nIn particular, **bde** implements **Microcanonical Langevin Ensembles (MILE)** as\nintroduced in [*Microcanonical Langevin Ensembles: Advancing the Sampling of Bayesian Neural Networks* (ICLR 2025)](https://arxiv.org/abs/2502.06335).\nA conceptual overview of MILE is shown below:\n\n\u003cdiv style=\"width: 60%; margin: auto;\"\u003e\n    \u003cimg src=\"doc/_static/img/flowchart.png\" alt=\"MILE Overview\" style=\"width: 100%;\"\u003e\n\u003c/div\u003e\n\n\n**Scope:** As of right now this package supports full-batch MILE for fully connected\nfeedforward networks, covering classification and regression on tabular data.\nThe method can however also be applied to other\narchitectures and data modalities, but these are not yet in scope of this\nparticular implementation.\n\nInstallation\n------------\n\nTo install the latest release from PyPI, run:\n\n```\npip install sklearn-contrib-bde\n```\n\nTo install the latest development version from GitHub, run:\n\n```\npip install git+https://github.com/scikit-learn-contrib/bde.git\n```\n\nDeveloper environment\n---------------------\n\nWe recommend using [pixi](https://pixi.prefix.dev/latest/) to create a\ndeterministic development environment:\n\n```\npixi install\n\n# Then you can directly run examples like so:\npixi run python -m examples.example\n```\n\nPixi ensures the correct JAX, CUDA (when needed), and scikit-learn versions are\nselected automatically. See `pixi.lock` for channel and platform details.\n\n\n\nExample Usage\n-------------\n\nMinimal runnable scripts live in `examples/`, and the snippets below highlight the\nmost common regression and classification workflows. When running outside those\nscripts, remember to set the XLA device count so JAX allocates enough host devices (\nthis needs to be done before importing JAX):\n\n```\nexport XLA_FLAGS=\"--xla_force_host_platform_device_count=\u003cn_decive\u003e\"\n```\n\nAdjust the value to match the number of CPU (or GPU) devices you plan to use.\n\n### Regression Example\n\n```python\nimport os\n\nos.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=8\"\n\nimport jax.numpy as jnp\nfrom sklearn.datasets import fetch_openml\nfrom sklearn.metrics import root_mean_squared_error\nfrom sklearn.model_selection import train_test_split\n\nfrom bde import BdeRegressor\nfrom bde.loss import GaussianNLL\n\ndata = fetch_openml(name=\"airfoil_self_noise\", as_frame=True) # requires pandas\n\nX = data.data.values\ny = data.target.values.reshape(-1, 1)\n\nX_train, X_test, y_train, y_test = train_test_split(\n    X, y, test_size=0.2, random_state=42\n)\n\nXmu, Xstd = jnp.mean(X_train, 0), jnp.std(X_train, 0) + 1e-8\nYmu, Ystd = jnp.mean(y_train, 0), jnp.std(y_train, 0) + 1e-8\n\nXtr = (X_train - Xmu) / Xstd\nXte = (X_test - Xmu) / Xstd\nytr = (y_train - Ymu) / Ystd\nyte = (y_test - Ymu) / Ystd\n\n# Build the regressor\nregressor = BdeRegressor(\n    hidden_layers=[16, 16],\n    n_members=8,\n    seed=0,\n    loss=GaussianNLL(),\n    epochs=200,\n    validation_split=0.15,\n    lr=1e-3,\n    weight_decay=1e-4,\n    warmup_steps=5000,\n    n_samples=2000,\n    n_thinning=2,\n    patience=10,\n)\n\n# Fit the regressor\nregressor.fit(x=Xtr, y=ytr)\n\n# Get results from regressor\nmeans, sigmas = regressor.predict(Xte, mean_and_std=True)\nmean, intervals = regressor.predict(Xte, credible_intervals=[0.1, 0.9])\nraw = regressor.predict(Xte, raw=True) # (ensemble members, n_samples/n_thinning, n_test_data, (mu,sigma))\n```\n\n\n### Classification Example\n\n```python\nimport os\n\nos.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=8\"\n\nfrom sklearn.datasets import load_iris\nfrom sklearn.model_selection import train_test_split\n\nfrom bde import BdeClassifier\nfrom bde.loss import CategoricalCrossEntropy\n\niris = load_iris()\nX = iris.data.astype(\"float32\")\ny = iris.target.astype(\"int32\").ravel()\n\nX_train, X_test, y_train, y_test = train_test_split(\n    X, y, test_size=0.2, random_state=42\n)\n\n# Build the classifier\nclassifier = BdeClassifier(\n    n_members=4,\n    hidden_layers=[16, 16],\n    seed=0,\n    loss=CategoricalCrossEntropy(),\n    activation=\"relu\",\n    epochs=1000,\n    validation_split=0.15,\n    lr=1e-3,\n    warmup_steps=400,  # very few steps required for this simple dataset\n    n_samples=100,\n    n_thinning=1,\n    patience=10,\n)\n\n# Fit the classifier\nclassifier.fit(x=X_train, y=y_train)\n\n# Get results from classifier\npreds = classifier.predict(X_test)\nprobs = classifier.predict_proba(X_test)\nscore = classifier.score(X_train, y_train)\nraw = classifier.predict(X_test, raw=True) # (ensemble members, n_samples/n_thinning, n_test_data, n_classes)\n```\n\nWorkflow\n--------\n\nThe high-level estimators follow this flow during `fit` and evaluation:\n\n- `BdeRegressor` / `BdeClassifier` (`bde/bde.py`) delegate to the shared `Bde` base class.\n- `Bde.fit` validates data, resolves defaults, and calls `_build_bde()` to instantiate `BdeBuilder`.\n- `BdeBuilder.fit_members` (`bde/bde_builder.py`) trains each network, handles device padding, and applies early stopping.\n- `_build_log_post` constructs the ensemble log-posterior, then `warmup_bde` (`bde/sampler/warmup.py`) adapts step sizes before sampling.\n- Sampler utilities (`bde/sampler/*`) draw posterior samples and cache them for downstream prediction.\n- User-facing `predict` / `predict_proba` call the private `_evaluate` / `_make_predictor` (`bde/bde_evaluator.py`) to aggregate samples into means, intervals, probabilities, or raw outputs.\n\n```mermaid\nflowchart TD\n    subgraph User\n        FitCall[\"Call BdeRegressor/BdeClassifier.fit(X, y)\"]\n        PredCall[\"Call predict(...)/predict_proba(...)\"]\n    end\n\n    subgraph Bde\n        Validate[\"validate_fit_data / _prepare_targets\"]\n        Build[\"_build_bde()\"]\n        Builder[\"BdeBuilder\"]\n        Train[\"fit_members(X, y, optimizer, loss)\"]\n        LogPost[\"_build_log_post(X, y)\"]\n        WarmSampler[\"_warmup_sampler(logpost)\"]\n        Keys[\"_generate_rng_keys + _normalize_tuned_parameters\"]\n        Draw[\"_draw_samples(...) via MileWrapper.sample_batched\"]\n        Cache[\"positions_eT_ stored in estimator\"]\n        Eval[\"_evaluate(... flags ...)\"]\n        MakePred[\"_make_predictor(Xte)\"]\n    end\n\n    subgraph Warmup\n        Warm[\"warmup_bde()\"]\n        Adapter[\"custom_mclmc_warmup adapter\"]\n        Adapt[\"per-member adaptation (pmap/vmap)\"]\n        Results[\"AdaptationResults: states_e, tuned params\"]\n    end\n\n    subgraph Sampling\n        Wrapper[\"MileWrapper\"]\n        Batch[\"sample_batched(...)\"]\n        Posterior[\"Posterior samples (E x T x ...)\"]\n    end\n\n    subgraph Evaluation\n        Predictor[\"BdePredictor\"]\n        Outputs[\"Predictions (mean, std, intervals, probs, raw)\"]\n    end\n\n    FitCall --\u003e Validate --\u003e Build --\u003e Builder\n    Builder --\u003e Train --\u003e LogPost --\u003e WarmSampler --\u003e Keys --\u003e Draw --\u003e Cache\n    WarmSampler --\u003e Warm --\u003e Adapter --\u003e Adapt --\u003e Results\n    Draw --\u003e Wrapper --\u003e Batch --\u003e Posterior\n    Cache --\u003e PredCall --\u003e Eval --\u003e MakePred --\u003e Predictor --\u003e Outputs\n    Posterior --\u003e Predictor\n```\n\n\n### Datasets included in the package for testing purposes\n\n| Dataset | Source | Task |\n|---------|---------|------|\n| **Airfoil** | UCI Machine Learning Repository (Dua \u0026 Graff, 2017) | Regression |\n| **Concrete** | UCI Machine Learning Repository (Yeh, 2006) | Regression |\n| **Iris** | Fisher (1936); canonical modern version distributed via scikit-learn | Multiclass classification (setosa, versicolor, virginica) |\n","funding_links":[],"categories":[],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fscikit-learn-contrib%2Fbde","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fscikit-learn-contrib%2Fbde","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fscikit-learn-contrib%2Fbde/lists"}