Custom Network Architectures¶
This tutorial shows how to implement custom actors and critics for the current TensorDict-capable ApexRL stack.
Overview¶
Base classes:
ContinuousActorfor continuous policiesDiscreteActorfor discrete policiesCriticfor value networksContinuousQNetworkfor SAC-styleQ(s, a)critics
What Your Network Receives¶
The current repository version supports structured observations.
If your environment returns:
{
"obs": {
"image": image,
"vector": vector,
},
"privileged_obs": {
"state": state,
"context": context,
},
}
then:
PPO actor receives
{"image": ..., "vector": ...}PPO asymmetric critic receives
{"state": ..., "context": ...}SAC actor receives
obsSAC critics receive
privileged_obswhen present, otherwise the actor observation
Continuous Actor Example¶
import torch
import torch.nn as nn
from apexrl.models.base import ContinuousActor
class MultiModalContinuousActor(ContinuousActor):
def __init__(self, obs_space, action_space, cfg=None):
super().__init__(obs_space, action_space, cfg)
cfg = cfg or {}
image_shape = obs_space["image"].shape
vector_dim = obs_space["vector"].shape[0]
hidden_dim = cfg.get("hidden_dim", 256)
self.image_encoder = nn.Sequential(
nn.Conv2d(image_shape[0], 16, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten(),
)
with torch.no_grad():
dummy = torch.zeros(1, *image_shape)
image_dim = self.image_encoder(dummy).shape[-1]
self.vector_encoder = nn.Sequential(
nn.Linear(vector_dim, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
)
self.backbone = nn.Sequential(
nn.Linear(image_dim + 64, hidden_dim),
nn.ReLU(),
)
self.mean_head = nn.Linear(hidden_dim, self.action_dim)
self.log_std = nn.Parameter(torch.zeros(self.action_dim))
def forward(self, obs):
image_feat = self.image_encoder(obs["image"])
vector_feat = self.vector_encoder(obs["vector"])
fused = torch.cat([image_feat, vector_feat], dim=-1)
return self.mean_head(self.backbone(fused))
def get_action_dist(self, obs):
mean = self.forward(obs)
std = torch.exp(self.log_std).expand_as(mean)
return torch.distributions.Normal(mean, std)
Discrete Actor Example¶
from apexrl.models.base import DiscreteActor
class MultiModalDiscreteActor(DiscreteActor):
def __init__(self, obs_space, action_space, cfg=None):
super().__init__(obs_space, action_space, cfg)
image_shape = obs_space["image"].shape
vector_dim = obs_space["vector"].shape[0]
self.image_encoder = nn.Sequential(
nn.Conv2d(image_shape[0], 16, 3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten(),
)
with torch.no_grad():
dummy = torch.zeros(1, *image_shape)
image_dim = self.image_encoder(dummy).shape[-1]
self.head = nn.Sequential(
nn.Linear(image_dim + vector_dim, 256),
nn.ReLU(),
nn.Linear(256, self.num_actions),
)
def forward(self, obs):
image_feat = self.image_encoder(obs["image"])
return self.head(torch.cat([image_feat, obs["vector"]], dim=-1))
def get_action_dist(self, obs):
logits = self.forward(obs)
return torch.distributions.Categorical(logits=logits)
Critic Example¶
For a privileged critic:
import torch
import torch.nn as nn
from apexrl.models.base import Critic
class PrivilegedCritic(Critic):
def __init__(self, obs_space, cfg=None):
super().__init__(obs_space, cfg)
state_dim = obs_space["state"].shape[0]
context_dim = obs_space["context"].shape[0]
self.network = nn.Sequential(
nn.Linear(state_dim + context_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 1),
)
def forward(self, obs):
x = torch.cat([obs["state"], obs["context"]], dim=-1)
return self.network(x).squeeze(-1)
def get_value(self, obs):
return self.forward(obs)
Using Custom Networks¶
from apexrl.agent.on_policy_runner import OnPolicyRunner
from apexrl.algorithms.ppo import PPOConfig
runner = OnPolicyRunner(
env=env,
cfg=PPOConfig(use_asymmetric=True, device="cpu"),
actor_class=MultiModalDiscreteActor,
critic_class=PrivilegedCritic,
actor_cfg={"hidden_dim": 256},
log_dir="./logs",
)
Recurrent PPO Networks¶
RecurrentPPO uses the same actor_class / critic_class injection
pattern. Custom recurrent networks should follow the RecurrentPPO actor and
critic interfaces. ApexRL includes GRUActor, GRUDiscreteActor and
GRUCritic as reference implementations.
from apexrl.agent import OnPolicyRunner
from apexrl.algorithms.ppo import RecurrentPPOConfig
from apexrl.models import GRUCritic, GRUDiscreteActor
runner = OnPolicyRunner(
env=env,
algorithm="recurrent_ppo",
cfg=RecurrentPPOConfig(
num_steps=64,
sequence_length=16,
recurrent_minibatch_size=256,
),
actor_class=GRUDiscreteActor,
critic_class=GRUCritic,
actor_cfg={"hidden_dims": [128], "rnn_hidden_size": 128},
critic_cfg={"hidden_dims": [128], "rnn_hidden_size": 128},
)
Multi-Agent Custom Networks¶
MAPPO, IPPO and HAPPO use the same custom network pattern as PPO: pass ApexRL actor and critic classes, plus optional configuration dictionaries. The network classes can contain arbitrary PyTorch modules internally while inheriting ApexRL’s actor or critic base interfaces.
from apexrl.models.base import Critic, DiscreteActor
from apexrl.multiagent import HAPPO, HAPPOConfig, IPPO, IPPOConfig, MAPPO, MAPPOConfig
class EntityDiscreteActor(DiscreteActor):
def __init__(self, obs_space, action_space, cfg=None):
super().__init__(obs_space, action_space, cfg)
# Build attention, graph, CNN or MLP blocks here.
def forward(self, obs):
...
def get_action_dist(self, obs):
logits = self.forward(obs)
return torch.distributions.Categorical(logits=logits)
class EntityCritic(Critic):
def __init__(self, obs_space, cfg=None):
super().__init__(obs_space, cfg)
# Build the value network here.
def forward(self, obs):
...
def get_value(self, obs):
return self.forward(obs)
mappo_agent = MAPPO(
env=multiagent_env,
cfg=MAPPOConfig(
centralized_critic=True,
share_actor=True,
share_critic=True,
),
actor_class=EntityDiscreteActor,
critic_class=EntityCritic,
actor_cfg={"hidden_dim": 256},
critic_cfg={"hidden_dim": 512},
)
ippo_agent = IPPO(
env=multiagent_env,
cfg=IPPOConfig(share_actor=True, share_critic=True),
actor_class=EntityDiscreteActor,
critic_class=EntityCritic,
)
happo_agent = HAPPO(
env=multiagent_env,
cfg=HAPPOConfig(centralized_critic=True, share_actor=False),
actor_class=EntityDiscreteActor,
critic_class=EntityCritic,
)
For MAPPO with centralized_critic=True, the critic receives
env.state_space and env.get_state() outputs. For IPPO, or MAPPO with
centralized_critic=False, each critic receives that agent’s local
observation space and local observations. Actors always receive per-agent local
observations.
Parameter sharing is controlled by the multi-agent config:
share_actor=Truecreates one actor instance and reuses it for all agents. This requires identical observation and action spaces across agents.HAPPO uses
share_actor=Falseso each agent has a separate policy for its sequential update.share_critic=Truecreates one critic instance and reuses it for all agents. With decentralized critics, this requires identical observation spaces; with centralized critics, all critics consume the shared state space.Set either flag to
Falsewhen agents need separate model parameters.
You can also pass already constructed ApexRL actor/critic objects through the
models dictionary. This is useful for heterogeneous agents or for manually
sharing selected modules:
models = {
"agent_0": {"policy": actor_0, "value": critic_0},
"agent_1": {"policy": actor_1, "value": critic_1},
}
agent = MAPPO(
possible_agents=["agent_0", "agent_1"],
observation_spaces=observation_spaces,
action_spaces=action_spaces,
state_space=state_space,
models=models,
cfg=MAPPOConfig(centralized_critic=True),
)
For more complex multi-agent networks, prefer representing entities with fixed
structured observations, such as spaces.Dict entries, padded tensors and
masks. Attention, DeepSets, GNN and transformer-style encoders can then be
implemented inside the actor or critic without changing MAPPO/IPPO/HAPPO.
Best Practices¶
Keep the observation structure explicit instead of flattening everything immediately.
Let the environment expose actor and critic groups as
obsandprivileged_obs.Use
DiscreteActorfor discrete PPO andContinuousActorfor continuous PPO.For SAC, keep custom critics in
ContinuousQNetworkform, i.e.forward(obs, actions).Prefer explicit branch encoders for image + vector inputs instead of a single flat MLP.
For MAPPO/IPPO/HAPPO, use local per-agent observations in actors and choose centralized or decentralized critic observations through the algorithm config.