# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
from modelzoo.common.pytorch.layers.FeedForwardNetwork import FeedForwardNetwork
from modelzoo.vision.pytorch.dit.layers.GaussianDiffusion import index
[docs]class TimestepEmbeddingLayer(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
[docs] def __init__(
self,
num_diffusion_steps,
hidden_size,
frequency_embedding_size=256,
nonlinearity="silu",
kernel_initializer: str = "xavier_uniform",
bias_initializer: str = "zeros",
):
super().__init__()
self.timestep_embedding = self.create_timestep_embedding(
seq_len=num_diffusion_steps, dim=frequency_embedding_size
)
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
self.ffn = FeedForwardNetwork(
input_unit=frequency_embedding_size,
layers_units=[hidden_size, hidden_size],
layers_activation=[nonlinearity, None],
use_bias=True,
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
)
# Initialize weights and bias
self.__reset_parameters()
def reset_parameters(self):
self.__reset_parameters()
def __reset_parameters(self):
self.ffn.reset_parameters()
[docs] @staticmethod
def create_timestep_embedding(seq_len, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Slightly different than `EmbeddingLayer.create_fix_pos_embedding`.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
position = torch.arange(seq_len, dtype=torch.float32)
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
)
args = position[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return torch.nn.Parameter(embedding, requires_grad=False)
def forward(self, t):
t_freq = index(self.timestep_embedding, t)
t_emb = self.ffn(t_freq)
return t_emb