An open API service indexing awesome lists of open source software.

https://github.com/juliapomdp/deepqlearning.jl

Implementation of the Deep Q-learning algorithm to solve MDPs
https://github.com/juliapomdp/deepqlearning.jl

deep-reinforcement-learning machine-learning pomdps reinforcement-learning

Last synced: 3 months ago
JSON representation

Implementation of the Deep Q-learning algorithm to solve MDPs

Awesome Lists containing this project

README

          

# DeepQLearning

[![Build status](https://github.com/JuliaPOMDP/DeepQLearning.jl/workflows/CI/badge.svg)](https://github.com/JuliaPOMDP/DeepQLearning.jl/actions/workflows/CI.yml)
[![codecov](https://codecov.io/github/JuliaPOMDP/DeepQLearning.jl/branch/master/graph/badge.svg?token=EfDZPMisVB)](https://codecov.io/github/JuliaPOMDP/DeepQLearning.jl)

This package provides an implementation of the Deep Q learning algorithm for solving MDPs. For more information see https://arxiv.org/pdf/1312.5602.pdf.
It uses POMDPs.jl and Flux.jl

It supports the following innovations:
- Target network
- Prioritized replay https://arxiv.org/pdf/1511.05952.pdf
- Dueling https://arxiv.org/pdf/1511.06581.pdf
- Double Q http://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/download/12389/11847
- Recurrent Q Learning

## Installation

```Julia
using Pkg
Pkg.add("DeepQLearning")
```

## Usage

```Julia
using DeepQLearning
using POMDPs
using Flux
using POMDPModels
using POMDPTools

# load MDP model from POMDPModels or define your own!
mdp = SimpleGridWorld();

# Define the Q network (see Flux.jl documentation)
# the gridworld state is represented by a 2 dimensional vector.
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))

exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2));

solver = DeepQLearningSolver(qnetwork = model, max_steps=10000,
exploration_policy = exploration,
learning_rate=0.005,log_freq=500,
recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
policy = solve(solver, mdp)

sim = RolloutSimulator(max_steps=30)
r_tot = simulate(sim, mdp, policy)
println("Total discounted reward for 1 simulation: $r_tot")
```

## Specifying exploration / evaluation policy

An exploration policy and evaluation policy can be specified in the solver parameters.

An **exploration policy** can be provided in the form of a function that must return an action. The function provided will be called as follows: `f(policy, env, obs, global_step, rng)` where `policy` is the NN policy being trained, `env` the environment, `obs` the observation at which to take the action, `global_step` the interaction step of the solver, and `rng` a random number generator. This package provides by default an epsilon greedy policy with linear decrease of epsilon with `global_step`.

An **evaluation policy** can be provided in a similar manner. The function will be called as follows: `f(policy, env, n_eval, max_episode_length, verbose)` where `policy` is the NN policy being trained, `env` the environment, `n_eval` the number of evaluation episode, `max_episode_length` the maximum number of steps in one episode, and `verbose` a boolean to enable printing or not. The evaluation function must returns three elements:
- Average total reward (Float), the average score per episode
- Average number of steps (Float), the average number of steps taken per episode
- Info, a dictionary mapping `String` to `Float` that can be used to log custom scalar values.

## Q-Network

The `qnetwork` options of the solver should accept any `Chain` object. It is expected that they will be multi-layer perceptrons or convolutional layers followed by dense layer. If the network is ending with dense layers, the `dueling` option will split all the dense layers at the end of the network.

If the observation is a multi-dimensional array (e.g. an image), one can use the `flattenbatch` function to flatten all the dimensions of the image. It is useful to connect convolutional layers and dense layers for example. `flattenbatch` will flatten all the dimensions but the batch size.

The input size of the network is problem dependent and must be specified when you create the q network.

This package exports the type `AbstractNNPolicy` which represents neural network based policy. In addition to the functions from `POMDPs.jl`, `AbstractNNPolicy` objects supports the following:
- `getnetwork(policy)`: returns the value network of the policy
- `resetstate!(policy)`: reset the hidden states of a policy (does nothing if it is not an RNN)

## Saving/Reloading model

See [Flux.jl documentation](http://fluxml.ai/Flux.jl/stable/saving.html) for saving and loading models. The DeepQLearning solver saves the weights of the Q-network as a `bson` file in `solver.logdir/"qnetwork.bson"`.

## Logging

Logging is done through [TensorBoardLogger.jl](https://github.com/PhilipVinc/TensorBoardLogger.jl). A log directory can be specified in the solver options, to disable logging you can set the `logdir` option to `nothing`.

## GPU Support

`DeepQLearning.jl` should support running the calculations on GPUs through the package [CuArrays.jl](https://github.com/JuliaGPU/CuArrays.jl).
You must checkout the branch `gpu-support`. Note that it has not been tested thoroughly.
To run the solver on GPU you must first load `CuArrays` and then proceed as usual.

```julia
using CuArrays
using DeepQLearning
using POMDPs
using Flux
using POMDPModels

mdp = SimpleGridWorld();

# the model weights will be send to the gpu in the call to solve
model = Chain(Dense(2, 32), Dense(32, length(actions(mdp))))

exploration = EpsGreedyPolicy(mdp, LinearDecaySchedule(start=1.0, stop=0.01, steps=10000/2));

solver = DeepQLearningSolver(qnetwork=model, max_steps=10000,
exploration_policy=exploration,
learning_rate=0.005,log_freq=500,
recurrence=false,double_q=true, dueling=true, prioritized_replay=true)
policy = solve(solver, mdp)
```

## Solver Options

**Fields of the Q Learning solver:**
- `qnetwork::Any = nothing` Specify the architecture of the Q network
- `exploration_policy::