Source code for modelzoo.vision.pytorch.dit.layers.vae.utils

# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import List, Optional, Tuple, Union

import numpy as np
import torch

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name


[docs]def randn_tensor( shape: Union[Tuple, List], generator: Optional[ Union[List["torch.Generator"], "torch.Generator"] ] = None, device: Optional["torch.device"] = None, dtype: Optional["torch.dtype"] = None, layout: Optional["torch.layout"] = None, ): """ This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor will always be created on CPU. """ # device on which tensor is created defaults to device rand_device = device batch_size = shape[0] layout = layout or torch.strided device = device or torch.device("cpu") if generator is not None: gen_device_type = ( generator.device.type if not isinstance(generator, list) else generator[0].device.type ) if gen_device_type != device.type and gen_device_type == "cpu": rand_device = "cpu" if device != "mps": logger.info( f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" f" slighly speed up this function by passing a generator that was created on the {device} device." ) elif gen_device_type != device.type and gen_device_type == "cuda": raise ValueError( f"Cannot generate a {device} tensor from a generator of type {gen_device_type}." ) if isinstance(generator, list): shape = (1,) + shape[1:] latents = [ torch.randn( shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout, ) for i in range(batch_size) ] latents = torch.cat(latents, dim=0).to(device) else: latents = torch.randn( shape, generator=generator, device=rand_device, dtype=dtype, layout=layout, ).to(device) return latents
def _emulate_chunk2_dim1(x): c = x.shape[1] // 2 return x[:, 0:c, ...], x[:, c:, ...]
[docs]class DiagonalGaussianDistribution(object):
[docs] def __init__(self, parameters, deterministic=False): self.parameters = parameters self.mean, self.logvar = _emulate_chunk2_dim1(parameters) self.logvar = torch.clamp(self.logvar, -30.0, 20.0) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: self.var = self.std = torch.zeros_like( self.mean, device=self.parameters.device, dtype=self.parameters.dtype, )
def sample( self, noise=None, generator: Optional[torch.Generator] = None ) -> torch.FloatTensor: # make sure sample is on the same device as the parameters and has same dtype if noise is None: noise = randn_tensor( self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype, ) x = self.mean + self.std * noise return x def kl(self, other=None): if self.deterministic: return torch.Tensor([0.0]) else: if other is None: return 0.5 * torch.sum( torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3], ) else: return 0.5 * torch.sum( torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar, dim=[1, 2, 3], ) def nll(self, sample, dims=[1, 2, 3]): if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi) return 0.5 * torch.sum( logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims, ) def mode(self): return self.mean