Lukas Schwarz

Explanation and Implementation of DQN with Tensorflow and Keras

GitHub GitHub

Deep Q-Learning (DQN) is a family of algorithms used in reinforcement learning to find an optimal policy. It is based on Q-Learning, which estimates the action-value function for the optimal policy by iterating the Bellman optimality equation. Knowing the action-value function, the optimal policy can then be obtained by choosing the action with the highest value. Instead of using a simple table to represent the value function as in traditional Q-Learning, in DQN a neural network is trained and used as a function approximator. The algorithm was introduced by Mnih et al. [1-2], where additional techniques like experience replay and a separate target network have been used to stabilize the training. Since then, further improvements have been added to the algorithms, for example Double DQN [3], prioritized experience replay [4] or dueling networks [5] - just to mention a few. Some of these extensions are discussed and combined in [6]. In this text, I first explain the involved algorithms and then implement DQN with experience replay and a separate target network using Tensorflow, Keras and the Gym API for the environment.

Index

Markov Decision Process (MDP)

To understand Q-Learning, we have to first define the reinforcement learning setting. The reinforcement learning problem is based on a Markov decision process which consists of the sequence $(S_0,A_0,R_1,S_1,A_1,\ldots,S_T)$ with states $S_t$, actions $A_t$ and rewards $R_t$ at timestep $t$. Here, we will consider only finite sequences, called episodes, up to a terminal state $S_T$. The state $S_t$ can be represented as a multidimensional vector of discrete or continuous numbers. Furthermore, we assume that there is only a finite number $N_a$ of discrete actions, i.e. $A_t \in [1,\ldots,N_a]$. The episode is generated by starting in an initial state $S_0$. An agent then selects an action according to a policy $\pi(s,a)$, which is a probability distribution for choosing an action $a$ given state $s$. The environment returns a reward $R_1$ and a new state $S_1$. This process of taking an action and then observing a reward and new state continues until the terminal state $S_T$ is reached. This is schematically shown in the following figure.

The goal of the reinforcement learning setting is to find a policy $\pi_*$ which maximizes the return $G_t$, defined as the total discounted future reward of timestep $t$ $$ \begin{aligned} G_t = \sum_{n=0} \gamma^n R_{t+n+1} \end{aligned} $$ Hereby, the discount factor $\gamma \in [0,1]$ can be used to discount rewards in the future, i.e. for $\gamma < 1$, future rewards are valued less than immediate rewards. To determine the optimal policy, one can define an action-value function $$ \begin{aligned} q_\pi(s,a) = \mathbb E_\pi[G_t|S_t = s, A_t = a] \end{aligned} $$ which is the expected return of starting in state $s$, taking action $a$ and then following policy $\pi$ in the subsequent process. The optimal action-value function $q_*$ is the maximum over all policies $$ \begin{aligned} q_*(s,a) = \operatorname{max}_\pi q_\pi(s,a) \end{aligned} $$ If one knows the optimal action-value function $q_*$, the optimal policy $\pi_*$ is easily obtained by choosing the action $a$ which maximizes $q_*$ $$ \begin{aligned} \pi_*(s,a) = \begin{cases} 1 & a = \operatorname{argmax}_{a'} q_*(s,a')\\ 0 & \text{else} \end{cases} \end{aligned} $$ There are also other methods to directly determine the optimal policy without finding a value function first, e.g. policy gradient methods. However, I will not consider these methods here.

Bellmann Equation and Iteration

