modelzoo.common.pytorch.model_utils.weight_initializers.trunc_normal_#
- modelzoo.common.pytorch.model_utils.weight_initializers.trunc_normal_(tensor, mean=0.0, std=1.0, a=- 2.0, b=2.0)[source]#
Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution \(\mathcal{N}(\text{mean}, \text{std}^2)\) with values outside \([a, b]\) redrawn until they are within the bounds. The method used for generating the random values works best when \(a \leq \text{mean} \leq b\).
- Parameters
tensor (torch.Tensor) – an n-dimensional torch.Tensor
mean (float) – the mean of the normal distribution. Defaults to 0.0
std (float) – the standard deviation of the normal distribution. Defaults to 1.0
a (float) – the minimum cutoff value. Defaults to -2.0
b (float) – the maximum cutoff value. Defaults to 2.0
Examples
>>> w = torch.empty(3, 3) >>> trunc_normal_(w)