Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/aicenter/GenerativeModels.jl
Generative Models with trainable conditional distributions in Julia!
https://github.com/aicenter/GenerativeModels.jl
Last synced: about 1 month ago
JSON representation
Generative Models with trainable conditional distributions in Julia!
- Host: GitHub
- URL: https://github.com/aicenter/GenerativeModels.jl
- Owner: aicenter
- License: mit
- Created: 2019-08-22T15:52:23.000Z (over 5 years ago)
- Default Branch: master
- Last Pushed: 2022-09-21T03:13:32.000Z (about 2 years ago)
- Last Synced: 2024-08-10T03:40:48.756Z (4 months ago)
- Language: Julia
- Homepage:
- Size: 411 KB
- Stars: 31
- Watchers: 7
- Forks: 3
- Open Issues: 9
-
Metadata Files:
- Readme: README.jmd
- License: LICENSE
Awesome Lists containing this project
- awesome-generative-ai-meets-julia-language - GenerativeModels.jl - Useful library to train more traditional generative models like VAEs. It's built on top of Flux.jl. (Noteworthy Mentions / Generative AI - Previous Generation)
README
---
weave_options:
doctype: github
---![Run tests](https://github.com/aicenter/GenerativeModels.jl/workflows/Run%20tests/badge.svg)
[![codecov](https://codecov.io/gh/aicenter/GenerativeModels.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/aicenter/GenerativeModels.jl)# GenerativeModels.jl
This library contains a collection of generative models.
It uses trainable
[`ConditionalDists.jl`](https://github.com/aicenter/ConditionalDists.jl) that
can be used in conjuction with [`Flux.jl`](https://github.com/FluxML/Flux.jl)
models. Probability measures such as KL divergence are defined in
[`IPMeasures.jl`](https://github.com/aicenter/IPMeasures.jl) This package aims
to make experimenting with new models as easy as possible.As an example, check out how to build a conventional variational autoencoder (VAE)
that reconstructs MNIST below.# Reconstructing MNIST
First we load the MNIST training dataset
```julia results="hidden"
using MLDatasets, Flux
train_x, _ = MNIST.traindata(Float32)
flat_x = reshape(train_x, :, size(train_x,3)) |> gpu
data = Flux.Data.DataLoader(flat_x, batchsize=200, shuffle=true);
```
and define some parameters for a VAE with an input length `xlength` and latent
vector of `zlength`.
```julia results="hidden"
using ConditionalDistsxlength = size(flat_x, 1)
zlength = 2
hdim = 512
hd2 = Int(hdim/2)
```
We define an `encoder` with diagonal variance on the latent dimension,
which is just a Flux model wrapped in a `ConditionalMvNormal`. The Flux model
must return a tuple with the appropriate number of parameters - in case of a
`MvNormal` two: mean and variance. Hence, the `SplitLayer` returns two vectors
of `zlength`, one of which (the variance) is constrained to be positive.
```julia results="hidden"
using ConditionalDists: SplitLayer# mapping that will be trained to output mean and variance
enc_map = Chain(Dense(xlength, hdim, relu),
Dense(hdim, hd2, relu),
SplitLayer(hd2, [zlength,zlength], [identity,softplus]))
# conditional encoder (can be called e.g. like `rand(encoder,x)`, see ConditionalDists.jl)
encoder = ConditionalMvNormal(enc_map)
```
The decoder will return a Multivariate Normal with scalar variance:
```julia results="hidden"
dec_map = Chain(Dense(zlength, hd2, relu),
Dense(hd2, hdim, relu),
SplitLayer(hdim, [xlength,1], σ))
decoder = ConditionalMvNormal(dec_map)
```
Now we can create the VAE model and train it to maximize the ELBO.
```julia results="markdown"; cache=true
using GenerativeModelsmodel = VAE(zlength, encoder, decoder) |> gpu
loss(x) = -elbo(model,x)ps = Flux.params(model)
opt = ADAM()for e in 1:50
@info "Epoch $e" loss(flat_x)
Flux.train!(loss, ps, data, opt)
end
```
```julia; echo=false; eval=true; results="hidden"
using Plots
using Distributions
pyplot()
function plot_reconstructions(model, test_x)
plts = []
for i in 1:size(test_x,3)
img = test_x[:,:,i]
s1 = heatmap(img'[end:-1:1,:], title="truth")
push!(plts, s1)
x = reshape(img,:)
rec(x) = mean(model.decoder, mean(model.encoder, x))
x̂ = rec(x)
rec = reshape(x̂, size(img))'[end:-1:1,:]
s2 = heatmap(rec, title="rec")
push!(plts, s2)
end
p1 = plot(plts..., size=(800,600))
end
```
Some test reconstructions and the corresponding latent space are shown below:
```julia
model = model |> cpu
test_x, test_y = MNIST.testdata(Float32)
p1 = plot_reconstructions(model, test_x[:,:,1:6])
```
```julia echo=false
z = mean(model.encoder, reshape(test_x,:,size(test_x,3)))
p2 = plot(size=(500,500), title="Latent space")
for nr in 0:9
mask = test_y .== nr
scatter!(p2, z[1,mask], z[2,mask], c=test_y[mask], size=(500,500), label=nr)
end
display(p2)
```