modelzoo.common.pytorch.model_utils.weight_initializers.trunc_normal_#

modelzoo.common.pytorch.model_utils.weight_initializers.trunc_normal_(tensor, mean=0.0, std=1.0, a=- 2.0, b=2.0)[source]#

Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution \(\mathcal{N}(\text{mean}, \text{std}^2)\) with values outside \([a, b]\) redrawn until they are within the bounds. The method used for generating the random values works best when \(a \leq \text{mean} \leq b\).

Parameters
  • tensor (torch.Tensor) – an n-dimensional torch.Tensor

  • mean (float) – the mean of the normal distribution. Defaults to 0.0

  • std (float) – the standard deviation of the normal distribution. Defaults to 1.0

  • a (float) – the minimum cutoff value. Defaults to -2.0

  • b (float) – the maximum cutoff value. Defaults to 2.0

Examples

>>> w = torch.empty(3, 3)
>>> trunc_normal_(w)