Guest User

Untitled

a guest
Sep 30th, 2025
28
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.36 KB | None | 0 0
  1. import torch
  2. from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps)
  3. from .fm_solvers_unipc import FlowUniPCMultistepScheduler
  4. from .basic_flowmatch import FlowMatchScheduler
  5. from .flowmatch_pusa import FlowMatchSchedulerPusa
  6. from .flowmatch_res_multistep import FlowMatchSchedulerResMultistep
  7. from .scheduling_flow_match_lcm import FlowMatchLCMScheduler
  8. from diffusers.schedulers import FlowMatchEulerDiscreteScheduler, DEISMultistepScheduler
  9. from .fm_solvers_euler import EulerScheduler
  10. from ...utils import log
  11.  
  12. scheduler_list = [
  13. "unipc", "unipc/beta",
  14. "dpm++", "dpm++/beta",
  15. "dpm++_sde", "dpm++_sde/beta",
  16. "euler", "euler/beta",
  17. "deis",
  18. "lcm", "lcm/beta",
  19. "res_multistep",
  20. "flowmatch_causvid",
  21. "flowmatch_distill",
  22. "flowmatch_pusa",
  23. "lightning_euler", "lightning_euler/beta", "lightning_euler/beta57",
  24. "multitalk"
  25. ]
  26.  
  27. def get_scheduler(scheduler, steps, start_step, end_step, shift, device, transformer_dim=5120, flowedit_args=None, denoise_strength=1.0, sigmas=None, log_timesteps=False):
  28. timesteps = None
  29. if 'unipc' in scheduler:
  30. sample_scheduler = FlowUniPCMultistepScheduler(shift=shift)
  31. if sigmas is None:
  32. sample_scheduler.set_timesteps(steps, device=device, shift=shift, use_beta_sigmas=('beta' in scheduler))
  33. else:
  34. sample_scheduler.sigmas = sigmas.to(device)
  35. sample_scheduler.timesteps = (sample_scheduler.sigmas[:-1] * 1000).to(torch.int64).to(device)
  36. sample_scheduler.num_inference_steps = len(sample_scheduler.timesteps)
  37.  
  38. elif scheduler in ['euler/beta', 'euler']:
  39. sample_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift, use_beta_sigmas=(scheduler == 'euler/beta'))
  40. if flowedit_args: #seems to work better
  41. timesteps, _ = retrieve_timesteps(sample_scheduler, device=device, sigmas=get_sampling_sigmas(steps, shift))
  42. else:
  43. sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas[:-1].tolist() if sigmas is not None else None)
  44. elif 'dpm' in scheduler:
  45. if 'sde' in scheduler:
  46. algorithm_type = "sde-dpmsolver++"
  47. else:
  48. algorithm_type = "dpmsolver++"
  49. sample_scheduler = FlowDPMSolverMultistepScheduler(shift=shift, algorithm_type=algorithm_type)
  50. if sigmas is None:
  51. sample_scheduler.set_timesteps(steps, device=device, use_beta_sigmas=('beta' in scheduler))
  52. else:
  53. sample_scheduler.sigmas = sigmas.to(device)
  54. sample_scheduler.timesteps = (sample_scheduler.sigmas[:-1] * 1000).to(torch.int64).to(device)
  55. sample_scheduler.num_inference_steps = len(sample_scheduler.timesteps)
  56. elif scheduler == 'deis':
  57. sample_scheduler = DEISMultistepScheduler(use_flow_sigmas=True, prediction_type="flow_prediction", flow_shift=shift)
  58. sample_scheduler.set_timesteps(steps, device=device)
  59. sample_scheduler.sigmas[-1] = 1e-6
  60. elif 'lcm' in scheduler:
  61. sample_scheduler = FlowMatchLCMScheduler(shift=shift, use_beta_sigmas=(scheduler == 'lcm/beta'))
  62. sample_scheduler.set_timesteps(steps, device=device, sigmas=sigmas[:-1].tolist() if sigmas is not None else None)
  63. elif 'flowmatch_causvid' in scheduler:
  64. if sigmas is not None:
  65. raise NotImplementedError("This scheduler does not support custom sigmas")
  66. if transformer_dim == 5120:
  67. denoising_list = [999, 934, 862, 756, 603, 410, 250, 140, 74]
  68. else:
  69. if steps != 4:
  70. raise ValueError("CausVid 1.3B schedule is only for 4 steps")
  71. denoising_list = [1000, 750, 500, 250]
  72. sample_scheduler = FlowMatchScheduler(num_inference_steps=steps, shift=shift, sigma_min=0, extra_one_step=True)
  73. sample_scheduler.timesteps = torch.tensor(denoising_list)[:steps].to(device)
  74. sample_scheduler.sigmas = torch.cat([sample_scheduler.timesteps / 1000, torch.tensor([0.0], device=device)])
  75. elif 'flowmatch_distill' in scheduler:
  76. if sigmas is not None:
  77. raise NotImplementedError("This scheduler does not support custom sigmas")
  78. sample_scheduler = FlowMatchScheduler(
  79. shift=shift, sigma_min=0.0, extra_one_step=True
  80. )
  81. sample_scheduler.set_timesteps(1000, training=True)
  82.  
  83. denoising_step_list = torch.tensor([999, 750, 500, 250] , dtype=torch.long)
  84. temp_timesteps = torch.cat((sample_scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
  85. denoising_step_list = temp_timesteps[1000 - denoising_step_list]
  86. #print("denoising_step_list: ", denoising_step_list)
  87.  
  88. if steps != 4:
  89. raise ValueError("This scheduler is only for 4 steps")
  90.  
  91. sample_scheduler.timesteps = denoising_step_list[:steps].clone().detach().to(device)
  92. sample_scheduler.sigmas = torch.cat([sample_scheduler.timesteps / 1000, torch.tensor([0.0], device=device)])
  93. elif 'flowmatch_pusa' in scheduler:
  94. sample_scheduler = FlowMatchSchedulerPusa(shift=shift, sigma_min=0.0, extra_one_step=True)
  95. sample_scheduler.set_timesteps(steps+1, denoising_strength=denoise_strength, shift=shift,
  96. sigmas=sigmas[:-1].tolist() if sigmas is not None else None)
  97. elif 'lightning_euler' in scheduler:
  98. if sigmas is not None:
  99. raise NotImplementedError("This scheduler does not support custom sigmas")
  100.  
  101. # Resolve Beta shaping + (alpha, beta) for 'beta57'
  102. use_beta = (scheduler == 'lightning_euler/beta') or (scheduler == 'lightning_euler/beta57')
  103. alpha_val, beta_val = (0.6, 0.6)
  104. if scheduler == 'lightning_euler/beta57':
  105. alpha_val, beta_val = (0.5, 0.7)
  106.  
  107. sample_scheduler = EulerScheduler(
  108. num_train_timesteps=1000,
  109. shift=shift,
  110. device=device,
  111. use_beta_sigmas=use_beta,
  112. alpha=alpha_val,
  113. beta=beta_val,
  114. )
  115.  
  116. sample_scheduler.set_timesteps(num_inference_steps=steps, device=device)
  117. timesteps = sample_scheduler.timesteps[:-1].clone()
  118. elif scheduler == 'res_multistep':
  119. sample_scheduler = FlowMatchSchedulerResMultistep(shift=shift)
  120. sample_scheduler.set_timesteps(steps, denoising_strength=denoise_strength, sigmas=sigmas[:-1].tolist() if sigmas is not None else None)
  121.  
  122. if timesteps is None:
  123. timesteps = sample_scheduler.timesteps
  124.  
  125. steps = len(timesteps)
  126. if (isinstance(start_step, int) and end_step != -1 and start_step >= end_step) or (not isinstance(start_step, int) and start_step != -1 and end_step >= start_step):
  127. raise ValueError("start_step must be less than end_step")
  128. if denoise_strength < 1.0:
  129. if start_step != 0:
  130. raise ValueError("start_step must be 0 when denoise_strength is used")
  131. start_step = steps - int(steps * denoise_strength) - 1
  132.  
  133. # Determine start and end indices for slicing
  134. start_idx = 0
  135. end_idx = len(timesteps) - 1
  136.  
  137. if log_timesteps:
  138. log.info(f"------- Scheduler info -------")
  139. log.info(f"Total timesteps: {timesteps}")
  140.  
  141. if isinstance(start_step, float):
  142. idxs = (sample_scheduler.sigmas <= start_step).nonzero(as_tuple=True)[0]
  143. if len(idxs) > 0:
  144. start_idx = idxs[0].item()
  145. elif isinstance(start_step, int):
  146. if start_step > 0:
  147. start_idx = start_step
  148.  
  149. if isinstance(end_step, float):
  150. idxs = (sample_scheduler.sigmas >= end_step).nonzero(as_tuple=True)[0]
  151. if len(idxs) > 0:
  152. end_idx = idxs[-1].item()
  153. elif isinstance(end_step, int):
  154. if end_step != -1:
  155. end_idx = end_step - 1
  156.  
  157. # Slice timesteps and sigmas once, based on indices
  158. timesteps = timesteps[start_idx:end_idx+1]
  159. sample_scheduler.full_sigmas = sample_scheduler.sigmas.clone()
  160. sample_scheduler.sigmas = sample_scheduler.sigmas[start_idx:start_idx+len(timesteps)+1] # always one longer
  161.  
  162. if log_timesteps:
  163. log.info(f"Using timesteps: {timesteps}")
  164. log.info(f"Using sigmas: {sample_scheduler.sigmas}")
  165. log.info(f"------------------------------")
  166.  
  167. if hasattr(sample_scheduler, 'timesteps'):
  168. sample_scheduler.timesteps = timesteps
  169.  
  170. return sample_scheduler, timesteps, start_idx, end_idx
Advertisement
Add Comment
Please, Sign In to add comment