popgym.baselines.ray_models.base_model

Module Contents

Classes

BaseModel

The base memory model that all other memory models should

Functions

weight_init(m)

popgym.baselines.ray_models.base_model.weight_init(m)
class popgym.baselines.ray_models.base_model.BaseModel(obs_space: gymnasium.spaces.Space, action_space: gymnasium.spaces.Space, num_outputs: int, model_config: ray.rllib.utils.typing.ModelConfigDict, name: str, **custom_model_kwargs)

Bases: ray.rllib.models.torch.torch_modelv2.TorchModelV2, torch.nn.Module

The base memory model that all other memory models should inherit from.

You will need to override initial_state and memory forward, as well as set your model defaults using MODEL_CONFIG.

BASE_CONFIG
MODEL_CONFIG: Dict[str, Any]
setup_embedding(max_seq_len: int)

Setup positional embedding module

Args:

max_seq_len (int): Maximum sequence length

abstract initial_state() List[ray.rllib.utils.typing.TensorType]

Return the initial states for your memory model.

The shape of states returned here should NOT contain the batch dimension. The batch dimension will be prepended by RLlib, depending on the number of episodes/rollouts in the batch

Returns:

List of tensors denoting the t=0 recurrent state

get_initial_state() List[ray.rllib.utils.typing.TensorType]
value_function() ray.rllib.utils.typing.TensorType
preprocess_obs(obs, seq_lens)

Preprocess the incoming observation of shape [?, Feature] into [Batch, Time, Feature]

abstract forward_memory(z: ray.rllib.utils.typing.TensorType, state: List[ray.rllib.utils.typing.TensorType], t_starts: ray.rllib.utils.typing.TensorType, seq_lens: ray.rllib.utils.typing.TensorType) Tuple[ray.rllib.utils.typing.TensorType, List[ray.rllib.utils.typing.TensorType]]

Forward for your custom memory model.

Args:
z: Preprocessed features of shape [B, T, F], with padding along the time

dimension

state: Recurrent states of shape [B, …] t_starts: Tensor of size [B] denoteint the length of the rollout so far.

Note that in some cases RLlib might chunk a long rollout into multiple forward passes. This tracks the length throughout all forward passes.

seq_lens: Tensor of size [B] denoting the number of non-padding elements

in z. E.g. seq_lens == [100, 60], z.shape[1] == 128 means the first 100 elements of the first batch dimension are valid and the first 60 elements of the second batch dimension are valid. Rest are padding.

Returns:
(output, state)

where output is [B, T, D] and state is [B, …]. Note that the padding must be present in the output. The state must be exactly the same shape as the input state.

apply_embedding(x, t_starts)

Applies temporal embeddings to x

Args:

x: Tensor of shape [B, T, F] t_starts: Tensor of shape [B] denoting the time offset

of the first element in x

Returns:

Tensor of shape [B, T, F] with the temporal embeddings

forward(input_dict: Dict[str, ray.rllib.utils.typing.TensorType], state: List[ray.rllib.utils.typing.TensorType], seq_lens: ray.rllib.utils.typing.TensorType) Tuple[ray.rllib.utils.typing.TensorType, List[ray.rllib.utils.typing.TensorType]]