https://github.com/scbirlab/duvida
🧐 Calculating exact and approximate confidence and information metrics for deep learning on general purpose and chemistry tasks.
https://github.com/scbirlab/duvida
active-learning ai chemistry confidence-estimation confidence-intervals hessian hessian-vector-product jax torch
Last synced: 5 months ago
JSON representation
🧐 Calculating exact and approximate confidence and information metrics for deep learning on general purpose and chemistry tasks.
- Host: GitHub
- URL: https://github.com/scbirlab/duvida
- Owner: scbirlab
- License: mit
- Created: 2025-03-24T18:02:34.000Z (about 1 year ago)
- Default Branch: main
- Last Pushed: 2025-10-15T21:26:44.000Z (8 months ago)
- Last Synced: 2025-10-26T08:23:46.676Z (8 months ago)
- Topics: active-learning, ai, chemistry, confidence-estimation, confidence-intervals, hessian, hessian-vector-product, jax, torch
- Language: Python
- Homepage:
- Size: 21.6 MB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# 🧐 duvida



**duvida** (Portuguese for _doubt_) is a suite of python tools for calculating confidence and information metrics
for deep learning. It provides lower-level function transforms for exact and approximate Hessian diagonals
in JAX and pytorch.
- [Installation](#installation)
- [Python API](#python-api)
- [Issues, problems, suggestions](#issues-problems-suggestions)
- [Documentation](#documentation)
## Installation
### The easy way
You can install the precompiled version directly using `pip`. You need to specify the machine learning framework
that you want to use:
```bash
$ pip install duvida[jax]
# or
$ pip install duvida[jax_cuda12] # for JAX installing CUDA 12 for GPU support
# or
$ pip install duvida[jax_cuda12_local] # for JAX using a locally-installed CUDA 12
# or
$ pip install duvida[torch]
```
We have implemented JAX and pytorch functional transformations for approximate and exact Hessian diagonals,
and doubtscore and information sensitivity. These can be used with JAX- and pytorch-based frameworks.
### From source
Clone the repository, then `cd` into it. Then run:
```bash
$ pip install -e .[torch]
```
## Python API
**duvida** provides functional transforms for JAX and pytorch that calculate
either exact or approximate Hessian diagonals.
You can check which backend you're using:
```python
>>> from duvida.stateless.config import config
>>> config
Config(backend='jax', precision='double', fallback=True)
```
It can be changed:
```python
>>> config.set_backend("torch")
'torch'
>>> config
Config(backend='torch', precision='double', fallback=True)
```
Now you can calculate exact Hessian diagonals without calculating the
full matrix:
```python
>>> from duvida.stateless.utils import hessian
>>> import duvida.stateless.numpy as dnp
>>> f = lambda x: dnp.sum(x ** 3. + x ** 2. + 4.)
>>> a = dnp.array([1., 2.])
>>> exact_diagonal(f)(a) == dnp.diag(hessian(f)(a))
Array([ True, True], dtype=bool)
```
Various approximations are also allowed.
```python
>>> from duvida.stateless.hessians import get_approximators
>>> get_approximators() # Use no arguments to show what's available
('squared_jacobian', 'exact_diagonal', 'bekas', 'rough_finite_difference')
```
Now apply:
```python
>>> approx_hessian_diag = get_approximators("bekas")
>>> g = lambda x: dnp.sum(dnp.sum(x) ** 3. + x ** 2. + 4.)
>>> a = dnp.array([1., 2.])
>>> dnp.diag(hessian(g)(a)) # Exact
Array([38., 38.], dtype=float64)
>>> approx_hessian_diag(g, n=1000)(a) # Less accurate when parameters interact
Array([38.52438307, 38.49679655], dtype=float64)
>>> approx_hessian_diag(g, n=1000, seed=1)(a) # Change the seed to alter the outcome
Array([39.07878869, 38.97796601], dtype=float64)
```
## Issues, problems, suggestions
Add to the [issue tracker](https://www.github.com/scbirlab/duvida/issues).
## Documentation
(To come at [ReadTheDocs](https://duvida.readthedocs.org).)