popgym.baselines.models.linear_attention

Module Contents

Classes

Phi

LinearAttentionBlock

The building block from the Linear Transformers are Secretly RNNs Paper. This is

class popgym.baselines.models.linear_attention.Phi

Bases: torch.nn.Module

forward(x)
class popgym.baselines.models.linear_attention.LinearAttentionBlock(input_size: int, hidden_size: int, S_aggregator: str = 'sum', Z_aggregator: str = 'sum', feed_forward=True, residual=True)

Bases: torch.nn.Module

The building block from the Linear Transformers are Secretly RNNs Paper. This is a form of linear transformer.

Inputs:

input_size: Size of input feature dim hidden_size: Size of key/query/value space S_aggregator: Which type of aggregation to use for the numerator (S term) Z_aggregator: Which type of aggregation to use for the denominator (Z term) feed_forward: Whether to apply a perceptron to the output residual: Whether to apply a residual connection from input to output

forward(x: torch.Tensor, state: List[torch.Tensor]) Tuple[torch.Tensor, List[torch.Tensor]]
Input:

x: [B, T, F] state: Tuple[

[B, 1, D, D], [B, 1, D]

]

Output:

y: [B, T, D] state: Tuple[

[B, 1, D, D], [B, 1, D]

]