popgym.baselines.models.aggregations

Module Contents

Classes

Aggregation

Aggregates (x_k ... x_t , s_k) into s_t

SumAggregation

Aggregates (x_k ... x_t , s_k) into s_t

MaxAggregation

Aggregates (x_k ... x_t , s_k) into s_t

Functions

get_aggregator(→ torch.nn.Module)

popgym.baselines.models.aggregations.get_aggregator(name: str) torch.nn.Module
class popgym.baselines.models.aggregations.Aggregation(*args, **kwargs)

Bases: torch.nn.Module

Aggregates (x_k … x_t , s_k) into s_t

abstract forward(x: torch.Tensor, memory: torch.Tensor) torch.Tensor
class popgym.baselines.models.aggregations.SumAggregation(*args, **kwargs)

Bases: Aggregation

Aggregates (x_k … x_t , s_k) into s_t

forward(x: torch.Tensor, memory: torch.Tensor) torch.Tensor
class popgym.baselines.models.aggregations.MaxAggregation(*args, **kwargs)

Bases: Aggregation

Aggregates (x_k … x_t , s_k) into s_t

forward(x: torch.Tensor, memory: torch.Tensor) torch.Tensor