https://github.com/salamanderxing/jax-min-batch-kmeans
Jax implementation of Mini-batch K-Means algorithm
https://github.com/salamanderxing/jax-min-batch-kmeans
clustering-algorithm jax kmeans-algorithm mini-batch-kmeans
Last synced: about 2 months ago
JSON representation
Jax implementation of Mini-batch K-Means algorithm
- Host: GitHub
- URL: https://github.com/salamanderxing/jax-min-batch-kmeans
- Owner: SalamanderXing
- License: mit
- Created: 2022-10-27T13:19:46.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2022-10-29T23:35:15.000Z (over 2 years ago)
- Last Synced: 2025-02-24T08:39:57.712Z (about 2 months ago)
- Topics: clustering-algorithm, jax, kmeans-algorithm, mini-batch-kmeans
- Language: Python
- Homepage:
- Size: 10.7 KB
- Stars: 4
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Mini-Batch KMeans written in JAX
Just-in-time compiled, and accelerated ⚡ implementation of [Mini-Batch KMeans](https://doi.org/10.1145/1772690.1772862)[1]
## Requirements
- [JAX 😎](https://github.com/google/jax) >= 0.3.17
## Installation
```bash
git clone https://github.com/GiulioZani/jax-min-batch-kmeanscd jax-mini-batch-kmeans
```
## Usage
```python
from mini_batch_kmeans import MiniBatchKMeansdef main():
xs = # a 2D array of shape (number of samples, number of features)
mini_batch_kmeans = MiniBatchKMeans(
xs, # can be a numpy or jax array
k=4, # number of clusters
batch_size=1000, # batch size
iter=1000, # number of iterations
random_state=0
)
mini_batch_kmeans.fit()print(f"{mini_batch_kmeans.centroids=}")
```## References
[1] D. Sculley. 2010. Web-scale k-means clustering. In Proceedings of the 19th international conference on World wide web (WWW '10). Association for Computing Machinery, New York, NY, USA, 1177–1178. https://doi.org/10.1145/1772690.1772862