# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
[docs]def randn_tensor(
shape: Union[Tuple, List],
generator: Optional[
Union[List["torch.Generator"], "torch.Generator"]
] = None,
device: Optional["torch.device"] = None,
dtype: Optional["torch.dtype"] = None,
layout: Optional["torch.layout"] = None,
):
"""
This is a helper function that allows to create random tensors
on the desired `device` with the desired `dtype`. When
passing a list of generators one can seed each batched size individually.
If CPU generators are passed the tensor
will always be created on CPU.
"""
# device on which tensor is created defaults to device
rand_device = device
batch_size = shape[0]
layout = layout or torch.strided
device = device or torch.device("cpu")
if generator is not None:
gen_device_type = (
generator.device.type
if not isinstance(generator, list)
else generator[0].device.type
)
if gen_device_type != device.type and gen_device_type == "cpu":
rand_device = "cpu"
if device != "mps":
logger.info(
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
f" slighly speed up this function by passing a generator that was created on the {device} device."
)
elif gen_device_type != device.type and gen_device_type == "cuda":
raise ValueError(
f"Cannot generate a {device} tensor from a generator of type {gen_device_type}."
)
if isinstance(generator, list):
shape = (1,) + shape[1:]
latents = [
torch.randn(
shape,
generator=generator[i],
device=rand_device,
dtype=dtype,
layout=layout,
)
for i in range(batch_size)
]
latents = torch.cat(latents, dim=0).to(device)
else:
latents = torch.randn(
shape,
generator=generator,
device=rand_device,
dtype=dtype,
layout=layout,
).to(device)
return latents
def _emulate_chunk2_dim1(x):
c = x.shape[1] // 2
return x[:, 0:c, ...], x[:, c:, ...]
[docs]class DiagonalGaussianDistribution(object):
[docs] def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = _emulate_chunk2_dim1(parameters)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
def sample(
self, noise=None, generator: Optional[torch.Generator] = None
) -> torch.FloatTensor:
# make sure sample is on the same device as the parameters and has same dtype
if noise is None:
noise = randn_tensor(
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * noise
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var
- 1.0
- self.logvar
+ other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi
+ self.logvar
+ torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self):
return self.mean