Advertisement
ethansmith2000

diffusion

May 11th, 2024
682
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 8.35 KB | Source Code | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. import math
  6. from tqdm import tqdm
  7. import itertools
  8. import torch.nn.functional as F
  9. import pandas as pd
  10. import seaborn as sns
  11.  
  12.  
  13. class SinusoidalPositionalEmbedding(nn.Module):
  14.     def __init__(self, embedding_dim=10, max_length=1000):
  15.         super(SinusoidalPositionalEmbedding, self).__init__()
  16.         self.embedding_dim = embedding_dim
  17.         self.max_length = max_length
  18.        
  19.         # Compute the positional encodings once in log space
  20.         self.positional_encodings = self._get_positional_encodings()
  21.  
  22.     def _get_positional_encodings(self):
  23.         pe = torch.zeros(self.max_length, self.embedding_dim)
  24.         position = torch.arange(0, self.max_length, dtype=torch.float).unsqueeze(1)
  25.         div_term = torch.exp(torch.arange(0, self.embedding_dim, 2).float() * (-math.log(10000.0) / self.embedding_dim))
  26.        
  27.         pe[:, 0::2] = torch.sin(position * div_term)
  28.         pe[:, 1::2] = torch.cos(position * div_term)
  29.         # pe = pe.unsqueeze(0)
  30.         return pe
  31.  
  32.     def forward(self, time):
  33.         # Add positional embeddings to the input tensor
  34.         return self.positional_encodings[time, :]
  35.  
  36. # Make the score diction network (i.e. same role as the UNet)
  37. class ScoreNetwork(nn.Module):
  38.     """
  39.        Has a simple feed forward MLP structure.
  40.  
  41.        Takes as input the data point and a embedding encoding time.
  42.    """
  43.  
  44.     def __init__(self, data_dim=2, time_dim=2, dim=128, total_timesteps=1000):
  45.         super(ScoreNetwork, self).__init__()
  46.         self.data_dim = data_dim
  47.         self.time_dim = time_dim
  48.         # Make the positional embedding
  49.         self.positional_embedding = SinusoidalPositionalEmbedding(
  50.             time_dim,
  51.             total_timesteps
  52.         )
  53.  
  54.         # act = nn.SiLU
  55.         act = nn.LeakyReLU
  56.  
  57.         self.network = nn.Sequential(
  58.             nn.Linear(data_dim + time_dim, dim),
  59.             act(),
  60.             nn.Linear(dim, dim),
  61.             act(),
  62.             nn.Linear(dim, data_dim),
  63.         )
  64.  
  65.  
  66.     def forward(self, x, time):
  67.         """
  68.            Forward pass of the network
  69.        """
  70.         # Map the time through a positional encoding
  71.         time_embedding = self.positional_embedding(time)
  72.         x = self.network(torch.cat([x, time_embedding], dim=-1))
  73.         return x
  74.  
  75. # Make a Diffusion model object
  76. class DiffusionModel(nn.Module):
  77.  
  78.     def __init__(
  79.             self,
  80.             total_timesteps=1000,  
  81.             data_dim=2,
  82.             time_dim=10,
  83.             beta_start=0.0001,
  84.             beta_end=0.02,
  85.             offset=4.0,
  86.             spread=1.0,
  87.         ):
  88.         super(DiffusionModel, self).__init__()
  89.  
  90.         self.total_timesteps = total_timesteps
  91.         # Make a score network
  92.         self.score_network = ScoreNetwork(
  93.             data_dim=data_dim,
  94.             time_dim=time_dim
  95.         )
  96.         self.offset = offset
  97.         self.spread = spread
  98.  
  99.         # Make alphas and betas
  100.         self.betas = torch.linspace(beta_start, beta_end, total_timesteps) # Use linear schedule
  101.         self.alphas = 1 - self.betas
  102.         # Make cumulative products
  103.         self.cumulative_alphas = torch.cumprod(self.alphas, dim=0)
  104.         self.cumulative_betas = 1 - self.cumulative_alphas
  105.  
  106.         self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
  107.         self.alphas_cumprod_prev = F.pad(
  108.             self.alphas_cumprod[:-1], (1, 0), value=1.)
  109.  
  110.         # required for self.add_noise
  111.         self.sqrt_alphas_cumprod = self.alphas_cumprod ** 0.5
  112.         self.sqrt_one_minus_alphas_cumprod = (1 - self.alphas_cumprod) ** 0.5
  113.  
  114.         # required for reconstruct_x0
  115.         self.sqrt_inv_alphas_cumprod = torch.sqrt(1 / self.alphas_cumprod)
  116.         self.sqrt_inv_alphas_cumprod_minus_one = torch.sqrt(
  117.             1 / self.alphas_cumprod - 1)
  118.  
  119.         # required for q_posterior
  120.         self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
  121.         self.posterior_mean_coef2 = (1. - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1. - self.alphas_cumprod)
  122.  
  123.     def get_variance(self, t):
  124.         if t == 0:
  125.             return 0
  126.  
  127.         variance = self.betas[t] * (1. - self.alphas_cumprod_prev[t]) / (1. - self.alphas_cumprod[t])
  128.         variance = variance.clip(1e-20)
  129.         return variance
  130.    
  131.     def predict_noise(self, x, t):
  132.         """
  133.            Predicts the noise at a noisy sample and time step
  134.        """
  135.         return self.score_network(x, t)
  136.  
  137.     def step(self, model_output, timestep, x_t):
  138.         t = timestep
  139.         # Reconstruct the original sample
  140.         s1 = self.sqrt_inv_alphas_cumprod[t].reshape(-1, 1)
  141.         s2 = self.sqrt_inv_alphas_cumprod_minus_one[t].reshape(-1, 1)
  142.         pred_original_sample = s1 * x_t - s2 * model_output
  143.         # Predict the previous sample
  144.         s1 = self.posterior_mean_coef1[t].reshape(-1, 1)
  145.         s2 = self.posterior_mean_coef2[t].reshape(-1, 1)
  146.         pred_prev_sample = s1 * pred_original_sample + s2 * x_t
  147.         # Add noise back to the sample
  148.         variance = 0
  149.         if t > 0:
  150.             noise = torch.randn_like(model_output)
  151.             variance = (self.get_variance(t) ** 0.5) * noise
  152.         # if t > 0:
  153.         #     noise = torch.randn_like(model_output)
  154.         #     variance = (self.get_variance(t) ** 0.5) * model_output
  155.  
  156.         pred_prev_sample = pred_prev_sample + variance
  157.  
  158.         return pred_prev_sample
  159.  
  160.     def add_noise(self, x_start, x_noise, timesteps):
  161.         s1 = self.sqrt_alphas_cumprod[timesteps]
  162.         s2 = self.sqrt_one_minus_alphas_cumprod[timesteps]
  163.  
  164.         s1 = s1.reshape(-1, 1)
  165.         s2 = s2.reshape(-1, 1)
  166.  
  167.         return s1 * x_start + s2 * x_noise
  168.    
  169.     def sample(self, num_samples=1000, num_timesteps=1000, device='cpu'):
  170.         orig_sample = torch.randn(num_samples, 2) * self.spread
  171.         orig_sample[:, 0] -= self.offset
  172.  
  173.         sample = orig_sample.clone()
  174.         timesteps = list(range(self.total_timesteps))[::-1]
  175.         interemediate_values = torch.empty(num_samples, num_timesteps, 2)
  176.  
  177.         for i, t in enumerate(timesteps):
  178.             t = torch.from_numpy(np.repeat(t, num_samples)).long()
  179.             with torch.no_grad():
  180.                 residual = self.predict_noise(sample, t)
  181.             sample = self.step(residual, t[0], sample)
  182.             interemediate_values[:, i, :] = sample
  183.        
  184.         return orig_sample, sample, interemediate_values
  185.  
  186. # Train the model
  187. def train(model, num_iterations=1000, batch_size=32, learning_rate=1e-4, device='cpu'):
  188.     model = model.to(device)
  189.     optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.98))
  190.     loss_fn = nn.MSELoss()
  191.     losses = []
  192.     for i in tqdm(range(num_iterations)):
  193.         optimizer.zero_grad()
  194.         noise_free_data = torch.randn(batch_size, 2).to(device) * model.spread
  195.         noise_free_data[:, 0] += model.offset
  196.  
  197.         time_steps = torch.randint(0, model.total_timesteps, (noise_free_data.shape[0],)) # Randomly sample time steps uniformly  
  198.         noise = torch.randn(noise_free_data.shape) * model.spread
  199.         noise[:, 0] -= model.offset
  200.  
  201.         noisy_data = model.add_noise(noise_free_data, noise, time_steps)
  202.         predicted_noise = model.predict_noise(noisy_data, time_steps)
  203.         loss = loss_fn(predicted_noise, noise)
  204.         losses.append(loss.item())
  205.         loss.backward()
  206.         optimizer.step()
  207.         if i % 10000 == 0:
  208.             # Generate 1000 samples and plot them
  209.             orig_sample, samples, _ = model.sample(num_samples=500, num_timesteps=1000, device=device)
  210.             samples = samples.detach().cpu().numpy()
  211.             plt.figure()
  212.             # Plot the true data
  213.             data = torch.randn(1000, 2).cpu().numpy() * model.spread
  214.             data[:, 0] += model.offset
  215.  
  216.             plt.scatter(orig_sample[:, 0], orig_sample[:, 1], label='Noise dist', alpha=0.5)
  217.             plt.scatter(data[:, 0], data[:, 1], label='True Data', alpha=0.5)
  218.             plt.scatter(samples[:, 0], samples[:, 1], label='Samples', alpha=0.5)
  219.             plt.legend()
  220.             plt.savefig(f'samples_{i}.png')
  221.  
  222.     # Plot the losses
  223.     plt.figure()
  224.     losses = np.convolve(losses, np.ones(5000) / 5000, mode='valid')
  225.     plt.plot(losses)
  226.     plt.savefig('plots/losses.png')
  227.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement