Guest User

Untitled

a guest
May 26th, 2018
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.81 KB | None | 0 0
  1. import random
  2. import torch
  3. from torch import nn
  4. from torch.autograd import Variable
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7.  
  8. torch.manual_seed(1) # reproducible
  9.  
  10. # Hyper Parameters
  11. TIME_STEP = 10 # rnn time step
  12. INPUT_SIZE = 1 # rnn input size
  13. LR = 0.02 # learning rate
  14.  
  15.  
  16. class RNN(nn.Module):
  17. def __init__(self):
  18. super(RNN, self).__init__()
  19.  
  20. self.hidden_size = 32
  21.  
  22. self.rnn = nn.RNN(
  23. input_size=INPUT_SIZE,
  24. hidden_size=self.hidden_size, # rnn hidden unit
  25. num_layers=1, # number of rnn layer
  26. batch_first=True, # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
  27. nonlinearity="tanh"
  28. )
  29. # self.out = nn.Linear(64, 1)
  30.  
  31. self.layer_sizes = [self.hidden_size, 32, 32, 1]
  32. D_in, H1, H2, D_out = self.layer_sizes
  33.  
  34. # print(D_in, H1, D_out)
  35.  
  36. self.linear1 = nn.Linear(D_in, H1)
  37. self.activation1 = nn.Tanh()
  38. self.linear2 = nn.Linear(H1, H2)
  39. self.activation2 = nn.Tanh()
  40. self.linear3 = nn.Linear(H2, D_out)
  41. self.activation3 = nn.Tanh()
  42.  
  43. def out(self, x):
  44. x = self.linear1(x)
  45. x = self.activation1(x)
  46. # print(x.size())
  47.  
  48. x = self.linear2(x)
  49. x = self.activation2(x)
  50. # print(x.size())
  51.  
  52. x = self.linear3(x)
  53. x = self.activation3(x)
  54. # print(x.size())
  55.  
  56. return x
  57.  
  58. def forward(self, x, h_state):
  59. # x (batch, time_step, input_size)
  60. # h_state (n_layers, batch, hidden_size)
  61. # r_out (batch, time_step, hidden_size)
  62. r_out, h_state = self.rnn(x, h_state)
  63.  
  64. print("x", x.shape)
  65. print("h_state", h_state.shape)
  66. print("r_out", r_out.shape)
  67.  
  68. r_out = r_out.view(-1, self.hidden_size)
  69.  
  70. y = self.out(r_out)
  71.  
  72. print("output", y.shape)
  73.  
  74. return y, h_state
  75.  
  76.  
  77. def sample(x):
  78. y = np.sin(x) / 2
  79. return y
  80.  
  81.  
  82. def train_realtime(model, loss_fn, optimizer, n_steps=40):
  83. h_state = None # for initial hidden state
  84.  
  85. for step in range(n_steps):
  86. start, end = step*np.pi*0.8, (step+1)*np.pi*0.8 # time range
  87. steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
  88.  
  89. print(start, end)
  90. print(steps)
  91.  
  92. x_np = steps # float32 for converting torch FloatTensor
  93. y_np = sample(x_np)
  94.  
  95. print("y", y_np.shape)
  96. # print(x_np[np.newaxis, :, np.newaxis])
  97.  
  98. x = Variable(torch.from_numpy(x_np[np.newaxis, :, np.newaxis])) # shape (batch, time_step, input_size)
  99. y = Variable(torch.from_numpy(y_np[np.newaxis, :, np.newaxis]))
  100.  
  101. prediction, h_state = model(x, h_state) # rnn output
  102.  
  103. # !! next step is important !!
  104. h_state = Variable(h_state.data) # repack the hidden state, break the connection from last iteration
  105.  
  106. loss = loss_fn(prediction, y) # loss
  107. optimizer.zero_grad() # clear gradients for this training step
  108. loss.backward() # backpropagation, compute gradients
  109. optimizer.step() # apply gradients
  110.  
  111. # plotting
  112. plt.plot(steps, y_np.flatten(), 'r-')
  113. plt.plot(steps, prediction.data.numpy().flatten(), 'b-')
  114. plt.draw()
  115. plt.pause(0.05)
  116.  
  117. plt.ioff()
  118. plt.show()
  119.  
  120.  
  121. def test(model, n_steps=40):
  122. h_state = None # for initial hidden state
  123.  
  124. for step in range(n_steps):
  125. start, end = step*np.pi*0.8, (step+1)*np.pi*0.8 # time range
  126. steps = np.linspace(start, end, TIME_STEP, dtype=np.float32)
  127.  
  128. print(start, end)
  129. print(steps)
  130.  
  131. x_np = steps # float32 for converting torch FloatTensor
  132. y_np = sample(x_np)
  133.  
  134. print("y", y_np.shape)
  135. # print(x_np[np.newaxis, :, np.newaxis])
  136.  
  137. x = Variable(torch.from_numpy(x_np[np.newaxis, :, np.newaxis])) # shape (batch, time_step, input_size)
  138. y = Variable(torch.from_numpy(y_np[np.newaxis, :, np.newaxis]))
  139.  
  140. prediction, h_state = model(x, h_state) # rnn output
  141.  
  142. # !! next step is important !!
  143. h_state = Variable(h_state.data) # repack the hidden state, break the connection from last iteration
  144.  
  145. # plotting
  146. plt.plot(steps, y_np.flatten(), 'r-')
  147. plt.plot(steps, prediction.data.numpy().flatten(), 'b-')
  148. plt.draw()
  149. plt.pause(0.05)
  150.  
  151. plt.ioff()
  152. plt.show()
  153.  
  154.  
  155. if __name__ == "__main__":
  156. rnn = RNN()
  157. print(rnn)
  158.  
  159. optimizer = torch.optim.Adam(rnn.parameters(), lr=LR) # optimize all rnn parameters
  160. loss_fn = nn.MSELoss()
  161.  
  162. plt.figure(1, figsize=(12, 5))
  163. plt.ion() # continuously plot
  164.  
  165. train_realtime(model=rnn,
  166. loss_fn=loss_fn,
  167. optimizer=optimizer)
  168.  
  169. test(model=rnn)
Add Comment
Please, Sign In to add comment