popgym.baselines.models.embeddings

Module Contents

Classes

PhaserEncoding

ScaledPositionalEncoding

SimpleScaledEncoding

Time2Vec

Base2Embedding

PerceptronEmbedding

SoftExponentialEmbedding

Functions

positional_encoding(max_len, d_model)

popgym.baselines.models.embeddings.positional_encoding(max_len, d_model)
class popgym.baselines.models.embeddings.PhaserEncoding(max_len: int, d_model: int, num_dims: int, max_period: int = 1024, dtype: torch.dtype = torch.double, decay_bias: float = -0.5, freq_scale: float = 1.0, fudge_factor: float = 0.025, soft_clamp: bool = False)

Bases: torch.nn.Module

psi(t_minus_i)
recurrent_update2(x, state)
recurrent_update(x, state)
forward(x, state)
class popgym.baselines.models.embeddings.ScaledPositionalEncoding(num_embeddings, embedding_dim, **kwargs)

Bases: torch.nn.Embedding

forward(idx)
class popgym.baselines.models.embeddings.SimpleScaledEncoding(max_len, d_model)

Bases: torch.nn.Module

forward(idx)
class popgym.baselines.models.embeddings.Time2Vec(max_len, d_model)

Bases: torch.nn.Module

forward(idx)
class popgym.baselines.models.embeddings.Base2Embedding(max_len, d_model)

Bases: torch.nn.Module

forward(idx)
class popgym.baselines.models.embeddings.PerceptronEmbedding(max_len, d_model)

Bases: torch.nn.Module

forward(idx)
class popgym.baselines.models.embeddings.SoftExponentialEmbedding(max_len, d_model)

Bases: torch.nn.Module

forward(idx)