An open API service indexing awesome lists of open source software.

https://github.com/chanlumerico/hierarchical-vae

Generalization of Hierarchical VAE and its Implementation
https://github.com/chanlumerico/hierarchical-vae

Last synced: 4 months ago
JSON representation

Generalization of Hierarchical VAE and its Implementation

Awesome Lists containing this project

README

          

# Hierarchical VAE

[๐Ÿ”—Velog - [DL] ๊ณ„์ธตํ˜• VAE์˜ ์ผ๋ฐ˜ํ™”์™€ ๊ตฌํ˜„](https://velog.io/@lumerico284/DL-%EA%B3%84%EC%B8%B5%ED%98%95-VAE%EC%9D%98-%EC%9D%BC%EB%B0%98%ED%99%94%EC%99%80-%EA%B5%AC%ED%98%84)

## ๊ตฌํ˜„

PyTorch๋ฅผ ์ด์šฉํ•ด **2-๊ณ„์ธต VAE**๋ฅผ ๊ตฌํ˜„ํ•จ.

### ๐Ÿ”ท ์ธ์ฝ”๋” ๋ชจ๋“ˆ

```py
class Encoder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, latent_dim: int) -> None:
super().__init__()
self.linear = nn.Linear(in_dim, hidden_dim)
self.linear_mu = nn.Linear(hidden_dim, latent_dim)
self.linear_logvar = nn.Linear(hidden_dim, latent_dim)

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
h = F.relu(self.linear(x))
mu = self.linear_mu(h)
logvar = self.linear_logvar(h)

sigma = torch.exp(0.5 * logvar)
return mu, sigma
```

### ๐Ÿ”ถ ๋””์ฝ”๋” ๋ชจ๋“ˆ

```py
class Decoder(nn.Module):
def __init__(
self,
latent_dim: int,
hidden_dim: int,
out_dim: int,
use_sigmoid: bool = False,
) -> None:
super().__init__()
self.linear_1 = nn.Linear(latent_dim, hidden_dim)
self.linear_2 = nn.Linear(hidden_dim, out_dim)
self.use_sigmoid = use_sigmoid

def forward(self, z: torch.Tensor) -> torch.Tensor:
h = F.relu(self.linear_1(z))
h = self.linear_2(h)
return torch.sigmoid(h) if self.use_sigmoid else h
```

### ๐ŸŽฒ ์žฌ๋งค๊ฐœ๋ณ€์ˆ˜ํ™” ํŠธ๋ฆญ ํ•จ์ˆ˜

```py
def reparameterize(mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
eps = torch.randn_like(sigma)
return mu + eps * sigma
```

### โญ ๊ณ„์ธตํ˜• VAE ํด๋ž˜์Šค

```py
class HierarchicalVAE(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
latent_dim: int,
num_layers: int = 2,
use_bce: bool = True,
) -> None:
super().__init__()
assert num_layers >= 1, "Number of layers must be >= 1"
self.num_layers = num_layers
self.use_bce = use_bce

# Build encoders and decoders dynamically
dims = [input_dim] + [latent_dim] * (num_layers - 1)
self.encoders = nn.ModuleList(
[Encoder(dims[i], hidden_dim, latent_dim) for i in range(num_layers)]
)
self.decoders = nn.ModuleList()
for i in range(num_layers):
if i == 0:
# decoder for z1 -> x
self.decoders.append(
Decoder(latent_dim, hidden_dim, input_dim, use_sigmoid=True)
)
else:
# decoder for z_{i+1} -> z_i
self.decoders.append(Decoder(latent_dim, hidden_dim, latent_dim))

def get_loss(self, x: torch.Tensor) -> torch.Tensor:
batch_size = x.size(0)

# Encoding pass
mus, sigmas, zs = [], [], []
h = x
for enc in self.encoders:
mu, sigma = enc(h)
z = reparameterize(mu, sigma)
mus.append(mu)
sigmas.append(sigma)
zs.append(z)
h = z

# Decoding pass
x_hat = self.decoders[0](zs[0])
z_hats = [None] * (self.num_layers - 1)
for level in range(self.num_layers, 1, -1):
idx = level - 1
z_hats[idx - 1] = self.decoders[idx](zs[idx])

# Reconstruction loss
if self.use_bce:
L_recon = F.binary_cross_entropy(x_hat, x, reduction="sum")
else:
L_recon = F.mse_loss(x_hat, x, reduction="sum")

# KL divergence
# Top-level prior N(0,I)
mu_T, sigma_T = mus[-1], sigmas[-1]
L_kl = -torch.sum(1 + torch.log(sigma_T.pow(2)) - mu_T.pow(2) - sigma_T.pow(2))

# Intermediate levels prior N(z_hat, I)
for i in range(self.num_layers - 1):
mu_i, sigma_i = mus[i], sigmas[i]
z_hat_i = z_hats[i]
L_kl += -torch.sum(
1 + torch.log(sigma_i.pow(2)) - (mu_i - z_hat_i).pow(2) - sigma_i.pow(2)
)

return (L_recon + L_kl) / batch_size
```

> ๐Ÿ’ก ์†์‹คํ•จ์ˆ˜์˜ ์ข…๋ฅ˜๋ฅผ ๋‘๊ฐ€์ง€(BCE, MSE)๋กœ ์„ธ๋ถ„ํ™”ํ•˜์˜€์Œ. (`use_bce=True`์ด๋ฉด BCE ์‚ฌ์šฉ)

## ํ›ˆ๋ จ & ํ‰๊ฐ€

### ๐ŸŒ ํ›ˆ๋ จ ํ™˜๊ฒฝ

```py
input_dim = 784
hidden_dim = 100
latent_dim = 20
num_layers = 2

epochs = 30
learning_rate = 1e-3
batch_size = 64

optimizer = optim.Adam(...)
```

> ์ด ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜: *174,084*

### ๐Ÿ“ƒ ํ›ˆ๋ จ ๊ฒฐ๊ณผ

#### 1๏ธโƒฃ ํ›ˆ๋ จ ์†์‹คํ•จ์ˆ˜ ๊ทธ๋ž˜ํ”„



#### 2๏ธโƒฃ ๋žœ๋ค ์ˆซ์ž ์ด๋ฏธ์ง€ ์ƒ์„ฑ

ํ‘œ์ค€์ •๊ทœ๋ถ„ํฌ $\mathcal{N}(\mathbf{x};\mathbf{0},\mathbf{I})$ ์—์„œ ํ‘œ๋ณธ 64๊ฐœ๋ฅผ ์ถ”์ถœ, ํ•™์Šต๋œ ๊ณ„์ธตํ˜• VAE์— ์ž…๋ ฅํ•จ.



#### 3๏ธโƒฃ ์ž ์žฌ๊ณต๊ฐ„ t-SNE ์‹œ๊ฐํ™”

๋งˆ์ง€๋ง‰ ์ž ์žฌ๋ณ€์ˆ˜์˜ ๋ถ„ํฌ $q_{\phi_2}(z_2|z_1)$๋ฅผ t-SNE๋ฅผ ํ†ตํ•ด 2์ฐจ์›์œผ๋กœ ๋‚ฎ์ถฐ ์‹œ๊ฐํ™”ํ•จ.



#### 4๏ธโƒฃ ์ž ์žฌ๊ณต๊ฐ„์—์„œ ํ•™์Šต๋œ ๋‹ค์–‘์ฒด(Manifold) ํˆฌ์˜

์ž ์žฌ๊ณต๊ฐ„์˜ ์ฒซ๋ฒˆ์งธ ์ฐจ์›๊ณผ ๋‘๋ฒˆ์งธ ์ฐจ์›์„ $\mathbb{R}\in[-3,3]$ ๋ฒ”์œ„์—์„œ ํƒ์ƒ‰ํ•˜์—ฌ ๊ณ ์ฐจ์› ์ž ์žฌ๊ณต๊ฐ„ ๋‚ด์—์„œ ํ•™์Šต๋œ ๋‹ค์–‘์ฒด๋ฅผ 2์ฐจ์›์œผ๋กœ ํˆฌ์˜ํ•จ.