popgym.baselines.models.fwp
Module Contents
Classes
The building block in the fast weight transformers paper. This is |
- class popgym.baselines.models.fwp.FWPBlock(input_size, hidden_size, aggregator='sum', sum_normalization=True, feed_forward=True, residual=True)
Bases:
torch.nn.ModuleThe building block in the fast weight transformers paper. This is a form of linear transformer.
Note that this is a “simplified” version of FWP without the delta update rule and DPFP. We use ReLU instead of DPFP.
- Inputs:
input_size: Size of input feature dim hidden_size: Size of key/query/value space aggregator: Which type of aggregation to use sum_normalization: Whether to use the sum normalization described
in the paper
feed_forward: Whether to apply a perceptron to the output residual: Whether to apply a residual connection from input to output
- forward(x: torch.Tensor, state: torch.Tensor) torch.Tensor
- Inputs:
x: [B, T, F] state: [B, 1, F]
- Outputs:
y: [B, T, D] state: [B, T, F]