popgym.baselines.models.embeddings
Module Contents
Classes
Functions
|
- popgym.baselines.models.embeddings.positional_encoding(max_len, d_model)
- class popgym.baselines.models.embeddings.PhaserEncoding(max_len: int, d_model: int, num_dims: int, max_period: int = 1024, dtype: torch.dtype = torch.double, decay_bias: float = -0.5, freq_scale: float = 1.0, fudge_factor: float = 0.025, soft_clamp: bool = False)
Bases:
torch.nn.Module- psi(t_minus_i)
- recurrent_update2(x, state)
- recurrent_update(x, state)
- forward(x, state)
- class popgym.baselines.models.embeddings.ScaledPositionalEncoding(num_embeddings, embedding_dim, **kwargs)
Bases:
torch.nn.Embedding- forward(idx)
- class popgym.baselines.models.embeddings.SimpleScaledEncoding(max_len, d_model)
Bases:
torch.nn.Module- forward(idx)
- class popgym.baselines.models.embeddings.Time2Vec(max_len, d_model)
Bases:
torch.nn.Module- forward(idx)
- class popgym.baselines.models.embeddings.Base2Embedding(max_len, d_model)
Bases:
torch.nn.Module- forward(idx)