https://github.com/chanlumerico/hierarchical-vae
Generalization of Hierarchical VAE and its Implementation
https://github.com/chanlumerico/hierarchical-vae
Last synced: 4 months ago
JSON representation
Generalization of Hierarchical VAE and its Implementation
- Host: GitHub
- URL: https://github.com/chanlumerico/hierarchical-vae
- Owner: ChanLumerico
- License: apache-2.0
- Created: 2025-06-13T05:17:34.000Z (12 months ago)
- Default Branch: main
- Last Pushed: 2025-06-13T10:18:42.000Z (12 months ago)
- Last Synced: 2025-08-02T03:54:17.680Z (10 months ago)
- Language: Python
- Size: 1.19 MB
- Stars: 0
- Watchers: 0
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Hierarchical VAE
[๐Velog - [DL] ๊ณ์ธตํ VAE์ ์ผ๋ฐํ์ ๊ตฌํ](https://velog.io/@lumerico284/DL-%EA%B3%84%EC%B8%B5%ED%98%95-VAE%EC%9D%98-%EC%9D%BC%EB%B0%98%ED%99%94%EC%99%80-%EA%B5%AC%ED%98%84)
## ๊ตฌํ
PyTorch๋ฅผ ์ด์ฉํด **2-๊ณ์ธต VAE**๋ฅผ ๊ตฌํํจ.
### ๐ท ์ธ์ฝ๋ ๋ชจ๋
```py
class Encoder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, latent_dim: int) -> None:
super().__init__()
self.linear = nn.Linear(in_dim, hidden_dim)
self.linear_mu = nn.Linear(hidden_dim, latent_dim)
self.linear_logvar = nn.Linear(hidden_dim, latent_dim)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
h = F.relu(self.linear(x))
mu = self.linear_mu(h)
logvar = self.linear_logvar(h)
sigma = torch.exp(0.5 * logvar)
return mu, sigma
```
### ๐ถ ๋์ฝ๋ ๋ชจ๋
```py
class Decoder(nn.Module):
def __init__(
self,
latent_dim: int,
hidden_dim: int,
out_dim: int,
use_sigmoid: bool = False,
) -> None:
super().__init__()
self.linear_1 = nn.Linear(latent_dim, hidden_dim)
self.linear_2 = nn.Linear(hidden_dim, out_dim)
self.use_sigmoid = use_sigmoid
def forward(self, z: torch.Tensor) -> torch.Tensor:
h = F.relu(self.linear_1(z))
h = self.linear_2(h)
return torch.sigmoid(h) if self.use_sigmoid else h
```
### ๐ฒ ์ฌ๋งค๊ฐ๋ณ์ํ ํธ๋ฆญ ํจ์
```py
def reparameterize(mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
eps = torch.randn_like(sigma)
return mu + eps * sigma
```
### โญ ๊ณ์ธตํ VAE ํด๋์ค
```py
class HierarchicalVAE(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
latent_dim: int,
num_layers: int = 2,
use_bce: bool = True,
) -> None:
super().__init__()
assert num_layers >= 1, "Number of layers must be >= 1"
self.num_layers = num_layers
self.use_bce = use_bce
# Build encoders and decoders dynamically
dims = [input_dim] + [latent_dim] * (num_layers - 1)
self.encoders = nn.ModuleList(
[Encoder(dims[i], hidden_dim, latent_dim) for i in range(num_layers)]
)
self.decoders = nn.ModuleList()
for i in range(num_layers):
if i == 0:
# decoder for z1 -> x
self.decoders.append(
Decoder(latent_dim, hidden_dim, input_dim, use_sigmoid=True)
)
else:
# decoder for z_{i+1} -> z_i
self.decoders.append(Decoder(latent_dim, hidden_dim, latent_dim))
def get_loss(self, x: torch.Tensor) -> torch.Tensor:
batch_size = x.size(0)
# Encoding pass
mus, sigmas, zs = [], [], []
h = x
for enc in self.encoders:
mu, sigma = enc(h)
z = reparameterize(mu, sigma)
mus.append(mu)
sigmas.append(sigma)
zs.append(z)
h = z
# Decoding pass
x_hat = self.decoders[0](zs[0])
z_hats = [None] * (self.num_layers - 1)
for level in range(self.num_layers, 1, -1):
idx = level - 1
z_hats[idx - 1] = self.decoders[idx](zs[idx])
# Reconstruction loss
if self.use_bce:
L_recon = F.binary_cross_entropy(x_hat, x, reduction="sum")
else:
L_recon = F.mse_loss(x_hat, x, reduction="sum")
# KL divergence
# Top-level prior N(0,I)
mu_T, sigma_T = mus[-1], sigmas[-1]
L_kl = -torch.sum(1 + torch.log(sigma_T.pow(2)) - mu_T.pow(2) - sigma_T.pow(2))
# Intermediate levels prior N(z_hat, I)
for i in range(self.num_layers - 1):
mu_i, sigma_i = mus[i], sigmas[i]
z_hat_i = z_hats[i]
L_kl += -torch.sum(
1 + torch.log(sigma_i.pow(2)) - (mu_i - z_hat_i).pow(2) - sigma_i.pow(2)
)
return (L_recon + L_kl) / batch_size
```
> ๐ก ์์คํจ์์ ์ข
๋ฅ๋ฅผ ๋๊ฐ์ง(BCE, MSE)๋ก ์ธ๋ถํํ์์. (`use_bce=True`์ด๋ฉด BCE ์ฌ์ฉ)
## ํ๋ จ & ํ๊ฐ
### ๐ ํ๋ จ ํ๊ฒฝ
```py
input_dim = 784
hidden_dim = 100
latent_dim = 20
num_layers = 2
epochs = 30
learning_rate = 1e-3
batch_size = 64
optimizer = optim.Adam(...)
```
> ์ด ํ๋ผ๋ฏธํฐ ์: *174,084*
### ๐ ํ๋ จ ๊ฒฐ๊ณผ
#### 1๏ธโฃ ํ๋ จ ์์คํจ์ ๊ทธ๋ํ
#### 2๏ธโฃ ๋๋ค ์ซ์ ์ด๋ฏธ์ง ์์ฑ
ํ์ค์ ๊ท๋ถํฌ $\mathcal{N}(\mathbf{x};\mathbf{0},\mathbf{I})$ ์์ ํ๋ณธ 64๊ฐ๋ฅผ ์ถ์ถ, ํ์ต๋ ๊ณ์ธตํ VAE์ ์
๋ ฅํจ.
#### 3๏ธโฃ ์ ์ฌ๊ณต๊ฐ t-SNE ์๊ฐํ
๋ง์ง๋ง ์ ์ฌ๋ณ์์ ๋ถํฌ $q_{\phi_2}(z_2|z_1)$๋ฅผ t-SNE๋ฅผ ํตํด 2์ฐจ์์ผ๋ก ๋ฎ์ถฐ ์๊ฐํํจ.
#### 4๏ธโฃ ์ ์ฌ๊ณต๊ฐ์์ ํ์ต๋ ๋ค์์ฒด(Manifold) ํฌ์
์ ์ฌ๊ณต๊ฐ์ ์ฒซ๋ฒ์งธ ์ฐจ์๊ณผ ๋๋ฒ์งธ ์ฐจ์์ $\mathbb{R}\in[-3,3]$ ๋ฒ์์์ ํ์ํ์ฌ ๊ณ ์ฐจ์ ์ ์ฌ๊ณต๊ฐ ๋ด์์ ํ์ต๋ ๋ค์์ฒด๋ฅผ 2์ฐจ์์ผ๋ก ํฌ์ํจ.