The question is of course, how do we find $q_*$? To get to an answer, one can make use of a recursion equation for the value function which is called Bellmann equation. It can be derived by inserting the expression for the return $G_t$ in the expectation value and separating the immediate reward $R_{t+1}$ from the rest of the sum $$ \begin{aligned} q_\pi(s,a) &= \mathbb E_\pi [G_t | S_t = s, A_t = a]\\ &= \mathbb E_\pi \left[\sum_{n=0} \gamma^n R_{t+n+1} | S_t = s, A_t = a\right]\\ &= \mathbb E_\pi \left[R_{t+1} + \gamma\sum_{n=0} \gamma^n R_{t+n+2} | S_t = s, A_t = a\right]\\ &= \mathbb E_\pi \left[R_{t+1} + \gamma G_{t+1} | S_t = s, A_t = a\right]\\ &= \mathbb E_\pi \left[R_{t+1} + \gamma q_\pi(S_{t+1},A_{t+1}) | S_t = s, A_t = a\right]\\ \end{aligned} $$ Given a policy $\pi$ and experience $(S_t,A_t,R_{t+1},S_{t+1})$, one can turn the Bellmann equation into a fixed-point iteration to get an estimate $Q_\pi(s,a)$ for the value function $$ \begin{aligned} Q_\pi(S_t,A_t) \leftarrow Q_\pi(S_t,A_t) + \alpha \Big( \mathbb E_\pi[R_{t+1} + \gamma Q_\pi(S_{t+1},A_{t+1})] - Q_\pi(S_t,A_t)\Big) \end{aligned} $$ Hereby, $\alpha$ is a step-size parameter which controls the rate of the learning process. For $\alpha = 1$, the iteration is exactly the Bellmann equation. If $Q_\pi = q_\pi$, the error $\Big(\mathbb E_\pi[R_{t+1} + \gamma Q_\pi(S_{t+1},A_{t+1})] - Q_\pi(S_t,A_t)\Big)$ is zero according to the Bellmann equation and the iteration is converged.

Now two questions still remain: First, how do we proceed from obtaining the action value $Q_\pi$ for some policy $\pi$ to obtaining the optimal action value function $Q_*$ for the optimal policy $\pi_*$? Secondly, how do we evaluate the expectation value in the iteration?

To answer the first question, one can use the general policy iteration (GPI) scheme [7]. The idea is that one starts from a random policy with a random value function estimate. Then, the policy is evaluated via iteration of the Bellmann equation to obtain a more accurate value function estimate for the current policy. Hereby, the iteration doesn't have to be done until it is converged, a single improvement step might be enough. After that, the policy is improved by making it greedy with respect to the current value function estimate, i.e. selecting the action with the highest value-function. This process is then repeated many times. During this iteration, the value function is always changed to fit better to the current policy, while the policy is always changed to be better according to the current value function. Thus, there are two moving targets for the iteration, where only the optimal value function and policy are the fixed point. This process is schematically shown in the following figure

Figure adapted from Sutton and Barto [7]

To ensure that the iteration converges one has to ensure that all states-action pairs are encountered. Therefore, an $\epsilon$-soft policy should be used, where the policy does not always act greedily but chooses a random action with a probability of $\epsilon$. Such an action selection is not optimal but crucial and reflects the difficult exploration-exploitation tradeoff inherent in reinforcement learning.

To answer the second remaining question about the evaluation of the expectation value, there are different answers, which actually lead to different algorithms.

