https://github.com/ryusudol/centered-kernel-alignment
CKA with Efficient Computation and Layer-wise Visualization for PyTorch
https://github.com/ryusudol/centered-kernel-alignment
centered-kernel-alignment cka machine-learning neural-networks pytorch representation similarity
Last synced: 4 months ago
JSON representation
CKA with Efficient Computation and Layer-wise Visualization for PyTorch
- Host: GitHub
- URL: https://github.com/ryusudol/centered-kernel-alignment
- Owner: ryusudol
- License: mit
- Created: 2025-12-15T21:36:56.000Z (6 months ago)
- Default Branch: main
- Last Pushed: 2026-01-14T09:06:25.000Z (5 months ago)
- Last Synced: 2026-01-14T13:16:06.865Z (5 months ago)
- Topics: centered-kernel-alignment, cka, machine-learning, neural-networks, pytorch, representation, similarity
- Language: Python
- Homepage:
- Size: 653 KB
- Stars: 3
- Watchers: 1
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# pytorch-cka
[](https://pypi.org/project/pytorch-cka/)
[](https://pypi.org/project/pytorch-cka/)
[](https://pepy.tech/projects/pytorch-cka)
**The Fastest, Memory-efficient Python Library for computing layer-wise similarity between neural network models**
44x faster CKA computation across 18 representational layers of ResNet-18 models on CIFAR-10 using NVIDIA H100 GPUs
- ⚡️ Fastest among CKA libraries thanks to **vectorized ops** & **GPU acceleration**
- 📦 Efficient memory management with explicit deallocation
- 🧠 Supports HuggingFace models, DataParallel, and DDP
- 🎨 Customizable visualizations: heatmaps and line charts
## 📦 Installation
Requires `Python 3.10+`
```bash
# Using pip
pip install pytorch-cka
# Using uv
uv add pytorch-cka
```
## 👟 Quick Start
### Basic Usage
```python
from cka import compute_cka
from torch.utils.data import DataLoader
from torchvision.models import resnet18, resnet34
resnet_18 = resnet18(pretrained=True)
resnet_34 = resnet34(pretrained=True)
dataloader1 = Dataloader(your_dataset1, batch_size=bach_size, shuffle=False, num_workers=4)
dataloader2 = Dataloader(your_dataset2, batch_size=bach_size, shuffle=False, num_workers=4)
dataloader3 = Dataloader(your_dataset3, batch_size=bach_size, shuffle=False, num_workers=4)
dataloaders = [dataloader1, dataloader2, dataloader3]
layers = [
'conv1',
'layer1.0.conv1',
'layer2.0.conv1',
'layer3.0.conv1',
'layer4.0.conv1',
'fc',
]
cka_matrices = compute_cka(
resnet_18,
resnet_34,
dataloaders,
layers=layers,
device=device,
)
for cka_matrix in cka_matrices:
print(cka_matrix)
```
### Visualization
**Heatmap**
```python
from cka import plot_cka_heatmap
fig, ax = plot_cka_heatmap(
cka_matrix,
layers1=layers,
layers2=layers,
model1_name="ResNet-18 (pretrained)",
model2_name="ResNet-18 (random init)",
annot=False, # Show values in cells
cmap="inferno", # Colormap
)
```
Self-comparison
Cross-model
**Trend Plot**
```python
from cka import plot_cka_trend
# Plot diagonal (self-similarity across layers)
diagonal = torch.diag(matrix)
fig, ax = plot_cka_trend(
layer_trends,
x_values=epochs,
labels=RESNET18_LAYERS,
markers=['o'],
xlabel='Epoch',
ylabel='CKA Score',
title='Pretrained vs. Fine-tuned Across Epochs (ResNet-18)',
legend=True,
)
fig, ax = plot_cka_layer_trend(
cka_matrices,
layers=RESNET18_LAYERS,
labels=cka_loader_names,
ylabel='CKA Score',
title='Pretrained vs. Fine-tuned Across Layers (ResNet-18)',
legend=True,
)
```
CKA Score Trend Across Epochs
CKA Score Trend Across Layers
## 📚 References
Kornblith, Simon, et al. ["Similarity of Neural Network Representations Revisited."](https://arxiv.org/abs/1905.00414) _ICML 2019._