An open API service indexing awesome lists of open source software.

https://github.com/developer0hye/torch-batchnorm-from-scratch


https://github.com/developer0hye/torch-batchnorm-from-scratch

Last synced: about 2 months ago
JSON representation

Awesome Lists containing this project

README

          

# Torch-BatchNorm-From-Scratch

```python

from typing import Optional, Any

import torch
from torch import Tensor
from torch import nn

class BatchNorm2d(nn.Module):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(BatchNorm2d, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = nn.Parameter(torch.empty(num_features, **factory_kwargs))
self.bias = nn.Parameter(torch.empty(num_features, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
self.running_mean: Optional[Tensor]
self.running_var: Optional[Tensor]
self.register_buffer('num_batches_tracked',
torch.tensor(0, dtype=torch.long,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
self.num_batches_tracked: Optional[Tensor]
else:
self.register_buffer("running_mean", None)
self.register_buffer("running_var", None)
self.register_buffer("num_batches_tracked", None)
self.reset_parameters()

def reset_running_stats(self) -> None:
if self.track_running_stats:
# running_mean/running_var/num_batches... are registered at runtime depending
# if self.track_running_stats is on
self.running_mean.zero_() # type: ignore[union-attr]
self.running_var.fill_(1) # type: ignore[union-attr]
self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]

def reset_parameters(self) -> None:
self.reset_running_stats()
if self.affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)

def forward(self, x: Tensor) -> Tensor:
if self.training and self.track_running_stats:
mean = x.mean(dim=[0, 2, 3], keepdim=True)
var = ((x - mean) ** 2).mean(dim=[0, 2, 3], keepdim=True)
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.view(-1)
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var.view(-1)
else:
var, mean = self.running_var, self.running_mean

x = (x-mean.view(1, self.num_features, 1, 1))/torch.sqrt(var.view(1, self.num_features, 1, 1)+self.eps)
x = self.weight.view(1, self.num_features, 1, 1) * x + self.bias.view(1, self.num_features, 1, 1)
return x
```