https://github.com/developer0hye/torch-batchnorm-from-scratch
https://github.com/developer0hye/torch-batchnorm-from-scratch
Last synced: about 2 months ago
JSON representation
- Host: GitHub
- URL: https://github.com/developer0hye/torch-batchnorm-from-scratch
- Owner: developer0hye
- License: mit
- Created: 2022-03-20T04:31:44.000Z (about 4 years ago)
- Default Branch: main
- Last Pushed: 2022-03-20T04:32:21.000Z (about 4 years ago)
- Last Synced: 2025-10-19T00:53:36.734Z (6 months ago)
- Size: 2.93 KB
- Stars: 2
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Torch-BatchNorm-From-Scratch
```python
from typing import Optional, Any
import torch
from torch import Tensor
from torch import nn
class BatchNorm2d(nn.Module):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(BatchNorm2d, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = nn.Parameter(torch.empty(num_features, **factory_kwargs))
self.bias = nn.Parameter(torch.empty(num_features, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
self.running_mean: Optional[Tensor]
self.running_var: Optional[Tensor]
self.register_buffer('num_batches_tracked',
torch.tensor(0, dtype=torch.long,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
self.num_batches_tracked: Optional[Tensor]
else:
self.register_buffer("running_mean", None)
self.register_buffer("running_var", None)
self.register_buffer("num_batches_tracked", None)
self.reset_parameters()
def reset_running_stats(self) -> None:
if self.track_running_stats:
# running_mean/running_var/num_batches... are registered at runtime depending
# if self.track_running_stats is on
self.running_mean.zero_() # type: ignore[union-attr]
self.running_var.fill_(1) # type: ignore[union-attr]
self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
def reset_parameters(self) -> None:
self.reset_running_stats()
if self.affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x: Tensor) -> Tensor:
if self.training and self.track_running_stats:
mean = x.mean(dim=[0, 2, 3], keepdim=True)
var = ((x - mean) ** 2).mean(dim=[0, 2, 3], keepdim=True)
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.view(-1)
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var.view(-1)
else:
var, mean = self.running_var, self.running_mean
x = (x-mean.view(1, self.num_features, 1, 1))/torch.sqrt(var.view(1, self.num_features, 1, 1)+self.eps)
x = self.weight.view(1, self.num_features, 1, 1) * x + self.bias.view(1, self.num_features, 1, 1)
return x
```