Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/kylesayrs/sketchgeneration
Probabilistic sequence generation of sketch drawings which builds on top of Google Brain's "A Neural Representation of Sketch Drawings"
https://github.com/kylesayrs/sketchgeneration
gaussian-mixture-models generative-ai magenta pytorch sketch-rnn
Last synced: 2 days ago
JSON representation
Probabilistic sequence generation of sketch drawings which builds on top of Google Brain's "A Neural Representation of Sketch Drawings"
- Host: GitHub
- URL: https://github.com/kylesayrs/sketchgeneration
- Owner: kylesayrs
- Created: 2024-03-22T16:57:02.000Z (10 months ago)
- Default Branch: master
- Last Pushed: 2024-06-11T02:28:00.000Z (7 months ago)
- Last Synced: 2024-06-11T03:59:40.392Z (7 months ago)
- Topics: gaussian-mixture-models, generative-ai, magenta, pytorch, sketch-rnn
- Language: Python
- Homepage:
- Size: 3.83 MB
- Stars: 0
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Sketch Generation #
The work in this repo is based upon research produced by Google Brain "[A Neural Representation of Sketch Drawings](https://arxiv.org/pdf/1704.03477)" which models probabilistic sequence prediction to generate sketch drawings. Notable differences include
* Replacement of HyperLSTM backbone with multi-headed transformer layers
* Replacement of NLLLoss with focal loss to more quickly learn underrepresented pen states
* Replacement of exponential activation of sigma variables with ELU to promote GMM training stability
## Lessons Learned ##
1. Gaussian mixture models do not train stablyThe first aspect that makes GMM sequence models difficult to train is the tendency of GMMs to collapse onto sample points and produce exploding gradients (see [GMMPyTorch](https://github.com/kylesayrs/GMMPytorch)). The original authors partially address this by imposing a `gradient_clip` parameter which limits the maximum magnitude of any one gradient step. While this technique is effective, I found that the model still produced extreme positive loss spikes which made analysis difficult and renders some batches useless/counter productive. My solution was to, in addition to imposing gradient clipping, replace the `exp` activation, which is subject to gradient explosion on both sides, with `ELU` plus a very small constant to the diagonal sigmas which limits the minimum standard deviation variable, thereby limiting the collapsing effect and improving sample efficiency.
2. Autoregressive training and inference is highly dependent on temperature
The second difficult aspect is the autoregressive nature of the sequence model. Like any autoregressive model, this model is prone to autoregressive drift if predictions are produced which are not within the dataset distribution. To limit this effect, a `temperature` parameter is required during inference time to reduce the chance of positions and pen states outside of the expected distribution.
3. Focal loss over NLLLoss
The original paper uses a softmax activation followed by `NLLLoss` to learn the pen state classes. While this technique works, I found that I was able to learn the minority class (pen_up) much faster using focal loss with a `gamma` of value of 2.
4. Train a toy dataset first
When training using highly custom models like this one, its much better to train using a toy dataset first, rather than starting with the full dataset. Training on just a single sample was vital to catching bugs and served as a basis for expected model output outside of just loss alone.
## Future work ##
One notable flaw in the original architecture is that the pen state and pen position are predicted independently of one another. This can lead to situations where the model infers a split pen position distribution which attempts to match both the pen being raised and the pen continuing to be drawn. This could be resolved by either coupling a new pen distribution to each component of the gaussian distribution or explicitly conditioning the GMM components on the pen prediction.