Buffers¶
ApexRL provides efficient buffer implementations for storing and processing training data.
Overview¶
Available buffer types:
RolloutBuffer - On-policy data storage (PPO)
ReplayBuffer - Off-policy data storage (DQN, SAC)
DistillationBuffer - Policy distillation data
RolloutBuffer¶
The RolloutBuffer stores trajectories collected during environment interaction for on-policy algorithms like PPO.
Key Features¶
Efficient tensor storage on GPU
Support for multi-dimensional observations
Support for structured
TensorDict/ nested dict observationsSupport for both scalar discrete and multi-dimensional continuous actions
Generalized Advantage Estimation (GAE)
Privileged observations for asymmetric actor-critic
Basic Usage¶
from apexrl.buffer.rollout_buffer import RolloutBuffer
buffer = RolloutBuffer(
num_envs=4096,
num_steps=24,
obs_shape=(48,),
action_shape=(12,),
action_dtype=torch.float32,
device="cuda",
num_privileged_obs=0,
)
# Collect data
for step in range(24):
actions, log_probs = actor.act(obs)
next_obs, rewards, dones, extras = env.step(actions)
values = critic.get_value(obs)
buffer.add(
observations=obs,
privileged_observations=None,
actions=actions,
rewards=rewards,
dones=dones.float(),
values=values,
log_probs=log_probs,
)
obs = next_obs
# Compute advantages
last_values = critic.get_value(obs)
buffer.compute_returns_and_advantages(
last_values=last_values,
gamma=0.99,
gae_lambda=0.95,
)
# Get training data
data = buffer.get_all_data()
API Reference¶
Data Flow¶
The rollout data flow:
Environment Step → Store Transition → GAE Computation → Training
↓
┌─────────────┐
| observations |
| actions |
| rewards |
| dones |
| values |
| log_probs |
└─────────────┘
↓
┌─────────────┐
| advantages |
| returns |
└─────────────┘
Memory Layout¶
Stored tensors have shape (num_steps, num_envs, ...):
# Observations: (num_steps, num_envs, *obs_shape)
self.observations # Shape: (24, 4096, 48)
# Continuous actions: (num_steps, num_envs, *action_shape)
self.actions # Shape: (24, 4096, 12)
# Scalars: (num_steps, num_envs)
self.rewards # Shape: (24, 4096)
self.dones # Shape: (24, 4096)
self.values # Shape: (24, 4096)
self.log_probs # Shape: (24, 4096)
self.advantages # Shape: (24, 4096)
self.returns # Shape: (24, 4096)
GAE Computation¶
Generalized Advantage Estimation is computed backwards:
where \(\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)\) is the TD error.
def compute_returns_and_advantages(self, last_values, gamma=0.99, gae_lambda=0.95):
advantages = torch.zeros_like(self.rewards)
last_gae = torch.zeros(self.num_envs, device=self.device)
for t in reversed(range(self.num_steps)):
if t == self.num_steps - 1:
next_values = last_values
else:
next_values = self.values[t + 1]
delta = self.rewards[t] + gamma * next_values * (1 - self.dones[t]) - self.values[t]
last_gae = delta + gamma * gae_lambda * (1 - self.dones[t]) * last_gae
advantages[t] = last_gae
self.advantages = advantages
self.returns = advantages + self.values
Timeout bootstrapping is handled before transitions are written into the buffer,
so dones in the stored rollout reflect the bootstrap mask used by PPO.
ReplayBuffer¶
For off-policy algorithms such as DQN and SAC:
from apexrl.buffer.replay_buffer import ReplayBuffer
buffer = ReplayBuffer(
capacity=1_000_000,
obs_shape=(4,),
action_shape=(),
device="cuda",
)
# Store transition
buffer.add(obs, action, reward, next_obs, done)
# Sample batch
batch = buffer.sample(batch_size=256)
For discrete-action DQN, action_shape=() stores scalar action indices.
For SAC and other continuous-control off-policy algorithms, set
action_shape to the vector action shape so replay stores full actions.
Replay also supports storing a separate critic observation branch. In the current SAC implementation this is used to keep actor observations and privileged critic observations separate inside replay.
API Reference¶
DistillationBuffer¶
For policy distillation and imitation learning.
Note
DistillationBuffer is still planned and is not implemented in the
runtime package yet. This section documents the intended scope only.
API Reference¶
See apexrl.buffer.distillation_buffer module for the current module status.
Best Practices¶
Pre-allocate: Buffers pre-allocate memory for efficiency
Device Placement: Keep buffers on the same device as models
Clear Buffers: Call
clear()between rolloutsBatch Size: Ensure batch size divides total transitions evenly
GAE Lambda: Typical values are 0.9-0.95
See Also¶
apexrl.buffer package - Full API reference
Algorithms - Algorithm implementations