Advertisement
jokeris

trunc pytorch

Jan 5th, 2018
82
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.09 KB | None
  1. import torch
  2. import torch.nn as nn
  3. from torch.autograd import Variable
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6.  
  7. BATCH_SIZE = 1
  8. INPUT_DIM = 1
  9. OUTPUT_DIM = 1
  10. DTYPE = np.float64
  11. TIMESTEPS = 50
  12.  
  13.  
  14. class Net(nn.Module):
  15.     def __init__(self, input_dim, hidden_dim, output_dim, hidden_layers):
  16.         super(Net, self).__init__()
  17.         self.input_dim = input_dim
  18.         self.hidden_dim = hidden_dim
  19.         self.output_dim = output_dim
  20.         self.hidden_layers = hidden_layers
  21.         self.rnn = nn.RNN(input_dim, hidden_dim, hidden_layers)
  22.         self.h2o = nn.Linear(hidden_dim, output_dim)
  23.  
  24.     def forward(self, x, hidden):
  25.         output, hidden = self.rnn(x, hidden)
  26.         output = self.h2o(output)
  27.         return output, hidden
  28.  
  29.     def init_hidden(self, batch_size):
  30.         h_0 = Variable(torch.zeros(self.hidden_layers, batch_size, self.hidden_dim)).cuda()
  31.  
  32.         if DTYPE == np.float32:
  33.             return h_0.float()
  34.         else:
  35.             return h_0.double()
  36.  
  37.  
  38. def weights_init(m):
  39.     if isinstance(m, nn.RNN):
  40.         nn.init.xavier_uniform(m.weight_ih_l0.data)
  41.         nn.init.orthogonal(m.weight_hh_l0.data)
  42.         nn.init.constant(m.bias_ih_l0.data, 0)
  43.         nn.init.constant(m.bias_hh_l0.data, 0)
  44.     if isinstance(m, nn.Linear):
  45.         nn.init.xavier_uniform(m.weight.data)
  46.         nn.init.constant(m.bias.data, 0)
  47.  
  48.  
  49. def pad_sequence(sequence, timesteps):
  50.     if sequence.shape[0] % timesteps == 0:
  51.         return sequence, timesteps
  52.     else:
  53.         pad = timesteps - (sequence.shape[0] % timesteps)
  54.         return np.pad(sequence, ((0, pad), (0, 0)), 'constant'), pad
  55.  
  56.  
  57. data = np.loadtxt('data/mg17.csv', delimiter=',', dtype=DTYPE)
  58. X_data = data[:, [0]]
  59. Y_data = data[:, [1]]
  60. trX_data = X_data[:4000, :]
  61. trY_data = Y_data[:4000, :]
  62. vlX = torch.from_numpy(np.expand_dims(X_data[4000:5000, :], axis=1)).cuda()
  63. vlY = torch.from_numpy(np.expand_dims(Y_data[4000:5000, :], axis=1)).cuda()
  64.  
  65. if np.isnan(TIMESTEPS):
  66.     trX = torch.from_numpy(np.expand_dims(trX_data, axis=1)).cuda()
  67.     trY = torch.from_numpy(np.expand_dims(trY_data, axis=1)).cuda()
  68.     tr_seq_lengths = [trX.shape[0]]
  69. else:
  70.     trX, _ = pad_sequence(trX_data, TIMESTEPS)
  71.     trX = torch.from_numpy(np.reshape(trX, (TIMESTEPS, -1, INPUT_DIM))).cuda()
  72.     trY, tr_pad = pad_sequence(trY_data, TIMESTEPS)
  73.     trY = torch.from_numpy(np.reshape(trY, (TIMESTEPS, -1, INPUT_DIM))).cuda()
  74.     tr_seq_lengths = np.full((trX.shape[1],), TIMESTEPS, dtype=np.int)
  75.     tr_seq_lengths[-1] = tr_pad
  76.  
  77.  
  78. loss_fcn = nn.MSELoss()
  79. for r in range(5):
  80.     model = Net(INPUT_DIM, 10, OUTPUT_DIM, 1).cuda()
  81.     if DTYPE == np.float32:
  82.         model = model.float()
  83.     else:
  84.         model = model.double()
  85.     model.apply(weights_init)
  86.     optimizer = torch.optim.Adam(model.parameters(), lr=0.01, eps=2e-16)
  87.  
  88.     epochs = 2000
  89.     tr_loss_plt = np.zeros((epochs,))
  90.     vl_loss_plt = np.zeros((epochs,))
  91.  
  92.     for e in range(epochs):
  93.         hidden = model.init_hidden(BATCH_SIZE)
  94.  
  95.         model.train()
  96.         tot_loss = 0
  97.         for seq in range(trX.shape[1]):
  98.             x = Variable(trX[:tr_seq_lengths[seq], [seq], :])
  99.             y = Variable(trY[:tr_seq_lengths[seq], [seq], :])
  100.             hidden.detach_()
  101.             model.zero_grad()
  102.             output, hidden = model(x, hidden)
  103.             loss = loss_fcn(output, y)
  104.             loss.backward()
  105.             optimizer.step()
  106.             tot_loss += loss.cpu().data
  107.         tot_loss /= trX.shape[1]
  108.         tr_loss_plt[e] = tot_loss
  109.  
  110.         model.eval()
  111.         hidden = model.init_hidden(BATCH_SIZE)
  112.         x = Variable(vlX)
  113.         y = Variable(vlY)
  114.         output, _ = model(x, hidden)
  115.         loss = loss_fcn(output, y)
  116.         vl_loss_plt[e] = loss.cpu().data
  117.  
  118.         # print("Epoch", e + 1, "TR:", tr_loss_plt[e], "VL:", vl_loss_plt[e])
  119.  
  120.     # plt.clf()
  121.     # plt.plot(tr_loss_plt)
  122.     # plt.plot(vl_loss_plt)
  123.     # plt.xlabel("epoch")
  124.     # plt.ylabel("loss")
  125.     # plt.legend(["TR", "VL"])
  126.     # plt.savefig("pytorch-mg-rnn-trunk("+str(r)+").png")
  127.     print("TR:", tr_loss_plt[-1], "VL:", vl_loss_plt[-1])
Advertisement
RAW Paste Data Copied
Advertisement