Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def forward(self, samples):
- latents = self.vae.encode(samples["tgt_image"].half()).latent_dist.sample()
- latents = latents * 0.18215
- # Sample noise that we'll add to the latents
- noise = torch.randn_like(latents)
- bsz = latents.shape[0]
- # Sample a random timestep for each image
- timesteps = torch.randint(
- 0,
- self.noise_scheduler.config.num_train_timesteps,
- (bsz,),
- device=latents.device,
- )
- timesteps = timesteps.long()
- # Add noise to the latents according to the noise magnitude at each timestep
- # (this is the forward diffusion process)
- noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
- ctx_embeddings = self.forward_ctx_embeddings(
- input_image=samples["inp_image"], text_input=samples["subject_text"]
- )
- # Get the text embedding for conditioning
- input_ids = self.tokenizer(
- samples["caption"],
- padding="do_not_pad",
- truncation=True,
- max_length=self.tokenizer.model_max_length,
- return_tensors="pt",
- ).input_ids.to(self.device)
- encoder_hidden_states = self.text_encoder(
- input_ids=input_ids,
- ctx_embeddings=ctx_embeddings,
- ctx_begin_pos=[self._CTX_BEGIN_POS] * input_ids.shape[0],
- )[0]
- # Predict the noise residual
- noise_pred = self.unet(
- noisy_latents.float(), timesteps, encoder_hidden_states
- ).sample
- loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
- return {"loss": loss}
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement