modelzoo.vision.pytorch.dit.samplers.SamplerBase.SamplerBase#
- class modelzoo.vision.pytorch.dit.samplers.SamplerBase.SamplerBase[source]#
Bases:
abc.ABC
Methods
Returns the previous timestep based on current timestep.
Computes timesteps to be used during sampling
Predict the sample at the previous timestep by reversing the SDE.
- abstract previous_timestep(timestep)[source]#
Returns the previous timestep based on current timestep. Depends on the timesteps computed in self.set_timesteps
- abstract set_timesteps(num_diffusion_steps, num_inference_steps, custom_timesteps)[source]#
Computes timesteps to be used during sampling
- Parameters
num_diffusion_steps (int) – Total number of steps the model was trained on
num_inference_steps (str) – string containing comma-separated numbers, indicating the step count per section. For example, if there’s 300 num_diffusion_steps and num_inference_steps=`10,15,20` then the first 100 timesteps are strided to be 10 timesteps, the second 100 are strided to be 15 timesteps, and the final 100 are strided to be 20. Can either pass custom_timesteps (or) num_inference_steps, but not both.
custom_timesteps (List[int]) – User specified list of timesteps to be used during sampling.