Dynamic Programming
If the model of the environment, in particular the transition probability $P_\pi(s,s',a)$ is known, then the expectation value can be exactly computed. $$ \begin{aligned} \mathbb E_\pi[R_{t+1} + \gamma Q_\pi(S_{t+1},A_{t+1})] = \sum_{s'} P(s,s',a) \left( r(s') + \gamma \sum_{a'} \pi(s',a') Q_\pi(s',a') \right) \end{aligned} $$ This method is called dynamic programming and the optimal policy can be obtained via policy iteration or value iteration. However, having perfect knowledge about the environment is rarely the case, such that these methods are restricted to small and special cases.
Monte Carlo Methods
The simplest approach for model-free learning using only experience, is to use a sample return as estimate for the expectation value, namely executing a complete episode and then performing an update with the total return $G_t$ $$ \begin{aligned} \mathbb E_\pi[R_{t+1} + \gamma Q_\pi(S_{t+1},A_{t+1})] \approx G_t \end{aligned} $$ Such an approach is called Monte Carlo method. It has the disadvantage that one has to wait until the end of an episode to perform a single update.
Temporal Difference Learning
A second approach is to approximate the expectation value via a single example at each time step. Such methods are called temporal-difference (TD) learning. These methods make use of the current estimate for the next state $Q_\pi(S_{t+1},A_{t+1})$, which is called bootstrapping. Compared to Monte-Carlo, the update can be made online and one does not have to wait until the end of the episode. If one uses the actual quintuple $(S_t,A_t,R_{t+1},S_{t+1},A_{t+1})$, the algorithm is called SARSA $$ \begin{aligned} \mathbb E_\pi[R_{t+1} + \gamma Q_\pi(S_{t+1},A_{t+1})] \approx R_{t+1} + \gamma Q_\pi(S_{t+1},A_{t+1}) \end{aligned} $$ To directly approximate $Q_*$ one can use as an estimate the maximum over all actions. The algorithm is called Q-Learning $$ \begin{aligned} \mathbb E_\pi[R_{t+1} + \gamma Q_\pi(S_{t+1},A_{t+1})] \approx R_{t+1} + \gamma \operatorname{max}_a Q(S_{t+1},a) \end{aligned} $$ A generalization of SARSA and Q-Learning is Expected SARSA, which is computationally more costly as it requires to calculate the sum over all actions $$ \begin{aligned} \mathbb E_\pi[R_{t+1} + \gamma Q_\pi(S_{t+1},A_{t+1})] \approx R_{t+1} + \gamma \sum_a \pi(S_{t+1},a) Q_\pi(S_{t+1},a) \end{aligned} $$

Please note that the list of shown algorithms is not exhaustive, only the basic algorithms are presented. In the following, I will concentrate on Q-Learning in more detail.

Tabular Q-Learning

Q-Learning is an off-policy temporal-difference control algorithm to estimate the optimal action-value function $q_*$. As discussed above, the update rule for the value function estimate $Q(S_t,A_t)$ at time step $t$ reads $$ \begin{aligned} Q(S_t,A_t) \leftarrow Q(S_t,A_t) + \alpha \left[ R_{t+1} + \gamma \operatorname{max}_a Q(S_{t+1},a) - Q(S_t,A_t) \right] \end{aligned} $$ Hereby, $Q(s,a)$ is a tabular representation of the action-value function for state $s$ and action $a$. The so-called TD target in the equation above $$ \begin{aligned} R_{t+1} + \gamma \operatorname{max}_a Q(S_{t+1},a) \end{aligned} $$ estimates the return by bootstrapping from the value for the next state $S_{t+1}$ assuming a greedy policy, thus the $\operatorname{max}_a$ operation over all states. Therefore, the obtained action-value function directly estimates the value for the optimal policy $\pi_*$. For the behavior, an $\epsilon$-soft policy should be used to ensure exploration. A basic implementation of the algorithm is shown below (see also )

import numpy as np
from collections import defaultdict

def qlearning(env, alpha, gamma, epsilon, num_epsisodes):
    # Initialisize Q table to zero
    Q = defaultdict(lambda: np.zeros(env.action_space.n))

    # Run several episodes
    for i_episode in range(num_episodes):
        state = env.reset()
        done = False
        while not done:
            # Epsilon-greedy action selection
            if np.random.rand() < epsilon:
                action = env.action_space.sample()
            else:
                action = np.argmax(Q[state])

            # Take action and obtain observation and reward
            state_new, reward, done, info = env.step(action)

            # Update of Q table
            Q[state][action] += alpha*(reward
                + gamma*np.max(Q[state_new]) - Q[state][action])
            state = state_new
    return Q

The Q-table is implemented as a defaultdict, where the key is the state or observation as returned by the environment and the value is a numpy array with the size of the number of actions.

Q-Learning with Function Approximation

For large or continuous state spaces, a representation of the action-value function as a table is no longer possible. Instead, the value function can be approximated by a parametrized function $Q(s,a,\vec w)$, where the parameters $\vec w$ are adjusted during the learning process such that the parametrized function represents the value function best. The parameters can be optimized by stochastic gradient descent methods like basic SGD or more advanzed optimizers like Adam. For basic SGD, the usual update rule for the weights $\vec \omega$ reads $$ \begin{aligned} \vec w \leftarrow \vec w - \alpha \nabla_{\vec w}L \end{aligned} $$ where the loss or error function is given by the mean-squared TD error $$ \begin{aligned} L &= \frac 1 2 \Big( R_{t+1} + \gamma \operatorname{max}_a Q(S_{t+1},a,\vec w) - Q(S_t,A_t,\vec w) \Big)^2 \end{aligned} $$ In principle, the gradient would read $$ \begin{aligned} \nabla_{\vec w}L &= \Big( R_{t+1} + \gamma \operatorname{max}_a Q(S_{t+1},a,\vec w) - Q(S_t,A_t,\vec w) \Big) \\ &\qquad \times \Big( \gamma \nabla_{\vec w}\operatorname{max}_a Q(S_{t+1},a,\vec w) - \nabla_{\vec w}Q(S_t,A_t,\vec w) \Big) \end{aligned} $$ where two gradient terms occur as the TD target depends itself on the weights $\vec w$. However, to obtain a simple update rule, one often neglects the first gradient in the expression above. The update rule then reads $$ \begin{aligned} \vec w \leftarrow \vec w + \alpha \left[ R_{t+1} + \gamma \operatorname{max}_a Q(S_{t+1},a,\vec w) - Q(S_t,A_t,\vec w) \right] \cdot \nabla_{\vec w} Q(S_t,A_t,\vec w) \end{aligned} $$ Due to this approximation, the method is called semi-gradient method. As the update is not exactly correct, the convergence of the iteration is slightly worse. Yet, this is somewhat compensated due to the usage of a TD method which in general converges faster than Monte-Carlo methods where such an approximation would not be needed. In practice, this update works quite well. A basic implementation of such a semi-gradient Q-Learning algorithm is shown below (see also )

import numpy as np
from collections import defaultdict

def semi_gradient_qlearning(env, qfunc, qfunc_deriv, w,
        gamma, alpha, epsilon, num_episodes):
    # Run several episodes
    for i_episode in range(num_episodes):
        state = env.reset()

        # Action values for current state
        Qs = [qfunc(state, a, w) for a in range(env.action_space.n)]
        done = False
        while not done:
            # Epsilon-greedy action selection
            if np.random.rand() < epsilon:
                action = env.action_space.sample()
            else:
                action = np.argmax(Qs)

            # Take action and obtain observation and reward
            state_new, reward, done, info = env.step(action)

            # Gradient of current state-action pair
            dQsa = qfunc_deriv(state, action, w)

            # Construct target
            target = reward
            if not done:
                # Action values for new state
                Qs_new = [qfunc(state_new, a, w)
                    for a in range(env.action_space.n)]
                target += gamma*np.max(Qs_new)

            # Update weights
            w += alpha*(target - Qs[action])*dQsa

            state = state_new
            Qs = Qs_new
    return w

In the function above, qfunc(s, a, w) is an arbitrary function of state s, action a and parameters w representing the value-function, while qfunc_deriv(s, a, w) is its analytic derivative.

Q-Learning with Neural Networks (DQN)

The parametrized function $Q(s,a,\vec w)$ as estimate for the action-value function described above can be any function. Using neural networks for this purpose, the algorithm is generally called deep Q-Learning.

Neural Network Architecture

The neural network could have in principle one of the following two architectures:

where

  • a) The network has as many inputs as the state $s$ has dimensions and an additional input for an action $a$. The network has a single output representing $Q(s, a)$. This is equal to the definition in the previous section.
  • b) The network has as many inputs as the state $s$ has dimensions. The network has as many outputs as there are different actions, i.e. each output $i$ represents $Q(s, a=i)$.

