Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def get_noise_noisy_latents_and_timesteps(
- args, noise_scheduler, latents: torch.FloatTensor
- ):
- """
- Генерация фиксированного шума на основе содержимого latents.
- """
- if args.fixed_noise:
- # Переводим в float32 перед numpy()
- latents_bytes = latents.detach().to(torch.float32).cpu().numpy().tobytes()
- hash_int = int(hashlib.sha256(latents_bytes).hexdigest(), 16) % (2**32)
- g = torch.Generator(device=latents.device).manual_seed(hash_int)
- noise = torch.randn(latents.shape, generator=g, device=latents.device)
- else:
- noise = torch.randn_like(latents)
- if args.noise_offset:
- if args.noise_offset_random_strength:
- noise_offset = torch.rand(1, device=latents.device) * args.noise_offset
- else:
- noise_offset = args.noise_offset
- noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale)
- if args.multires_noise_iterations:
- noise = custom_train_functions.pyramid_noise_like(
- noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
- )
- b_size = latents.shape[0]
- min_timestep = 0 if args.min_timestep is None else args.min_timestep
- max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
- timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)
- if args.ip_noise_gamma:
- if args.ip_noise_gamma_random_strength:
- strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma
- else:
- strength = args.ip_noise_gamma
- noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps)
- else:
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
- return noise, noisy_latents, timesteps
Advertisement
Add Comment
Please, Sign In to add comment