Train DQN¶
This tutorial shows the standard way to train a DQN agent in ApexRL.
Overview¶
The recommended DQN stack in this repository is:
GymVecEnvfor discrete-action Gymnasium tasksOffPolicyRunneras the canonical training entrypointMLPQNetworkas the default Q-network baseline
Prerequisites¶
Install ApexRL and Gymnasium:
pip install -e .
Environment Setup¶
For DQN, start with a discrete-control environment such as CartPole-v1.
import gymnasium as gym
from apexrl.envs.gym_wrapper import GymVecEnv
def make_env():
return gym.make("CartPole-v1")
env = GymVecEnv([make_env for _ in range(2)], device="cpu")
Build the Runner¶
OffPolicyRunner creates the DQN agent, fills replay, and schedules updates.
from apexrl.agent.off_policy_runner import OffPolicyRunner
from apexrl.algorithms.dqn import DQNConfig
from apexrl.models import MLPQNetwork
cfg = DQNConfig(
batch_size=128,
buffer_size=100_000,
learning_starts=1_000,
target_update_interval=250,
double_dqn=True,
dueling=True,
log_interval=1_000,
save_interval=10_000,
)
runner = OffPolicyRunner(
env=env,
cfg=cfg,
algorithm="dqn",
q_network_class=MLPQNetwork,
log_dir="./logs/dqn_cartpole",
save_dir="./checkpoints/dqn_cartpole",
)
Train¶
runner.learn(total_timesteps=50_000)
Evaluate and Save¶
stats = runner.eval(num_episodes=10)
print(f"Mean reward: {stats['eval/mean_reward']:.2f}")
runner.save_checkpoint("dqn_cartpole_final.pt")
env.close()
Complete Example¶
import gymnasium as gym
from apexrl.agent.off_policy_runner import OffPolicyRunner
from apexrl.algorithms.dqn import DQNConfig
from apexrl.envs.gym_wrapper import GymVecEnv
from apexrl.models import MLPQNetwork
def make_env():
return gym.make("CartPole-v1")
env = GymVecEnv([make_env for _ in range(2)], device="cpu")
cfg = DQNConfig(
batch_size=128,
buffer_size=100_000,
learning_starts=1_000,
target_update_interval=250,
double_dqn=True,
dueling=True,
)
runner = OffPolicyRunner(
env=env,
cfg=cfg,
algorithm="dqn",
q_network_class=MLPQNetwork,
log_dir="./logs/dqn_cartpole",
)
runner.learn(total_timesteps=50_000)
print(runner.eval(num_episodes=10))
runner.save_checkpoint("dqn_cartpole_final.pt")
env.close()
Notes¶
OffPolicyRunneris the preferred training entrypoint for DQN.double_dqn=Trueis enabled by default and should usually stay on.Set
dueling=Trueto switchMLPQNetworkto the dueling architecture.
Next Steps¶
Read Train PPO for the on-policy training flow
Read Train SAC for continuous-control off-policy training
Read Algorithms for DQN-specific options
Read Runners for runner API details