The architecture b) has the advantage that a single forward pass of the network directly returns all action-values of a single state and it is therefore computationally advantageous. Thus, it is the commonly used choice and will also be used in the following.

Let us look at an example implementation with Tensorflow and Keras. We first define a Keras model with the correct input/output dimensions. As an example, here we consider a two-layer network with 100 hidden units each and relu activations using Adam as an optimizer

state_dim = env.observation_space.shape[0]
num_actions = env.action_space.n
model = tf.keras.Sequential()
num_units = 100
model.add(tf.keras.layers.Dense(num_units, input_dim=state_dim, activation="relu"))
model.add(tf.keras.layers.Dense(num_units, activation="relu"))
model.add(tf.keras.layers.Dense(num_actions))
model.compile(loss="mse", optimizer=tf.keras.optimizers.Adam())

Given a single state state, the value function $Q(s,a)$ for all actions can be computed via

Qs = model.predict(np.array([state]))[0]

To perform an update, we make use of Tensorflow's autodiff feature to calculate the gradients and use the model.fit() function to do a single update step for given $(S_t,A_t,R_{t+1},S_{t+1})$ tuple with the already defined optimizer

model.fit(x, y, epochs=1, verbose=0)

to minimize the mean-squared TD-error as defined above $$ \begin{aligned} L &= \frac 1 2 \Big( R_{t+1} + \gamma \operatorname{max}_a Q(S_{t+1},a,\vec w) - Q(S_t,A_t,\vec w) \Big)^2 \end{aligned} $$

