Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- import torch.nn as nn
- import numpy as np
- import matplotlib.pyplot as plt
- import math
- from tqdm import tqdm
- import itertools
- import torch.nn.functional as F
- import pandas as pd
- import seaborn as sns
- class SinusoidalPositionalEmbedding(nn.Module):
- def __init__(self, embedding_dim=10, max_length=1000):
- super(SinusoidalPositionalEmbedding, self).__init__()
- self.embedding_dim = embedding_dim
- self.max_length = max_length
- # Compute the positional encodings once in log space
- self.positional_encodings = self._get_positional_encodings()
- def _get_positional_encodings(self):
- pe = torch.zeros(self.max_length, self.embedding_dim)
- position = torch.arange(0, self.max_length, dtype=torch.float).unsqueeze(1)
- div_term = torch.exp(torch.arange(0, self.embedding_dim, 2).float() * (-math.log(10000.0) / self.embedding_dim))
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
- # pe = pe.unsqueeze(0)
- return pe
- def forward(self, time):
- # Add positional embeddings to the input tensor
- return self.positional_encodings[time, :]
- # Make the score diction network (i.e. same role as the UNet)
- class ScoreNetwork(nn.Module):
- """
- Has a simple feed forward MLP structure.
- Takes as input the data point and a embedding encoding time.
- """
- def __init__(self, data_dim=2, time_dim=2, dim=128, total_timesteps=1000):
- super(ScoreNetwork, self).__init__()
- self.data_dim = data_dim
- self.time_dim = time_dim
- # Make the positional embedding
- self.positional_embedding = SinusoidalPositionalEmbedding(
- time_dim,
- total_timesteps
- )
- # act = nn.SiLU
- act = nn.LeakyReLU
- self.network = nn.Sequential(
- nn.Linear(data_dim + time_dim, dim),
- act(),
- nn.Linear(dim, dim),
- act(),
- nn.Linear(dim, data_dim),
- )
- def forward(self, x, time):
- """
- Forward pass of the network
- """
- # Map the time through a positional encoding
- time_embedding = self.positional_embedding(time)
- x = self.network(torch.cat([x, time_embedding], dim=-1))
- return x
- # Make a Diffusion model object
- class DiffusionModel(nn.Module):
- def __init__(
- self,
- total_timesteps=1000,
- data_dim=2,
- time_dim=10,
- beta_start=0.0001,
- beta_end=0.02,
- offset=4.0,
- spread=1.0,
- ):
- super(DiffusionModel, self).__init__()
- self.total_timesteps = total_timesteps
- # Make a score network
- self.score_network = ScoreNetwork(
- data_dim=data_dim,
- time_dim=time_dim
- )
- self.offset = offset
- self.spread = spread
- # Make alphas and betas
- self.betas = torch.linspace(beta_start, beta_end, total_timesteps) # Use linear schedule
- self.alphas = 1 - self.betas
- # Make cumulative products
- self.cumulative_alphas = torch.cumprod(self.alphas, dim=0)
- self.cumulative_betas = 1 - self.cumulative_alphas
- self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
- self.alphas_cumprod_prev = F.pad(
- self.alphas_cumprod[:-1], (1, 0), value=1.)
- # required for self.add_noise
- self.sqrt_alphas_cumprod = self.alphas_cumprod ** 0.5
- self.sqrt_one_minus_alphas_cumprod = (1 - self.alphas_cumprod) ** 0.5
- # required for reconstruct_x0
- self.sqrt_inv_alphas_cumprod = torch.sqrt(1 / self.alphas_cumprod)
- self.sqrt_inv_alphas_cumprod_minus_one = torch.sqrt(
- 1 / self.alphas_cumprod - 1)
- # required for q_posterior
- self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
- self.posterior_mean_coef2 = (1. - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1. - self.alphas_cumprod)
- def get_variance(self, t):
- if t == 0:
- return 0
- variance = self.betas[t] * (1. - self.alphas_cumprod_prev[t]) / (1. - self.alphas_cumprod[t])
- variance = variance.clip(1e-20)
- return variance
- def predict_noise(self, x, t):
- """
- Predicts the noise at a noisy sample and time step
- """
- return self.score_network(x, t)
- def step(self, model_output, timestep, x_t):
- t = timestep
- # Reconstruct the original sample
- s1 = self.sqrt_inv_alphas_cumprod[t].reshape(-1, 1)
- s2 = self.sqrt_inv_alphas_cumprod_minus_one[t].reshape(-1, 1)
- pred_original_sample = s1 * x_t - s2 * model_output
- # Predict the previous sample
- s1 = self.posterior_mean_coef1[t].reshape(-1, 1)
- s2 = self.posterior_mean_coef2[t].reshape(-1, 1)
- pred_prev_sample = s1 * pred_original_sample + s2 * x_t
- # Add noise back to the sample
- variance = 0
- if t > 0:
- noise = torch.randn_like(model_output)
- variance = (self.get_variance(t) ** 0.5) * noise
- # if t > 0:
- # noise = torch.randn_like(model_output)
- # variance = (self.get_variance(t) ** 0.5) * model_output
- pred_prev_sample = pred_prev_sample + variance
- return pred_prev_sample
- def add_noise(self, x_start, x_noise, timesteps):
- s1 = self.sqrt_alphas_cumprod[timesteps]
- s2 = self.sqrt_one_minus_alphas_cumprod[timesteps]
- s1 = s1.reshape(-1, 1)
- s2 = s2.reshape(-1, 1)
- return s1 * x_start + s2 * x_noise
- def sample(self, num_samples=1000, num_timesteps=1000, device='cpu'):
- orig_sample = torch.randn(num_samples, 2) * self.spread
- orig_sample[:, 0] -= self.offset
- sample = orig_sample.clone()
- timesteps = list(range(self.total_timesteps))[::-1]
- interemediate_values = torch.empty(num_samples, num_timesteps, 2)
- for i, t in enumerate(timesteps):
- t = torch.from_numpy(np.repeat(t, num_samples)).long()
- with torch.no_grad():
- residual = self.predict_noise(sample, t)
- sample = self.step(residual, t[0], sample)
- interemediate_values[:, i, :] = sample
- return orig_sample, sample, interemediate_values
- # Train the model
- def train(model, num_iterations=1000, batch_size=32, learning_rate=1e-4, device='cpu'):
- model = model.to(device)
- optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.98))
- loss_fn = nn.MSELoss()
- losses = []
- for i in tqdm(range(num_iterations)):
- optimizer.zero_grad()
- noise_free_data = torch.randn(batch_size, 2).to(device) * model.spread
- noise_free_data[:, 0] += model.offset
- time_steps = torch.randint(0, model.total_timesteps, (noise_free_data.shape[0],)) # Randomly sample time steps uniformly
- noise = torch.randn(noise_free_data.shape) * model.spread
- noise[:, 0] -= model.offset
- noisy_data = model.add_noise(noise_free_data, noise, time_steps)
- predicted_noise = model.predict_noise(noisy_data, time_steps)
- loss = loss_fn(predicted_noise, noise)
- losses.append(loss.item())
- loss.backward()
- optimizer.step()
- if i % 10000 == 0:
- # Generate 1000 samples and plot them
- orig_sample, samples, _ = model.sample(num_samples=500, num_timesteps=1000, device=device)
- samples = samples.detach().cpu().numpy()
- plt.figure()
- # Plot the true data
- data = torch.randn(1000, 2).cpu().numpy() * model.spread
- data[:, 0] += model.offset
- plt.scatter(orig_sample[:, 0], orig_sample[:, 1], label='Noise dist', alpha=0.5)
- plt.scatter(data[:, 0], data[:, 1], label='True Data', alpha=0.5)
- plt.scatter(samples[:, 0], samples[:, 1], label='Samples', alpha=0.5)
- plt.legend()
- plt.savefig(f'samples_{i}.png')
- # Plot the losses
- plt.figure()
- losses = np.convolve(losses, np.ones(5000) / 5000, mode='valid')
- plt.plot(losses)
- plt.savefig('plots/losses.png')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement