An open API service indexing awesome lists of open source software.

https://github.com/ryusudol/centered-kernel-alignment

CKA with Efficient Computation and Layer-wise Visualization for PyTorch
https://github.com/ryusudol/centered-kernel-alignment

centered-kernel-alignment cka machine-learning neural-networks pytorch representation similarity

Last synced: 4 months ago
JSON representation

CKA with Efficient Computation and Layer-wise Visualization for PyTorch

Awesome Lists containing this project

README

          

# pytorch-cka

[![PyPI](https://img.shields.io/pypi/v/pytorch-cka.svg)](https://pypi.org/project/pytorch-cka/)
[![Python](https://img.shields.io/badge/python-3.10%2B-blue)](https://pypi.org/project/pytorch-cka/)
[![PyPI Downloads](https://static.pepy.tech/personalized-badge/pytorch-cka?period=total&units=INTERNATIONAL_SYSTEM&left_color=GREY&right_color=RED&left_text=downloads)](https://pepy.tech/projects/pytorch-cka)

**The Fastest, Memory-efficient Python Library for computing layer-wise similarity between neural network models**





A bar chart with benchmark results in light mode


44x faster CKA computation across 18 representational layers of ResNet-18 models on CIFAR-10 using NVIDIA H100 GPUs

- ⚡️ Fastest among CKA libraries thanks to **vectorized ops** & **GPU acceleration**
- 📦 Efficient memory management with explicit deallocation
- 🧠 Supports HuggingFace models, DataParallel, and DDP
- 🎨 Customizable visualizations: heatmaps and line charts

## 📦 Installation

Requires `Python 3.10+`

```bash
# Using pip
pip install pytorch-cka

# Using uv
uv add pytorch-cka
```

## 👟 Quick Start

### Basic Usage

```python
from cka import compute_cka
from torch.utils.data import DataLoader
from torchvision.models import resnet18, resnet34

resnet_18 = resnet18(pretrained=True)
resnet_34 = resnet34(pretrained=True)

dataloader1 = Dataloader(your_dataset1, batch_size=bach_size, shuffle=False, num_workers=4)
dataloader2 = Dataloader(your_dataset2, batch_size=bach_size, shuffle=False, num_workers=4)
dataloader3 = Dataloader(your_dataset3, batch_size=bach_size, shuffle=False, num_workers=4)
dataloaders = [dataloader1, dataloader2, dataloader3]

layers = [
'conv1',
'layer1.0.conv1',
'layer2.0.conv1',
'layer3.0.conv1',
'layer4.0.conv1',
'fc',
]

cka_matrices = compute_cka(
resnet_18,
resnet_34,
dataloaders,
layers=layers,
device=device,
)

for cka_matrix in cka_matrices:
print(cka_matrix)
```

### Visualization

**Heatmap**

```python
from cka import plot_cka_heatmap

fig, ax = plot_cka_heatmap(
cka_matrix,
layers1=layers,
layers2=layers,
model1_name="ResNet-18 (pretrained)",
model2_name="ResNet-18 (random init)",
annot=False, # Show values in cells
cmap="inferno", # Colormap
)
```



Self-comparison heatmap


Cross-model comparison heatmap



Self-comparison
Cross-model

**Trend Plot**

```python
from cka import plot_cka_trend

# Plot diagonal (self-similarity across layers)
diagonal = torch.diag(matrix)

fig, ax = plot_cka_trend(
layer_trends,
x_values=epochs,
labels=RESNET18_LAYERS,
markers=['o'],
xlabel='Epoch',
ylabel='CKA Score',
title='Pretrained vs. Fine-tuned Across Epochs (ResNet-18)',
legend=True,
)

fig, ax = plot_cka_layer_trend(
cka_matrices,
layers=RESNET18_LAYERS,
labels=cka_loader_names,
ylabel='CKA Score',
title='Pretrained vs. Fine-tuned Across Layers (ResNet-18)',
legend=True,
)
```


CKA Score Trend Across Epochs
CKA Score Trend Across Layers


CKA Score Trend Across Epochs
CKA Score Trend Across Layers

## 📚 References

Kornblith, Simon, et al. ["Similarity of Neural Network Representations Revisited."](https://arxiv.org/abs/1905.00414) _ICML 2019._