Hereby, x corresponds to the state $S_t$ for which we want to update the value function $Q(S_t,A_t)$

x = np.array([state])

Furthermore, y corresponds to the TD-target $R_{t+1} + \gamma \operatorname{max}_a Q(S_{t+1},a,\vec w)$. As we want to update the value function only for the single action $A_t$, yet the neural network calculates the value function for all actions $a$, we simply set y to $Q(S_t,a)$ for all $a \neq A_t$ such that the TD error is zero and there is no update. Only for $a = A_t$, we set $y$ to the TD-target

y = model.predict(x)[0]
y[action] = reward
if not done:
    y[action] += gamma*np.max(model.predict(np.array([state_new]))[0])

Here, we additionally make sure that we add the estimate of the next state only for a nonterminal state. Summarizing all the ingredients, a basic implementation of the DQN algorithm looks like

def dqn(env, model, gamma, epsilon, N_episodes):
    # Run several episodes
    for i_episode in range(N_episodes):
        state = env.reset()

        # Action values for current state
        Qs = model.predict(np.array([state]))[0]
        done = False
        while not done:
            # Epsilon-greedy action selection
            if np.random.rand() < epsilon:
                action = env.action_space.sample()
            else:
                action = np.argmax(Qs)

            # Take action and obtain observation and reward
            state_new, reward, done, info = env.step(action)

            # Construct training example and update weights
            x = np.array([state])
            y = Qs[:]
            y[action] = reward
            if not done:
                Qs = model.predict(np.array([state_new]))[0]
                y[action] += gamma*np.max(Qs)
            y = np.array([y])
            model.fit(x, y, epochs=1, verbose=0)
            state = state_new

Unfortunately, the direct usage of neural networks as shown does not work well. It is data inefficient, as every (state,action,reward) tuple in an episode is only used once for training and the training is very instable. As a consequence, several additional techniques have been developed to improve the learning process. Some of these extensions are described for example in [6], where it was shown that they can be complementary combined to achieve better performance. Here, I make use of two basic improvements, namely Experience Replay and a Separate Target Network.

Experience Replay

A first useful concept is experience replay, where instead of an update of the weights with a single example, a random minibatch of examples from a buffer of experiences is used. The original paper [1] states two main advantages of this approach:

  • Data efficiency: Each experience can be used multiple times for weight updates.
  • Break of correlations: Random examples from the batch remove correlations between otherwise consecutive experiences.

