popgym.baselines.models.linear_attention
Module Contents
Classes
The building block from the Linear Transformers are Secretly RNNs Paper. This is |
- class popgym.baselines.models.linear_attention.LinearAttentionBlock(input_size: int, hidden_size: int, S_aggregator: str = 'sum', Z_aggregator: str = 'sum', feed_forward=True, residual=True)
Bases:
torch.nn.ModuleThe building block from the Linear Transformers are Secretly RNNs Paper. This is a form of linear transformer.
- Inputs:
input_size: Size of input feature dim hidden_size: Size of key/query/value space S_aggregator: Which type of aggregation to use for the numerator (S term) Z_aggregator: Which type of aggregation to use for the denominator (Z term) feed_forward: Whether to apply a perceptron to the output residual: Whether to apply a residual connection from input to output
- forward(x: torch.Tensor, state: List[torch.Tensor]) Tuple[torch.Tensor, List[torch.Tensor]]
- Input:
x: [B, T, F] state: Tuple[
[B, 1, D, D], [B, 1, D]
]
- Output:
y: [B, T, D] state: Tuple[
[B, 1, D, D], [B, 1, D]
]