popgym.baselines.ray_models.ray_framestack
Module Contents
Classes
Concatenate sequential frames along the feature dimension, |
- class popgym.baselines.ray_models.ray_framestack.Framestack(obs_space: gymnasium.spaces.Space, action_space: gymnasium.spaces.Space, num_outputs: int, model_config: ray.rllib.utils.typing.ModelConfigDict, name: str, **custom_model_kwargs)
Bases:
popgym.baselines.ray_models.base_model.BaseModelConcatenate sequential frames along the feature dimension, pushing them through an MLP. From
@article{mnih_human-level_2015, title = {Human-level control through deep reinforcement learning}, volume = {518}, number = {7540}, journal = {nature}, author = { Mnih, Volodymyr and Kavukcuoglu, Koray and Silver, David and Rusu, Andrei A and Veness, Joel and Bellemare, Marc G and Graves, Alex and Riedmiller, Martin and Fidjeland, Andreas K and Ostrovski, Georg and {others} }, year = {2015}, note = {Publisher: Nature Publishing Group}, pages = {529--533}, }- MODEL_CONFIG
- initial_state() List[ray.rllib.utils.typing.TensorType]
Return the initial states for your memory model.
The shape of states returned here should NOT contain the batch dimension. The batch dimension will be prepended by RLlib, depending on the number of episodes/rollouts in the batch
- Returns:
List of tensors denoting the t=0 recurrent state
- forward_memory(z: ray.rllib.utils.typing.TensorType, state: List[ray.rllib.utils.typing.TensorType], t_starts: ray.rllib.utils.typing.TensorType, seq_lens: ray.rllib.utils.typing.TensorType) Tuple[ray.rllib.utils.typing.TensorType, List[ray.rllib.utils.typing.TensorType]]
Forward for your custom memory model.
- Args:
- z: Preprocessed features of shape [B, T, F], with padding along the time
dimension
state: Recurrent states of shape [B, …] t_starts: Tensor of size [B] denoteint the length of the rollout so far.
Note that in some cases RLlib might chunk a long rollout into multiple forward passes. This tracks the length throughout all forward passes.
- seq_lens: Tensor of size [B] denoting the number of non-padding elements
in z. E.g. seq_lens == [100, 60], z.shape[1] == 128 means the first 100 elements of the first batch dimension are valid and the first 60 elements of the second batch dimension are valid. Rest are padding.
- Returns:
- (output, state)
where output is [B, T, D] and state is [B, …]. Note that the padding must be present in the output. The state must be exactly the same shape as the input state.