Guest User

Untitled

a guest
Apr 23rd, 2025
57
0
270 days
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.41 KB | None | 0 0
  1. import torch
  2.  
  3. sigmoid = torch.nn.Sigmoid()
  4. tanh = torch.nn.Tanh()
  5.  
  6. W_fx,W_fh,b_f =-1, 2,-3
  7. W_ix,W_ih,b_i = 1,-2, 3
  8. W_gx,W_gh,b_g =-1, 2,-3
  9. W_ox,W_oh,b_o = 1,-2, 3
  10. W_y,b_y = 1,2
  11.  
  12. h_bar,c_bar = 1,1
  13. x0,x1 = 1,2
  14.  
  15. class SimpleLSTM(torch.nn.Module):
  16.    
  17.     def __init__(self):
  18.         super().__init__()
  19.  
  20.         self.W_fx = torch.nn.Parameter(torch.Tensor([W_fx]))
  21.         self.W_fh = torch.nn.Parameter(torch.Tensor([W_fh]))
  22.         self.b_f = torch.nn.Parameter(torch.Tensor([b_f]))
  23.  
  24.         self.W_ix = torch.nn.Parameter(torch.Tensor([W_ix]))
  25.         self.W_ih = torch.nn.Parameter(torch.Tensor([W_ih]))
  26.         self.b_i = torch.nn.Parameter(torch.Tensor([b_i]))
  27.  
  28.         self.W_gx = torch.nn.Parameter(torch.Tensor([W_gx]))
  29.         self.W_gh = torch.nn.Parameter(torch.Tensor([W_gh]))
  30.         self.b_g = torch.nn.Parameter(torch.Tensor([b_g]))
  31.  
  32.         self.W_ox = torch.nn.Parameter(torch.Tensor([W_ox]))
  33.         self.W_oh = torch.nn.Parameter(torch.Tensor([W_oh]))
  34.         self.b_o = torch.nn.Parameter(torch.Tensor([b_o]))
  35.  
  36.         self.W_y = torch.nn.Parameter(torch.Tensor([W_y]))
  37.         self.b_y = torch.nn.Parameter(torch.Tensor([b_y]))
  38.  
  39.     def forward(self, x, prev_h, prev_c):
  40.  
  41.         z_f = self.W_fx*x + self.W_fh*prev_h + self.b_f
  42.         forget_gate = sigmoid(z_f)
  43.  
  44.         z_i = self.W_ix*x + self.W_ih*prev_h + self.b_i
  45.         input_gate = sigmoid(z_i)
  46.  
  47.         z_g = self.W_gx*x + self.W_gh*prev_h + self.b_g
  48.         cell_gate = tanh(z_g)
  49.  
  50.         z_o = self.W_ox*x + self.W_oh*prev_h + self.b_o
  51.         output_gate = sigmoid(z_o)
  52.  
  53.         c = forget_gate * prev_c + input_gate * cell_gate
  54.         h = tanh(c) * output_gate
  55.  
  56.         y = self.W_y*h + self.b_y
  57.  
  58.         return y,h,c
  59.  
  60. mse = torch.nn.MSELoss()
  61. lstm = SimpleLSTM()
  62.  
  63. print("------- t0 --------")
  64. y0,h0,c0 = lstm(x0,h_bar,c_bar)
  65. loss0 = mse(torch.Tensor([5]),y0)
  66. print("y,h,c: ",y0.item()," -",h0.item()," -",c0.item())
  67. print("loss: ",loss0.item())
  68.  
  69. print("------- t1 --------")
  70. y1,h1,c1 = lstm(x1,h0,c0)
  71. loss1 = mse(torch.Tensor([5]),y1)
  72. print("y,h,c: ",y1.item()," -",h1.item()," -",c1.item())
  73. print("loss: ",loss1.item())
  74.  
  75. print()
  76. print("---------------")
  77. loss = loss0 + loss1
  78. print("total loss: ", loss)
  79. loss.backward()
  80. print("grad wrt recurrent w in forget gate:")
  81. print(lstm.W_fh.grad)
  82.  
  83. print()
  84. print("########################      manual version     #######################")
  85. print()
  86.  
  87. def tanh_prime(x):
  88.     return 1-(tanh(x)**2)
  89.  
  90. def sigmoid_prime(x):
  91.     return sigmoid(x)*(1-sigmoid(x))
  92.  
  93. x,h,c = x0,h_bar,c_bar
  94.  
  95. print("------- t0 --------")
  96. zF0 = torch.Tensor([W_fx*x + W_fh*h + b_f])
  97. F0 = sigmoid(zF0)
  98. zI0 = torch.Tensor([W_ix*x + W_ih*h + b_i])
  99. I0 = sigmoid(zI0)
  100. zG0 = torch.Tensor([W_gx*x + W_gh*h + b_g])
  101. G0 = tanh(zG0)
  102. zO0 = torch.Tensor([W_ox*x + W_oh*h + b_o])
  103. O0 = sigmoid(zO0)
  104.  
  105. c0 = F0*c + I0*G0
  106. h0 = tanh(c0)*O0
  107. y0 = W_y*h0 + b_y
  108.  
  109. loss0 = (torch.Tensor([5])-y0)**2
  110. print("y,h,c: ",y0.item()," -",h0.item()," -",c0.item())
  111. print("loss: ",loss0.item())
  112.  
  113. x,h,c = x1,h0,c0
  114.  
  115. print("------- t1 --------")
  116. zF1 = torch.Tensor([W_fx*x + W_fh*h + b_f])
  117. F1 = sigmoid(zF1)
  118. zI1 = torch.Tensor([W_ix*x + W_ih*h + b_i])
  119. I1 = sigmoid(zI1)
  120. zG1 = torch.Tensor([W_gx*x + W_gh*h + b_g])
  121. G1 = tanh(zG1)
  122. zO1 = torch.Tensor([W_ox*x + W_oh*h + b_o])
  123. O1 = sigmoid(zO1)
  124.  
  125. c1 = F1*c + I1*G1
  126. h1 = tanh(c1)*O1
  127. y1 = W_y*h1 + b_y
  128.  
  129. loss1 = (torch.Tensor([5])-y1)**2
  130. print("y,h,c: ",y1.item()," -",h1.item()," -",c1.item())
  131. print("loss: ",loss1.item())
  132.  
  133. dL_dy0 = 2*(y0-torch.Tensor([5]))
  134. d_L_0 = dL_dy0 * W_y * O0 * tanh_prime(c0) * c_bar * sigmoid_prime(zF0) * h_bar
  135.  
  136. dL_dy1 = 2*(y1-torch.Tensor([5]))
  137. d_L_1 = dL_dy1 * W_y * ( O1 * tanh_prime(c1) * \
  138.                             ( c0 * sigmoid_prime(zF1) * (h0 + W_fh * (O0 * tanh_prime(c0) * c_bar * sigmoid_prime(zF0) * h_bar)) \
  139.                             + F1 * (c_bar * sigmoid_prime(zF0) * h_bar) \
  140.                             + G1 * sigmoid_prime(zI1) * W_ih * (O0 * tanh_prime(c0) * c_bar * sigmoid_prime(zF0) * h_bar) \
  141.                             + I1 * tanh_prime(zG1) * W_gh * (O0 * tanh_prime(c0) * c_bar * sigmoid_prime(zF0) * h_bar)) \
  142.                     + tanh(c1) * sigmoid_prime(zO1) * W_oh * (O0 * tanh_prime(c0) * c_bar * sigmoid_prime(zF0) * h_bar))
  143.  
  144. w_fh_grad = d_L_0 + d_L_1
  145.  
  146. print()
  147. print("---------------")
  148. loss = loss0 + loss1
  149. print("total loss: ", loss)
  150. print("grad wrt recurrent w in forget gate:")
  151. print(w_fh_grad)
  152.  
Advertisement
Add Comment
Please, Sign In to add comment