A Comprehensive Overview of Q-Learning and Actor-Critic Methods
A Comprehensive Overview of Q-Learning and Actor-Critic Methods
Table of Contents
- 1. Q-Learning: A Foundational Approach
- 2. Basic Actor-Critic
- 3. Neural Network Parameterization
- 4. Deep Deterministic Policy Gradient (DDPG)
- 5. Twin Delayed Deep Deterministic Policy Gradient (TD3)
- 6. Proximal Policy Optimization (PPO)
- 7. Soft Actor-Critic (SAC)
- 8. Asynchronous Advantage Actor-Critic (A3C)
- 9. A3C full example:
- 9. Evolution Timeline of the Methods
- 10. Concluding Remarks
1. Q-Learning: A Foundational Approach
Q-Learning attempts to learn the state-action value function:
For a discrete action space, we can keep a table or a neural network to represent .
1.1 Bellman Update
The core Bellman optimality update is:
Python Snippet
import numpy as np
# Suppose Q is a table (2D array: Q[state, action])def q_learning_update(Q, s, a, r, s_next, alpha, gamma): # Q[s, a] = Q[s, a] + alpha * [r + gamma * max(Q[s_next, :]) - Q[s, a]] Q[s, a] += alpha * (r + gamma * np.max(Q[s_next]) - Q[s, a])1.2 Deep Q-Network (DQN)
In deep Q-learning, we approximate with a neural network . The loss to minimize is:
Here, and denote online and target network parameters.
Python Snippet
import torchimport torch.nn as nnimport torch.optim as optim
class QNetwork(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim=64): super(QNetwork, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) self.fc3 = nn.Linear(hidden_dim, action_dim) # Q(s) -> R^action_dim
def forward(self, s): x = torch.relu(self.fc1(s)) x = torch.relu(self.fc2(x)) return self.fc3(x)
def dqn_loss(q_network, target_network, batch, gamma): # batch includes (s, a, r, s_next, done) s, a, r, s_next, done = batch
# Q(s, a) q_values = q_network(s).gather(1, a.unsqueeze(1)).squeeze(1)
# target = r + gamma * max_a' Q^- (s_next, a') if not done with torch.no_grad(): q_next = target_network(s_next).max(1)[0] q_target = r + gamma * q_next * (1 - done.float())
loss = nn.MSELoss()(q_values, q_target) return lossThus, Q-learning (and its deep counterpart) is typically off-policy because it learns about the greedy policy while potentially following a different data-collecting policy (e.g., -greedy).
2. Basic Actor-Critic
Unlike Q-learning, actor-critic methods maintain:
- A policy (the “actor”).
- A value function or -function (the “critic”).
2.1 Policy Gradient Theory
We want to maximize the expected return:
The policy gradient theorem says:
where is a baseline (often the value function).
Python Snippet
import torch
def policy_gradient_loss(log_probs, returns, baselines): # log_probs: tensor of log π(a|s) for each step # returns: tensor of G_t # baselines: tensor of b(s), often V(s)
advantage = returns - baselines # actor loss = - E[ log pi(a|s) * advantage ] loss = -(log_probs * advantage).mean() return loss2.2 Critic Objective
The critic (value-based) is learned via MSE:
Python Snippet
import torch.nn as nn
def value_loss(value_net, states, returns): # value_net(s) -> scalar V_phi(s) v = value_net(states) loss = nn.MSELoss()(v.squeeze(), returns) return loss3. Neural Network Parameterization
Below is a typical two-layer MLP for both actor and critic, with explicit shapes:
-
Actor
- First hidden layer:
- Second hidden layer:
- Output layer depends on discrete vs. continuous actions.
-
Critic or
- Similarly a two-layer MLP.
- Output dimension = (scalar).
Python Snippet for a 2-Layer MLP
class MLPActor(nn.Module): def __init__(self, state_dim, action_dim, hidden_dim=64, discrete=False): super(MLPActor, self).__init__() self.fc1 = nn.Linear(state_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, hidden_dim) # Discrete: outputs action_dim logits # Continuous: outputs mean, possibly log std self.discrete = discrete if discrete: self.fc_out = nn.Linear(hidden_dim, action_dim) else: # For continuous, let's say we output mean only self.fc_mean = nn.Linear(hidden_dim, action_dim) # could also have self.log_std = nn.Parameter(...)
def forward(self, s): x = torch.relu(self.fc1(s)) x = torch.relu(self.fc2(x)) if self.discrete: logits = self.fc_out(x) return logits # use softmax outside else: mean = self.fc_mean(x) return mean4. Deep Deterministic Policy Gradient (DDPG)
4.1 Deterministic Actor, Q-Critic
- Actor: .
- Critic: .
- Critic Loss:
Python Snippet
def ddpg_critic_loss(q_net, target_q_net, batch, gamma, actor, target_actor): s, a, r, s_next, done = batch q_vals = q_net(s, a) with torch.no_grad(): a_next = target_actor(s_next) q_next = target_q_net(s_next, a_next) q_target = r + gamma * q_next * (1 - done) loss = nn.MSELoss()(q_vals, q_target) return loss- Actor Update uses the deterministic policy gradient:
Python Snippet
def ddpg_actor_loss(q_net, states, actor): # actor(s) -> a a = actor(states) # compute d/d(a) of Q(s,a), then chain rule q_vals = q_net(states, a) # we want to maximize Q, so minimize negative loss = -q_vals.mean() return lossDDPG is off-policy and uses a replay buffer plus target networks to improve stability.
5. Twin Delayed Deep Deterministic Policy Gradient (TD3)
5.1 Twin Critics
To reduce overestimation in DDPG, TD3 uses two critics:
The critic target is:
Python Snippet
def td3_critic_loss(q1, q2, q1_target, q2_target, batch, gamma, actor_target): s, a, r, s_next, done = batch
with torch.no_grad(): a_next = actor_target(s_next) q1_next = q1_target(s_next, a_next) q2_next = q2_target(s_next, a_next) q_next_min = torch.min(q1_next, q2_next) q_target = r + gamma * q_next_min * (1 - done)
loss1 = nn.MSELoss()(q1(s, a), q_target) loss2 = nn.MSELoss()(q2(s, a), q_target) return loss1 + loss25.2 Delayed Updates
TD3 updates the actor (and target networks) every few critic steps, reducing variance.
6. Proximal Policy Optimization (PPO)
6.1 Probability Ratio and Clipping
PPO is an on-policy method. We define:
The clipped objective is:
Python Snippet
def ppo_clip_loss(pi_new, pi_old, actions, advantages, epsilon=0.2): # pi_new, pi_old: probability of actions under new/old policy ratio = pi_new / (pi_old + 1e-8) unclipped = ratio * advantages clipped = torch.clamp(ratio, 1-epsilon, 1+epsilon) * advantages loss = -torch.min(unclipped, clipped).mean() return loss7. Soft Actor-Critic (SAC)
7.1 Maximum Entropy RL
SAC encourages exploration via an entropy term . The objective is:
7.2 Two Critics
Like TD3, SAC uses twin critics . The target is:
with .
Python Snippet
def sac_critic_loss(q1, q2, q1_target, q2_target, batch, alpha, gamma, actor_target): s, a, r, s_next, done = batch with torch.no_grad(): # sample new action from actor_target a_next = actor_target(s_next) # log pi(a_next|s_next) logp_a_next = actor_target.log_prob(s_next, a_next) q1_next = q1_target(s_next, a_next) q2_next = q2_target(s_next, a_next) q_next_min = torch.min(q1_next, q2_next) q_target = r + gamma * (q_next_min - alpha * logp_a_next) * (1 - done)
loss1 = nn.MSELoss()(q1(s, a), q_target) loss2 = nn.MSELoss()(q2(s, a), q_target) return loss1 + loss28. Asynchronous Advantage Actor-Critic (A3C)
8.1 Parallelization Insight
A3C runs multiple worker processes, each with local copies of and . They asynchronously update the shared global parameters.
8.2 Advantage Actor-Critic Loss
A typical A3C loss (value-based critic) is:
Python Snippet
def a3c_loss(policy_net, value_net, states, actions, returns): # policy_net -> log pi(a|s), value_net -> V(s) log_probs = policy_net.log_prob(states, actions) values = value_net(states).squeeze() advantage = returns - values
actor_loss = - (log_probs * advantage).mean() critic_loss = advantage.pow(2).mean() return actor_loss + critic_lossA3C’s asynchronous updates help decorrelate data and speed up training on CPUs.
9. A3C full example:
# -*- coding: utf-8 -*-"""Refactored A3C Example: - Renamed classes, functions, and variables for clarity. - Maintains the original multi-process A3C logic with shared parameters.
Author: <Your Name>"""
import gymnasium as gymimport torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.multiprocessing as mpfrom torch.distributions import Categorical
############################################################################### 1) UTILITY: A safer environment "step" handling older/newer Gym returns.##############################################################################def safe_step_env(env, action): """ Step the environment while handling Gym/Gymnasium return formats: - Some return (obs, reward, done, info) - Gymnasium can return (obs, reward, done, truncated, info) """ results = env.step(action) if len(results) == 5: next_obs, reward, done, truncated, info = results done = done or truncated else: next_obs, reward, done, info = results return next_obs, reward, done, info
############################################################################### 2) SHARED OPTIMIZER##############################################################################class SharedAdam(torch.optim.Adam): """ A custom Adam optimizer that uses shared memory for the moving averages, allowing multiple processes to update a shared global set of parameters. """
def __init__(self, parameters, lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0): super(SharedAdam, self).__init__(parameters, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) for group in self.param_groups: for param in group['params']: state = self.state[param] state['step'] = 0 state['exp_avg'] = torch.zeros_like(param.data) state['exp_avg_sq'] = torch.zeros_like(param.data) # Share memory so that parallel processes can update # the optimizer states state['exp_avg'].share_memory_() state['exp_avg_sq'].share_memory_()
############################################################################### 3) A3C NETWORK (Actor + Critic)##############################################################################class A3CNetwork(nn.Module): """ An Actor-Critic network that outputs: - A policy distribution (pi) for selecting actions. - A value function (v) for estimating state value. """
def __init__(self, state_dim, num_actions, gamma=0.99): super(A3CNetwork, self).__init__() self.gamma = gamma
self.actor_hidden = nn.Linear(*state_dim, 128) self.critic_hidden = nn.Linear(*state_dim, 128)
self.actor_out = nn.Linear(128, num_actions) self.critic_out = nn.Linear(128, 1)
# We'll store trajectories in memory buffers between updates self.memory_states = [] self.memory_actions = [] self.memory_rewards = []
def store_experience(self, state, action, reward): """Cache a transition (state, action, reward).""" self.memory_states.append(state) self.memory_actions.append(action) self.memory_rewards.append(reward)
def reset_memory(self): """Clear trajectory buffers.""" self.memory_states = [] self.memory_actions = [] self.memory_rewards = []
def forward(self, state): """ Forward pass: state: a torch.Tensor of shape [batch_size, *state_dim]. Returns: - policy_logits (for the actor) - state_value (scalar per batch item, for the critic) """ actor_hidden_out = F.relu(self.actor_hidden(state)) critic_hidden_out = F.relu(self.critic_hidden(state))
policy_logits = self.actor_out(actor_hidden_out) state_value = self.critic_out(critic_hidden_out) return policy_logits, state_value
def compute_returns(self, done): """ Compute discounted returns for each step in the trajectory. If the episode ended (done=True), final state's value is 0. Otherwise, we bootstrap from the last state's value. """ states_t = torch.tensor(self.memory_states, dtype=torch.float) _, values_t = self.forward(states_t) values_t = values_t.squeeze(-1) # shape [trajectory_len]
# If done, final state's value is 0; else we take the last state's value final_value = 0.0 if done else values_t[-1].item()
returns = [] discounted_return = final_value for reward in reversed(self.memory_rewards): discounted_return = reward + self.gamma * discounted_return returns.append(discounted_return) returns.reverse() return torch.tensor(returns, dtype=torch.float)
def compute_loss(self, done): """ Compute combined Actor + Critic loss over the stored trajectory. """ states_t = torch.tensor(self.memory_states, dtype=torch.float) actions_t = torch.tensor(self.memory_actions, dtype=torch.long) returns_t = self.compute_returns(done)
policy_logits, values_t = self.forward(states_t) values_t = values_t.squeeze(-1) # shape [trajectory_len]
# Critic loss (Mean Squared Error) critic_loss = (returns_t - values_t) ** 2
# Actor loss (REINFORCE with advantage) probabilities = F.softmax(policy_logits, dim=1) dist = Categorical(probabilities) log_probs = dist.log_prob(actions_t)
advantages = returns_t - values_t.detach() # no grad wrt. values for advantage actor_loss = -log_probs * advantages
total_loss = (critic_loss + actor_loss).mean() return total_loss
def select_action(self, observation): """ Sample an action according to the policy's distribution. observation: raw state from the environment (np array or float list). """ observation_t = torch.tensor([observation], dtype=torch.float) policy_logits, _ = self.forward(observation_t) probabilities = F.softmax(policy_logits, dim=1) dist = Categorical(probabilities) action = dist.sample().item() return action
############################################################################### 4) WORKER AGENT (One process per agent)##############################################################################class A3CWorker(mp.Process): """ Each worker interacts with an environment instance, accumulates experience, and updates the global A3C network's parameters. """
def __init__(self, global_network, global_optimizer, state_dim, num_actions, gamma, lr, worker_id, global_episode_counter, env_id, max_episodes, update_interval): super(A3CWorker, self).__init__() self.local_network = A3CNetwork(state_dim, num_actions, gamma) self.global_network = global_network self.global_optimizer = global_optimizer
self.worker_name = f"worker_{worker_id:02d}" self.episode_counter = global_episode_counter
self.env = gym.make(env_id) self.max_episodes = max_episodes self.update_interval = update_interval
def run(self): step_count = 1 while self.episode_counter.value < self.max_episodes: state, _info = self._reset_env() # handle Gymnasium reset done = False episode_return = 0.0
self.local_network.reset_memory()
while not done: action = self.local_network.select_action(state) next_state, reward, done, info = safe_step_env(self.env, action)
episode_return += reward self.local_network.store_experience(state, action, reward)
# Update global network after 'update_interval' steps or on episode done if step_count % self.update_interval == 0 or done: loss = self.local_network.compute_loss(done)
self.global_optimizer.zero_grad() loss.backward()
# Copy grads from local to global for local_param, global_param in zip( self.local_network.parameters(), self.global_network.parameters()): global_param._grad = local_param.grad
self.global_optimizer.step()
# Sync local network with updated global parameters self.local_network.load_state_dict(self.global_network.state_dict()) self.local_network.reset_memory()
step_count += 1 state = next_state
# Increment global episode counter with self.episode_counter.get_lock(): self.episode_counter.value += 1
print(f"{self.worker_name} | Episode: {self.episode_counter.value} " f"| Return: {episode_return:.1f}")
def _reset_env(self): """ Reset environment (Gym/Gymnasium). Handles new reset() returning (obs, info). """ initial_obs = self.env.reset() if isinstance(initial_obs, tuple) and len(initial_obs) == 2: obs, info = initial_obs else: obs = initial_obs info = {} return obs, info
############################################################################### 5) WATCH A TRAINED AGENT##############################################################################def watch_agent(global_network, env_id="CartPole-v1", episodes_to_watch=5): """ Renders a few episodes using the global A3C network's parameters. """ env = gym.make(env_id, render_mode="human")
# Local copy for inference local_network = A3CNetwork([4], 2) # For CartPole: state_dim=[4], num_actions=2 local_network.load_state_dict(global_network.state_dict())
for ep in range(episodes_to_watch): state, _info = env.reset() done = False episode_return = 0.0 while not done: # For older Gym versions, you might do env.render() here action = local_network.select_action(state) state, reward, done, info = safe_step_env(env, action) episode_return += reward
print(f"Watch Episode {ep + 1}, Return: {episode_return:.1f}")
env.close()
############################################################################### 6) MAIN: TRAINING LOGIC##############################################################################if __name__ == '__main__':
LEARNING_RATE = 1e-4 ENV_ID = "CartPole-v1" env = gym.make(ENV_ID)
# Observation space shape => typically something like (4,) state_dim = env.observation_space.shape
# Action space => for Discrete(n), .n is the number of possible actions num_actions = env.action_space.n MAX_EPISODES = 3000 UPDATE_INTERVAL = 500
# Create global (shared) A3C network shared_network = A3CNetwork(state_dim, num_actions) shared_network.share_memory()
# Create shared optimizer shared_optimizer = SharedAdam(shared_network.parameters(), lr=LEARNING_RATE, betas=(0.92, 0.999))
global_episode_counter = mp.Value('i', 0)
# Spawn worker processes num_cpus = mp.cpu_count() workers = [] for cpu_id in range(num_cpus): worker = A3CWorker( global_network=shared_network, global_optimizer=shared_optimizer, state_dim=state_dim, num_actions=num_actions, gamma=0.99, lr=LEARNING_RATE, worker_id=cpu_id, global_episode_counter=global_episode_counter, env_id=ENV_ID, max_episodes=MAX_EPISODES, update_interval=UPDATE_INTERVAL ) workers.append(worker)
# Start and join each worker for w in workers: w.start() for w in workers: w.join()
print("Training complete. Now let's watch the agent in action!") watch_agent(shared_network, env_id=ENV_ID, episodes_to_watch=5000)9. Evolution Timeline of the Methods
-
Q-Learning:
- Learns using the Bellman update.
- Great for discrete action spaces (DQN for deep version).
- Off-policy, can be inefficient for large continuous spaces.
-
Actor-Critic (baseline):
- Combines policy gradient with a critic to reduce variance.
- Works in both discrete and continuous settings.
-
**DDPG **:
- Deterministic policy + replay buffer + target networks for continuous control.
- Issue: Overestimation, sensitive hyperparameters.
-
**A3C **:
- Multiple asynchronous workers for faster training.
- No replay buffer, but can have higher variance.
-
**TD3 **:
- Twin critics + delayed updates to reduce overestimation in DDPG.
- Deterministic, needs exploration noise.
-
**PPO **:
- On-policy with clipped objective for stable learning.
- Popular and relatively easy to tune.
-
**SAC **:
- Maximum entropy RL for robust exploration.
- Twin critics to reduce overestimation.
- Often state-of-the-art in continuous control tasks.
Hence, each method emerges to address specific challenges:
- Overestimation (TD3, SAC).
- Exploration (SAC’s entropy).
- Stability (PPO clipping, twin critics).
- Efficiency (replay buffers, asynchronous runs).
10. Concluding Remarks
- Q-learning (and DQN) forms the foundation for many discrete-action RL approaches.
- Actor-Critic methods extend naturally to continuous actions and can reduce variance with a learned critic.
- DDPG introduced a deterministic actor with an off-policy, replay-buffer approach, later refined by TD3 to address overestimation.
- PPO simplified stable on-policy learning with a clipped objective.
- SAC combined twin critics with maximum entropy to encourage robust exploration.
- A3C leveraged asynchronous CPU processes to speed up training without replay buffers.