https://github.com/bsc-quantic/tn4ml
Tensor Networks for Machine Learning
https://github.com/bsc-quantic/tn4ml
flax jax machine-learning matrix-product-states optax python python3 quimb tensor-network tensor-networks
Last synced: about 2 months ago
JSON representation
Tensor Networks for Machine Learning
- Host: GitHub
- URL: https://github.com/bsc-quantic/tn4ml
- Owner: bsc-quantic
- License: mit
- Created: 2022-07-12T10:18:19.000Z (over 3 years ago)
- Default Branch: master
- Last Pushed: 2025-06-02T15:47:50.000Z (9 months ago)
- Last Synced: 2025-06-02T18:56:29.763Z (9 months ago)
- Topics: flax, jax, machine-learning, matrix-product-states, optax, python, python3, quimb, tensor-network, tensor-networks
- Language: Python
- Homepage: https://tn4ml.readthedocs.io
- Size: 38.2 MB
- Stars: 18
- Watchers: 1
- Forks: 5
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README

# Tensor Networks for Machine Learning


**tn4ml** is a Python library that handles tensor networks for machine learning applications.
It is built on top of **Quimb**, for Tensor Network objects, and **JAX**, for optimization pipeline.
For now, the library supports 1D Tensor Network structures:
- **Matrix Product State**
- **Matrix Product Operator**
- **Spaced Matrix Product Operator**
It supports different **embedding** functions, **initialization** techniques, **objective functions** and **optimization strategies**.
## Installation
First create a virtualenv using `pyenv` or `conda`. Then install the package and its dependencies.
**With** `pip` (tag v1.0.5):
```bash
pip install tn4ml
```
or **directly from github**:
```bash
pip install -U git+https://github.com/bsc-quantic/tn4ml.git
```
If you want to test and edit the code, you can clone the local version of the package and install it.
```bash
git clone https://github.com/bsc-quantic/tn4ml.git
pip install -e tn4ml/
```
If you want to install dependices for *docs*, *test* and *examples*:
```zsh
pip install "tn4ml[docs]"
```
```zsh
pip install "tn4ml[test]"
```
```zsh
pip install "tn4ml[examples]"
```
**Accelerated runtime**
(Optional) To improve runtime precision set these flags:
```python
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_default_matmul_precision', 'highest')
```
**Running on GPU**
Before everything install `JAX` version that supports CUDA and its suitable for runs on GPU.
Checkout how to install here: [jax[cuda]](https://docs.jax.dev/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-via-pip-easier)
Next, at the beginning of your script set:
```python
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Use GPU 0 - or set any GPU ID
import jax
jax.config.update("jax_platform_name", 'gpu')
```
Then when training `Model` set:
```python
device = 'gpu'
model.configure(device=device)
```
## Documentation
Visit [tn4ml.readthedocs.io](https://tn4ml.readthedocs.io/en/latest/)
## Example notebooks
[TN for Classification](docs/source/examples/mnist_classification.ipynb)
[TN for Anomaly Detection](docs/source/examples/mnist_ad.ipynb)
[TN for Anomaly Detection with DMRG-like method](docs/source/examples/mnist_ad_sweeps.ipynb)
## Examples from the paper
[Breast Cancer Classification](docs/source/examples/supervised)
[Unsupervised Learning with MNIST](docs/source/examples/unsupervised)
[MPS for Anomaly Detection in the Latent Space of Proton Collision Events at the LHC](docs/source/examples/tnad_latent)
## Citation
If you use **tn4ml** in your work, please cite the following paper: [arXiv:2502.13090](https://arxiv.org/abs/2502.13090)
```bibtex
@article{puljak2025tn4mltensornetworktraining,
title={tn4ml: Tensor Network Training and Customization for Machine Learning},
author={Ema Puljak and Sergio Sanchez-Ramirez and Sergi Masot-Llima and Jofre Vallès-Muns and Artur Garcia-Saez and Maurizio Pierini},
year={2025},
eprint={2502.13090},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2502.13090},
}
```
## License
MIT license - check it out [here](LICENSE)