https://github.com/scikit-learn-contrib/bde
Bayesian Deep Ensembles via MILE: easy to use, scikit-learn compatible and fast (JAX powered)
https://github.com/scikit-learn-contrib/bde
jax machine-learning mcmc sampling-methods scikit-learn uncertainty-quantification
Last synced: 13 days ago
JSON representation
Bayesian Deep Ensembles via MILE: easy to use, scikit-learn compatible and fast (JAX powered)
- Host: GitHub
- URL: https://github.com/scikit-learn-contrib/bde
- Owner: scikit-learn-contrib
- License: bsd-3-clause
- Created: 2025-08-20T13:46:01.000Z (6 months ago)
- Default Branch: main
- Last Pushed: 2026-02-02T16:30:50.000Z (23 days ago)
- Last Synced: 2026-02-02T21:17:15.356Z (22 days ago)
- Topics: jax, machine-learning, mcmc, sampling-methods, scikit-learn, uncertainty-quantification
- Language: Python
- Homepage: https://contrib.scikit-learn.org/bde/
- Size: 17.7 MB
- Stars: 38
- Watchers: 0
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Bayesian Deep Ensembles for scikit-learn 
[](https://github.com/scikit-learn-contrib/bde/actions/workflows/deploy-gh-pages.yml)
[](https://github.com/scikit-learn-contrib/bde/actions/workflows/python-app.yml)

[](LICENSE)
👉 **[Start Here: Complete Online Documentation](https://scikit-learn-contrib.github.io/bde/)**
Introduction
------------
**bde** is a user-friendly implementation of Bayesian Deep Ensembles compatible with
scikit-learn with a particular focus on tabular data. It exposes estimators that plug
into scikit-learn pipelines while leveraging JAX for accelerator-backed training,
sampling, and uncertainty quantification.
In particular, **bde** implements **Microcanonical Langevin Ensembles (MILE)** as
introduced in [*Microcanonical Langevin Ensembles: Advancing the Sampling of Bayesian Neural Networks* (ICLR 2025)](https://arxiv.org/abs/2502.06335).
A conceptual overview of MILE is shown below:
**Scope:** As of right now this package supports full-batch MILE for fully connected
feedforward networks, covering classification and regression on tabular data.
The method can however also be applied to other
architectures and data modalities, but these are not yet in scope of this
particular implementation.
Installation
------------
To install the latest release from PyPI, run:
```
pip install sklearn-contrib-bde
```
To install the latest development version from GitHub, run:
```
pip install git+https://github.com/scikit-learn-contrib/bde.git
```
Developer environment
---------------------
We recommend using [pixi](https://pixi.prefix.dev/latest/) to create a
deterministic development environment:
```
pixi install
# Then you can directly run examples like so:
pixi run python -m examples.example
```
Pixi ensures the correct JAX, CUDA (when needed), and scikit-learn versions are
selected automatically. See `pixi.lock` for channel and platform details.
Example Usage
-------------
Minimal runnable scripts live in `examples/`, and the snippets below highlight the
most common regression and classification workflows. When running outside those
scripts, remember to set the XLA device count so JAX allocates enough host devices (
this needs to be done before importing JAX):
```
export XLA_FLAGS="--xla_force_host_platform_device_count="
```
Adjust the value to match the number of CPU (or GPU) devices you plan to use.
### Regression Example
```python
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
import jax.numpy as jnp
from sklearn.datasets import fetch_openml
from sklearn.metrics import root_mean_squared_error
from sklearn.model_selection import train_test_split
from bde import BdeRegressor
from bde.loss import GaussianNLL
data = fetch_openml(name="airfoil_self_noise", as_frame=True) # requires pandas
X = data.data.values
y = data.target.values.reshape(-1, 1)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
Xmu, Xstd = jnp.mean(X_train, 0), jnp.std(X_train, 0) + 1e-8
Ymu, Ystd = jnp.mean(y_train, 0), jnp.std(y_train, 0) + 1e-8
Xtr = (X_train - Xmu) / Xstd
Xte = (X_test - Xmu) / Xstd
ytr = (y_train - Ymu) / Ystd
yte = (y_test - Ymu) / Ystd
# Build the regressor
regressor = BdeRegressor(
hidden_layers=[16, 16],
n_members=8,
seed=0,
loss=GaussianNLL(),
epochs=200,
validation_split=0.15,
lr=1e-3,
weight_decay=1e-4,
warmup_steps=5000,
n_samples=2000,
n_thinning=2,
patience=10,
)
# Fit the regressor
regressor.fit(x=Xtr, y=ytr)
# Get results from regressor
means, sigmas = regressor.predict(Xte, mean_and_std=True)
mean, intervals = regressor.predict(Xte, credible_intervals=[0.1, 0.9])
raw = regressor.predict(Xte, raw=True) # (ensemble members, n_samples/n_thinning, n_test_data, (mu,sigma))
```
### Classification Example
```python
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from bde import BdeClassifier
from bde.loss import CategoricalCrossEntropy
iris = load_iris()
X = iris.data.astype("float32")
y = iris.target.astype("int32").ravel()
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
# Build the classifier
classifier = BdeClassifier(
n_members=4,
hidden_layers=[16, 16],
seed=0,
loss=CategoricalCrossEntropy(),
activation="relu",
epochs=1000,
validation_split=0.15,
lr=1e-3,
warmup_steps=400, # very few steps required for this simple dataset
n_samples=100,
n_thinning=1,
patience=10,
)
# Fit the classifier
classifier.fit(x=X_train, y=y_train)
# Get results from classifier
preds = classifier.predict(X_test)
probs = classifier.predict_proba(X_test)
score = classifier.score(X_train, y_train)
raw = classifier.predict(X_test, raw=True) # (ensemble members, n_samples/n_thinning, n_test_data, n_classes)
```
Workflow
--------
The high-level estimators follow this flow during `fit` and evaluation:
- `BdeRegressor` / `BdeClassifier` (`bde/bde.py`) delegate to the shared `Bde` base class.
- `Bde.fit` validates data, resolves defaults, and calls `_build_bde()` to instantiate `BdeBuilder`.
- `BdeBuilder.fit_members` (`bde/bde_builder.py`) trains each network, handles device padding, and applies early stopping.
- `_build_log_post` constructs the ensemble log-posterior, then `warmup_bde` (`bde/sampler/warmup.py`) adapts step sizes before sampling.
- Sampler utilities (`bde/sampler/*`) draw posterior samples and cache them for downstream prediction.
- User-facing `predict` / `predict_proba` call the private `_evaluate` / `_make_predictor` (`bde/bde_evaluator.py`) to aggregate samples into means, intervals, probabilities, or raw outputs.
```mermaid
flowchart TD
subgraph User
FitCall["Call BdeRegressor/BdeClassifier.fit(X, y)"]
PredCall["Call predict(...)/predict_proba(...)"]
end
subgraph Bde
Validate["validate_fit_data / _prepare_targets"]
Build["_build_bde()"]
Builder["BdeBuilder"]
Train["fit_members(X, y, optimizer, loss)"]
LogPost["_build_log_post(X, y)"]
WarmSampler["_warmup_sampler(logpost)"]
Keys["_generate_rng_keys + _normalize_tuned_parameters"]
Draw["_draw_samples(...) via MileWrapper.sample_batched"]
Cache["positions_eT_ stored in estimator"]
Eval["_evaluate(... flags ...)"]
MakePred["_make_predictor(Xte)"]
end
subgraph Warmup
Warm["warmup_bde()"]
Adapter["custom_mclmc_warmup adapter"]
Adapt["per-member adaptation (pmap/vmap)"]
Results["AdaptationResults: states_e, tuned params"]
end
subgraph Sampling
Wrapper["MileWrapper"]
Batch["sample_batched(...)"]
Posterior["Posterior samples (E x T x ...)"]
end
subgraph Evaluation
Predictor["BdePredictor"]
Outputs["Predictions (mean, std, intervals, probs, raw)"]
end
FitCall --> Validate --> Build --> Builder
Builder --> Train --> LogPost --> WarmSampler --> Keys --> Draw --> Cache
WarmSampler --> Warm --> Adapter --> Adapt --> Results
Draw --> Wrapper --> Batch --> Posterior
Cache --> PredCall --> Eval --> MakePred --> Predictor --> Outputs
Posterior --> Predictor
```
### Datasets included in the package for testing purposes
| Dataset | Source | Task |
|---------|---------|------|
| **Airfoil** | UCI Machine Learning Repository (Dua & Graff, 2017) | Regression |
| **Concrete** | UCI Machine Learning Repository (Yeh, 2006) | Regression |
| **Iris** | Fisher (1936); canonical modern version distributed via scikit-learn | Multiclass classification (setosa, versicolor, virginica) |