The procedure is the following. We first define a buffer of finite length to store the experiences. For this, we use python's deque data structure. It automatically discards the oldest elements in the queue if the maximum number of stored experiences is exceeded.

memory = deque(maxlen=buffer_size)

For each step of the episode, the experience tuple is added to the memory

memory.append((state,action,reward,state_new,done))

As soon as the stored elements in the buffer exceeds the defined batch size, i.e. enough experience is collected, the experience replay is performed in each step of an episode. Thus, the algorithms simultaneously collects new experiences, which are added to the memory, and optimizes the network weights using batches from the memory. Hereby, a random batch of experiences is selected from the memory. Then, the same update as shown before for the single experience of the current $(S_t,A_t,R_{t+1},S_{t+1})$ tuple is performed, but now for all experiences of the batch.

if len(memory) > batch_size:
    experience_sample = random.sample(memory, batch_size)
    x = np.array([e[0] for e in experience_sample])

    # Construct target
    y = model.predict(x)
    x2 = np.array([e[3] for e in experience_sample])
    Q2 = gamma*np.max(model.predict(x2), axis=1)
    for i,(s,a,r,s2,d) in enumerate(experience_sample):
        y[i][a] = r
        if not d:
            y[i][a] += Q2[i]

    # Update
    model.fit(x, y, batch_size=batch_size, epochs=1, verbose=0)

Separate Target Network

A second improvement, which was introduced in [2], is to use a separate model for predicting the learning target. It has the exact same architecture as the model for prediction. This target network is updated on a much slower rate compared to the prediction network used to choose actions in the episode. This stabilizes the learning target, which would otherwise change in each training step. The implementation is straight-forward. The target model is cloned from the prediction model in the beginning

target_model = tf.keras.models.clone_model(model)

Then, in the construction of the learning target, target_model is used instead of model. The learned weights of model are finally copied on a regular basis to the target_model

if steps % target_update_freq == 0:
    target_model.set_weights(model.get_weights())

Final algorithm

To complete the implementation, we add $\epsilon$-decay to the model, i.e. reducing the exploration parameter over time to reach convergence to an optimal policy. One simple approach is to multiply $\epsilon$ by a factor smaller than 1 every step with a check that it does not drop below a threshold value to ensure a minimal exploration

epsilon *= epsilon_decay
epsilon = max(epsilon_min, epsilon)

Additionally, we want to monitor and track the learning progress. For this, we can use the tf.summary module. First, we define a filewriter

from datetime import datetime
timestr = datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = "logs/" + timestr
file_writer = tf.summary.create_file_writer(logdir)
file_writer.set_as_default()

In the algorithm, we then log the value of $\epsilon$ and the epsiode return

tf.summary.scalar("epsilon", epsilon, step=i_episode)
...
tf.summary.scalar("return", episode_return, step=i_episode)

and also a histogram of the model weights after each update

for layer in model.layers:
    for weight in layer.weights:
        weight_name = weight.name.replace(':', '_')
        tf.summary.histogram(weight_name, weight, step=steps)

Finally, we also save a checkpoint of the model every so often

if steps % checkpoint_freq == 0:
    model.save_weights("{}/weights-{:08d}-{:08d}".format(
        checkpoint_path, i_episode, steps))

The final implementation of DQN then reads (see also )

