https://github.com/finite-sample/act
Augmentation Consistency Training
https://github.com/finite-sample/act
Last synced: 11 months ago
JSON representation
Augmentation Consistency Training
- Host: GitHub
- URL: https://github.com/finite-sample/act
- Owner: finite-sample
- Created: 2025-07-07T23:21:40.000Z (12 months ago)
- Default Branch: main
- Last Pushed: 2025-07-08T20:48:23.000Z (11 months ago)
- Last Synced: 2025-07-08T21:38:12.523Z (11 months ago)
- Language: Jupyter Notebook
- Size: 127 KB
- Stars: 1
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Augmentation Consistency Training for Robust Visual Recognition
Standard data augmentation exposes models to diverse views of each object, improving generalization. However, augmentation typically happens **implicitly**—we simply train on the augmented dataset without explicitly enforcing that the model responds similarly to different views of the same image.
We propose **Augmentation Consistency (AC) training**: an explicit loss that directly penalizes prediction disagreement across multiple augmented views of each input. This should yield more robust internal representations and improved stability under input variations.
## Method
### Core Idea
For each training batch, generate $k$ label-preserving augmentations of each input and enforce that the model produces similar predictions across all views.
### Mathematical Formulation
Given input batch $\mathbf{x} \in \mathbb{R}^{B \times C \times H \times W}$ with labels $\mathbf{y}$, generate $k$ augmentations:
$\{\tilde{\mathbf{x}}_1, \tilde{\mathbf{x}}_2, \ldots, \tilde{\mathbf{x}}_k\} = \{T_1(\mathbf{x}), T_2(\mathbf{x}), \ldots, T_k(\mathbf{x})\}$
where $T_i$ are label-preserving transformations (rotation, flip, brightness, noise).
Obtain predictions for each view:
$$\mathbf{z}_i = f_\theta(\tilde{\mathbf{x}}_i) \in \mathbb{R}^{B \times C}$$
$$\mathbf{p}_i = \text{softmax}(\mathbf{z}_i) \in \mathbb{R}^{B \times C}$$
The **Augmentation Consistency Loss** combines standard cross-entropy with a consistency penalty:
$$\mathcal{L}_{\text{pred}} = \frac{1}{k} \sum_{i=1}^k \text{CrossEntropy}(\mathbf{z}_i, \mathbf{y})$$
$$\mathcal{L}_{\text{consistency}} = \frac{1}{\binom{k}{2}} \sum_{1 \leq i < j \leq k} \frac{1}{2}\left[ D_{\text{KL}}(\mathbf{p}_i \| \mathbf{p}_j) + D_{\text{KL}}(\mathbf{p}_j \| \mathbf{p}_i) \right]$$
$$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{pred}} + \lambda \mathcal{L}_{\text{consistency}}$$
where $D_{\text{KL}}$ is the Kullback-Leibler divergence and $\lambda$ controls the consistency strength.
### Implementation
```python
def augmentation_consistency_loss(model, x, y, k=4, lambda_consistency=0.15):
# Generate k augmented views
augmented_views = [augment(x) for _ in range(k)]
# Forward pass through all views
logits = [model(view) for view in augmented_views]
probs = [F.softmax(logit, dim=1) for logit in logits]
# Prediction loss (average across views)
pred_loss = torch.mean(torch.stack([
F.cross_entropy(logit, y) for logit in logits
]))
# Consistency loss (symmetric KL between all pairs)
consistency_loss = 0.0
n_pairs = 0
for i in range(k):
for j in range(i+1, k):
kl_ij = F.kl_div(F.log_softmax(logits[i], dim=1), probs[j], reduction='batchmean')
kl_ji = F.kl_div(F.log_softmax(logits[j], dim=1), probs[i], reduction='batchmean')
consistency_loss += 0.5 * (kl_ij + kl_ji)
n_pairs += 1
consistency_loss /= n_pairs
return pred_loss + lambda_consistency * consistency_loss
```
## Relationship to Prior Work
### Data Augmentation
Data augmentation—training on transformed versions of inputs—has been fundamental to computer vision since the 1990s. The standard approach simply expands the training set with augmented examples, implicitly hoping the model learns invariance to these transformations. However, there is no explicit mechanism ensuring the model actually produces consistent predictions across different views of the same input.
### Consistency Regularization
The idea of explicitly enforcing prediction consistency across input transformations emerged in the mid-2010s for semi-supervised learning:
**Sajjadi et al. (2016)** introduced consistency regularization, penalizing disagreement between predictions on original and augmented versions of unlabeled data: $\mathcal{L}_{\text{consistency}} = \mathbb{E}_{x,T}[||f(x) - f(T(x))||^2]$ where $T$ is a stochastic transformation.
**Laine & Aila (2017)** developed the **π-model**, which applies this principle systematically: for each unlabeled example $x$, generate augmented version $\tilde{x}$ and minimize $||p_\theta(y|x) - p_\theta(y|\tilde{x})||^2$. This forces the model to give consistent predictions even when it doesn't know the true label.
**Temporal Ensembling** (Laine & Aila, 2017) extends π-model by maintaining exponential moving averages of predictions, reducing noise in the consistency targets.
These methods target semi-supervised learning: leveraging abundant unlabeled data by enforcing consistency across augmentations.
### Contrastive Learning
A parallel development uses augmentation for representation learning rather than prediction consistency:
**SimCLR** (Chen et al., 2020) learns representations by pulling together embeddings of augmented views while pushing apart embeddings of different images. The goal is invariant representations, not consistent predictions.
**MoCo** (He et al., 2020) uses momentum updates and large queues for efficient contrastive learning.
These methods operate in representation space and typically require large negative sample sets, making them computationally expensive.
### Supervised Consistency Training
More recent work applies consistency regularization in supervised settings:
**UDA** (Xie et al., 2020) uses consistency regularization with sophisticated augmentations for supervised learning, focusing on accuracy improvements rather than stability.
**FixMatch** (Sohn et al., 2020) combines consistency regularization with pseudo-labeling for semi-supervised learning.
### Our Contribution
Our work differs from prior consistency methods in several key aspects:
1. **Multi-view consistency**: While π-model enforces consistency between original and single augmented version, we enforce agreement across all pairs of $k$ augmented views ($k(k-1)/2$ constraints vs. 1)
2. **Stability evaluation**: Previous work measures accuracy improvements; we systematically evaluate cross-training stability using Turney's protocol—whether models trained on different data splits give consistent predictions
3. **Supervised robustness focus**: Rather than leveraging unlabeled data (semi-supervised) or learning representations (contrastive), we use consistency to improve model robustness in standard supervised settings
4. **Systematic analysis**: We provide the first comprehensive study of how augmentation consistency affects model stability across different data regimes (small samples, label noise, etc.)
While our technical approach builds directly on π-model's consistency principle, our contribution is empirical: demonstrating that multi-view consistency improves cross-training stability, not just accuracy.
## Experiments
### Datasets
- **CIFAR-10 Small**: 2K samples (high sampling variance)
- **CIFAR-10 Noisy**: 15% corrupted labels (robustness test)
- **Fashion-MNIST**: 4K grayscale samples (baseline comparison)
### Augmentations ($k=4$)
1. Slight Gaussian noise ($\sigma = 0.02$)
2. 90° rotation + noise
3. Brightness/contrast adjustment + noise
4. Horizontal flip + noise
### Results
| Dataset | Standard Stability | AC Stability | Δ Stability | Standard Accuracy | AC Accuracy | Δ Accuracy |
|---------|-------------------|--------------|-------------|-------------------|-------------|------------|
| CIFAR-10 Small | 0.419 ± 0.034 | **0.467 ± 0.041** | **+0.048** | 0.518 ± 0.028 | 0.453 ± 0.031 | -0.065 |
| CIFAR-10 Noisy | 0.344 ± 0.029 | **0.366 ± 0.033** | **+0.021** | 0.422 ± 0.025 | 0.372 ± 0.028 | -0.050 |
| Fashion-MNIST | 0.874 ± 0.018 | 0.872 ± 0.021 | -0.002 | 0.884 ± 0.012 | **0.888 ± 0.014** | +0.004 |
*Stability measured as prediction agreement between models trained on disjoint data splits, evaluated on held-out test set.*
### Key Findings
- **Stability gains** on datasets with high variance or label noise (+5-11%)
- **Diminishing returns** when baseline stability is already high (Fashion-MNIST)
- **Accuracy trade-offs** in noisy settings where consistency fights corrupted labels
- **Computational overhead**: $k$× forward passes per batch
## Discussion
### When Does AC Help?
- **Small datasets**: Higher sampling variance → greater benefit from consistency constraints
- **Noisy labels**: Forces focus on robust visual features rather than memorizing noise
- **Complex augmentation spaces**: Rich transformation sets provide stronger consistency signals
### Comparison to Contrastive Learning
- **Contrastive**: Learns representations that cluster augmented views, separate different images
- **AC**: Directly optimizes prediction consistency across augmented views
- **Complementary**: Could combine both approaches for representation + prediction consistency
### Limitations
- **Domain-specific**: Requires meaningful label-preserving augmentations
- **Computational cost**: $k$× increase in forward passes
- **Hyperparameter sensitivity**: $\lambda$ requires tuning for each domain
## Future Directions
1. **Adaptive augmentation strength**: Vary transformation intensity during training
2. **Learnable augmentations**: Discover optimal transformations for each dataset
3. **Hybrid approaches**: Combine with dropout consistency for dual robustness
4. **Theoretical analysis**: Formal connections between input consistency and generalization
## Conclusion
Augmentation Consistency training provides a direct method for improving model robustness through explicit consistency constraints. While computationally more expensive than standard augmentation, it offers meaningful stability improvements in challenging scenarios with limited data or label noise. The approach complements existing consistency regularization and contrastive learning methods, opening avenues for more robust visual recognition systems.
---
### References
- Chen, T., Kornblith, S., Norouzi, M., & Hinton, G. (2020). A simple framework for contrastive learning of visual representations. *ICML*.
- He, K., Fan, H., Wu, Y., Xie, S., & Girshick, R. (2020). Momentum contrast for unsupervised visual representation learning. *CVPR*.
- Laine, S., & Aila, T. (2017). Temporal ensembling for semi-supervised learning. *ICLR*.
- LeCun, Y., Bottou, L., Bengio, Y., & Haffner, P. (1998). Gradient-based learning applied to document recognition. *Proceedings of the IEEE*.
- Xie, Q., Dai, Z., Hovy, E., Luong, M. T., & Le, Q. V. (2020). Unsupervised data augmentation for consistency training. *NeurIPS*.