popgym.baselines.models.fwp

Module Contents

Classes

FWPBlock

The building block in the fast weight transformers paper. This is

class popgym.baselines.models.fwp.FWPBlock(input_size, hidden_size, aggregator='sum', sum_normalization=True, feed_forward=True, residual=True)

Bases: torch.nn.Module

The building block in the fast weight transformers paper. This is a form of linear transformer.

Note that this is a “simplified” version of FWP without the delta update rule and DPFP. We use ReLU instead of DPFP.

Inputs:

input_size: Size of input feature dim hidden_size: Size of key/query/value space aggregator: Which type of aggregation to use sum_normalization: Whether to use the sum normalization described

in the paper

feed_forward: Whether to apply a perceptron to the output residual: Whether to apply a residual connection from input to output

forward(x: torch.Tensor, state: torch.Tensor) torch.Tensor
Inputs:

x: [B, T, F] state: [B, 1, F]

Outputs:

y: [B, T, D] state: [B, T, F]