popgym.baselines.models.s4d

Standalone version of Structured (Sequence) State Space (S4) model.

Module Contents

Classes

DropoutNd

OptimModule

Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters

SSKernelNPLR

Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR)

SSKernelDiag

Version using (complex) diagonal state matrix (S4D)

SSKernel

Wrapper around SSKernel parameterizations.

S4

Functions

get_logger(→ logging.Logger)

Initializes multi-GPU-friendly python logger.

Activation([activation, dim])

LinearActivation(d_input, d_output[, bias, ...])

Returns a linear nn.Module with control over axes order, initialization, and activation

power(L, A[, v])

Compute A^L and the scan sum_i A^i v_i

transition(measure, N)

A, B transition matrices for different measures

rank_correction(measure, N[, rank, dtype])

Return low-rank matrix L such that A + L is normal

nplr(measure, N[, rank, dtype, diagonalize_precision])

Return w, p, q, V, B such that

dplr(scaling, N[, rank, H, dtype, real_scale, ...])

ssm(measure, N, R, H, **ssm_args)

Dispatcher to create single SSM initialization

combination(measures, N, R, S, **ssm_args)

Attributes

contract

contract_expression

log

Cauchy and Vandermonde kernels

has_cauchy_extension

has_pykeops

combinations

popgym.baselines.models.s4d.contract
popgym.baselines.models.s4d.contract_expression
popgym.baselines.models.s4d.get_logger(name=__name__, level=logging.INFO) logging.Logger

Initializes multi-GPU-friendly python logger.

popgym.baselines.models.s4d.log

Cauchy and Vandermonde kernels

popgym.baselines.models.s4d.has_cauchy_extension = True
popgym.baselines.models.s4d.has_pykeops = True
popgym.baselines.models.s4d.Activation(activation=None, dim=-1)
popgym.baselines.models.s4d.LinearActivation(d_input, d_output, bias=True, transposed=False, activation=None, activate=False, **kwargs)

Returns a linear nn.Module with control over axes order, initialization, and activation

class popgym.baselines.models.s4d.DropoutNd(p: float = 0.5, tie=True, transposed=True)

Bases: torch.nn.Module

forward(X)

X: (batch, dim, lengths…)

popgym.baselines.models.s4d.power(L, A, v=None)

Compute A^L and the scan sum_i A^i v_i

A: (…, N, N) v: (…, N, L)

popgym.baselines.models.s4d.transition(measure, N)

A, B transition matrices for different measures

popgym.baselines.models.s4d.rank_correction(measure, N, rank=1, dtype=torch.float)

Return low-rank matrix L such that A + L is normal

popgym.baselines.models.s4d.nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True)

Return w, p, q, V, B such that (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V i.e. A = V[w - p q^*]V^*, B = V B

popgym.baselines.models.s4d.dplr(scaling, N, rank=1, H=1, dtype=torch.float, real_scale=1.0, imag_scale=1.0, random_real=False, random_imag=False, normalize=False, diagonal=True, random_B=False)
popgym.baselines.models.s4d.ssm(measure, N, R, H, **ssm_args)

Dispatcher to create single SSM initialization

N: state size R: rank (for DPLR parameterization) H: number of independent SSM copies

popgym.baselines.models.s4d.combinations
popgym.baselines.models.s4d.combination(measures, N, R, S, **ssm_args)
class popgym.baselines.models.s4d.OptimModule

Bases: torch.nn.Module

Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters

register(name, tensor, lr=None)

Register a tensor with a configurable learning rate and 0 weight decay

class popgym.baselines.models.s4d.SSKernelNPLR(w, P, B, C, log_dt, L=None, lr=None, verbose=False, keops=False, real_type='exp', real_tolerance=0.001, bandlimit=None)

Bases: OptimModule

Stores a representation of and computes the SSKernel function K_L(A^dt, B^dt, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR)

forward(state=None, rate=1.0, L=None)

state: (B, H, N) initial state rate: sampling rate factor L: target length

returns: (C, H, L) convolution kernel (generally C=1) (B, H, L) output from initial state

default_state(*batch_shape)
step(u, state)

Must have called self._setup_step() and created state with self.default_state() before calling this

class popgym.baselines.models.s4d.SSKernelDiag(A, B, C, log_dt, L=None, disc='bilinear', real_type='exp', lr=None, bandlimit=None)

Bases: OptimModule

Version using (complex) diagonal state matrix (S4D)

forward(L, state=None, rate=1.0, u=None)

state: (B, H, N) initial state rate: sampling rate factor L: target length

returns: (C, H, L) convolution kernel (generally C=1) (B, H, L) output from initial state

default_state(*batch_shape)
step(u, state)
forward_state(u, state)
class popgym.baselines.models.s4d.SSKernel(H, N=64, L=None, measure='legs', rank=1, channels=1, dt_min=0.001, dt_max=0.1, deterministic=False, lr=None, mode='nplr', n_ssm=None, verbose=False, measure_args={}, **kernel_args)

Bases: torch.nn.Module

Wrapper around SSKernel parameterizations.

The SSKernel is expected to support the interface forward() default_state() _setup_step() step()

forward(state=None, L=None, rate=None)
forward_state(u, state)

Forward the state through a sequence, i.e. computes the state after passing chunk through SSM

state: (B, H, N) u: (B, H, L)

Returns: (B, H, N)

step(u, state, **kwargs)
default_state(*args, **kwargs)
class popgym.baselines.models.s4d.S4(d_model, d_state=64, l_max=None, channels=1, bidirectional=False, activation='gelu', postact='glu', hyper_act=None, dropout=0.0, tie_dropout=False, bottleneck=None, gate=None, transposed=True, verbose=False, **kernel_args)

Bases: torch.nn.Module

property d_output
forward(u, state=None, rate=1.0, lengths=None, **kwargs)

u: (B H L) if self.transposed else (B L H) state: (H N) never needed unless you know what you’re doing

Returns: same shape as u

setup_step(**kwargs)
step(u, state)

Step one time step as a recurrent model. Intended to be used during validation.

u: (B H) state: (B H N) Returns: output (B H), state (B H N)

default_state(*batch_shape, device=None)