Advertisement
UF6

Neural Network Differential Equations

UF6
Oct 20th, 2024
111
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 1.21 KB | Source Code | 0 0
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. import matplotlib.pyplot as plt
  5.  
  6. # Neural network model
  7. class Net(nn.Module):
  8.     def __init__(self):
  9.         super().__init__()
  10.         self.fc = nn.Sequential(
  11.             nn.Linear(1, 50),
  12.             nn.Tanh(),
  13.             nn.Linear(50, 1)
  14.         )
  15.  
  16.     def forward(self, x):
  17.         return self.fc(x)
  18.  
  19. # Differential equation dy/dx = -y (loss function)
  20. def ode_loss(x, model):
  21.     y = model(x)
  22.     dy_dx = torch.autograd.grad(y, x, grad_outputs=torch.ones_like(x), create_graph=True)[0]
  23.     return torch.mean((dy_dx + y)**2)
  24.  
  25. x_train = torch.linspace(0, 2, 100).view(-1, 1).requires_grad_(True)
  26. model = Net()
  27. optimizer = optim.Adam(model.parameters(), lr=0.01)
  28.  
  29. for epoch in range(2000):
  30.     optimizer.zero_grad()
  31.     loss = ode_loss(x_train, model)
  32.     loss.backward()
  33.     optimizer.step()
  34.     if epoch % 100 == 0:
  35.         print(f'Epoch {epoch}, Loss: {loss.item()}')
  36.  
  37. # Plot result
  38. with torch.no_grad():
  39.     x_test = torch.linspace(0, 2, 100).view(-1, 1)
  40.     y_pred = model(x_test)
  41.     plt.plot(x_test, y_pred, label="NN solution")
  42.     plt.plot(x_test, torch.exp(-x_test), label="Exact solution")
  43.     plt.legend()
  44.     plt.show()
  45.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement