https://github.com/astupidbear/reversediffflux.jl
https://github.com/astupidbear/reversediffflux.jl
Last synced: 7 months ago
JSON representation
- Host: GitHub
- URL: https://github.com/astupidbear/reversediffflux.jl
- Owner: AStupidBear
- License: mit
- Created: 2020-06-18T11:43:04.000Z (over 5 years ago)
- Default Branch: master
- Last Pushed: 2021-09-07T04:30:45.000Z (about 4 years ago)
- Last Synced: 2025-01-10T03:11:15.192Z (9 months ago)
- Language: Julia
- Size: 13.7 KB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# ReverseDiffFlux
[](https://github.com/AStupidBear/ReverseDiffFlux.jl/actions)
[](https://codecov.io/gh/AStupidBear/ReverseDiffFlux.jl)## Example
```julia
using Statistics
using ReverseDiffFlux
using Fluxx = randn(Float32, 10, 1, 100)
y = mean(x, dims = 1)model = Chain(LSTM(10, 100), LSTM(100, 1)) |> ReverseDiffFlux.track
function loss(x, y)
xs = Flux.unstack(x, 3)
ys = Flux.unstack(y, 3)
ŷs = model.(xs)
l = 0f0
for t in 1:length(ŷs)
l += Flux.mse(ys[t], ŷs[t])
end
return l / length(ŷs)
end
ps = Flux.params(model)
data = repeat([(x, y)], 100)
opt = ADAMW(1e-3, (0.9, 0.999), 1e-4)
cb = () -> Flux.reset!(model)
ReverseDiffFlux.overload_gradient()
Flux.@epochs 10 Flux.train!(loss, ps, data, opt, cb = cb)
```