def dqn(env, model, gamma, epsilon, epsilon_decay, epsilon_min, episodes,
    buffer_size, batch_size, target_update_freq, checkpoint_freq,
    checkpoint_path):
    steps = 0
    memory = deque(maxlen=buffer_size)
    target_model = tf.keras.models.clone_model(model)
    for i_episode in range(episodes):
        tf.summary.scalar("epsilon", epsilon, step=i_episode)
        episode_return = 0
        state = env.reset()
        done = False
        while not done:
            print("\r> DQN: Episode {}/{}, Step {}, Return {}".format(
                i_episode+1, episodes, steps, episode_return), end="")

            # Epsilon-greedy action selection
            if np.random.rand() < epsilon:
                action = env.action_space.sample()
            else:
                action = np.argmax(model.predict(np.array([state]))[0])

            # Take action and obtain observation and reward
            state_new, reward, done, info = env.step(action)
            episode_return += reward

            # Save experience
            memory.append((state,action,reward,state_new,done))

            # Experience replay
            if len(memory) > batch_size:
                experience_sample = random.sample(memory, batch_size)
                x = np.array([e[0] for e in experience_sample])

                # Construct target
                y = model.predict(x)
                x2 = np.array([e[3] for e in experience_sample])
                Q2 = gamma*np.max(target_model.predict(x2), axis=1)
                for i,(s,a,r,s2,d) in enumerate(experience_sample):
                    y[i][a] = r
                    if not d:
                        y[i][a] += Q2[i]

                # Update
                model.fit(x, y, batch_size=batch_size, epochs=1, verbose=0)

                # Save weight histogram
                for layer in model.layers:
                    for weight in layer.weights:
                        weight_name = weight.name.replace(':', '_')
                        tf.summary.histogram(weight_name, weight, step=steps)

            # Update of target model
            if steps % target_update_freq == 0:
                target_model.set_weights(model.get_weights())

            # Save model checkpoint
            if steps % checkpoint_freq == 0:
                model.save_weights("{}/weights-{:08d}-{:08d}".format(
                    checkpoint_path, i_episode, steps))

            state = state_new
            steps += 1

        # Epsilon decay
        epsilon *= epsilon_decay
        epsilon = max(epsilon_min, epsilon)

        tf.summary.scalar("return", episode_return, step=i_episode)
        tf.summary.flush()

        # Save final weights
        if steps-1 % checkpoint_freq != 0:
            model.save_weights("{}/weights-{:08d}-{:08d}".format(
                checkpoint_path, i_episode, steps-1))
    print()

This implementation of the DQN algorithm is very general and can be used by any Gym environment with any sequential Keras model architecture. Exemplary, I applied the algorithm at classic control problems.

References

  • [1] Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, Ioannis Antonoglou, Daan Wierstra, Martin Riedmiller, Playing Atari with Deep Reinforcement Learning (2013), [arXiv]
  • [2] Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Andrei A. Rusu, Joel Veness, Marc G. Bellemare, Alex Graves, Martin Riedmiller, Andreas K. Fidjeland, Georg Ostrovski, Stig Petersen, Charles Beattie, Amir Sadik, Ioannis Antonoglou, Helen King, Dharshan Kumaran, Daan Wierstra, Shane Legg, Demis Hassabis, Human-level control through deep reinforcement learning, Nature, 518, 529 (2015)
  • [3] Hado van Hasselt, Arthur Guez, David Silver, Deep Reinforcement Learning with Double Q-learning, Proc. AAAI 30, 2094 (2016), [arXiv]
  • [4] Tom Schaul, John Quan, Ioannis Antonoglou, David Silver, Prioritized Experience Replay (2015), [arXiv]
  • [5] Ziyu Wang, Tom Schaul, Matteo Hessel, Hado van Hasselt, Marc Lanctot, Nando de Freitas, Dueling Network Architectures for Deep Reinforcement Learning, PMLR 48, 1995 (2016), [arXiv]
  • [6] Matteo Hessel, Joseph Modayil, Hado van Hasselt, Tom Schaul, Georg Ostrovski, Will Dabney, Dan Horgan, Bilal Piot, Mohammad Azar, David Silver, Rainbow: Combining Improvements in Deep Reinforcement Learning, Proc. AAAI 32, 3215 (2018), [arXiv]
  • [7] Richard S. Sutton, Andrew G. Barto, Reinforcement Learning: An Introduction, MIT Press 2nd ed. (2018)