Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import torch
- sigmoid = torch.nn.Sigmoid()
- tanh = torch.nn.Tanh()
- W_fx,W_fh,b_f =-1, 2,-3
- W_ix,W_ih,b_i = 1,-2, 3
- W_gx,W_gh,b_g =-1, 2,-3
- W_ox,W_oh,b_o = 1,-2, 3
- W_y,b_y = 1,2
- h_bar,c_bar = 1,1
- x0,x1 = 1,2
- class SimpleLSTM(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.W_fx = torch.nn.Parameter(torch.Tensor([W_fx]))
- self.W_fh = torch.nn.Parameter(torch.Tensor([W_fh]))
- self.b_f = torch.nn.Parameter(torch.Tensor([b_f]))
- self.W_ix = torch.nn.Parameter(torch.Tensor([W_ix]))
- self.W_ih = torch.nn.Parameter(torch.Tensor([W_ih]))
- self.b_i = torch.nn.Parameter(torch.Tensor([b_i]))
- self.W_gx = torch.nn.Parameter(torch.Tensor([W_gx]))
- self.W_gh = torch.nn.Parameter(torch.Tensor([W_gh]))
- self.b_g = torch.nn.Parameter(torch.Tensor([b_g]))
- self.W_ox = torch.nn.Parameter(torch.Tensor([W_ox]))
- self.W_oh = torch.nn.Parameter(torch.Tensor([W_oh]))
- self.b_o = torch.nn.Parameter(torch.Tensor([b_o]))
- self.W_y = torch.nn.Parameter(torch.Tensor([W_y]))
- self.b_y = torch.nn.Parameter(torch.Tensor([b_y]))
- def forward(self, x, prev_h, prev_c):
- z_f = self.W_fx*x + self.W_fh*prev_h + self.b_f
- forget_gate = sigmoid(z_f)
- z_i = self.W_ix*x + self.W_ih*prev_h + self.b_i
- input_gate = sigmoid(z_i)
- z_g = self.W_gx*x + self.W_gh*prev_h + self.b_g
- cell_gate = tanh(z_g)
- z_o = self.W_ox*x + self.W_oh*prev_h + self.b_o
- output_gate = sigmoid(z_o)
- c = forget_gate * prev_c + input_gate * cell_gate
- h = tanh(c) * output_gate
- y = self.W_y*h + self.b_y
- return y,h,c
- mse = torch.nn.MSELoss()
- lstm = SimpleLSTM()
- print("------- t0 --------")
- y0,h0,c0 = lstm(x0,h_bar,c_bar)
- loss0 = mse(torch.Tensor([5]),y0)
- print("y,h,c: ",y0.item()," -",h0.item()," -",c0.item())
- print("loss: ",loss0.item())
- print("------- t1 --------")
- y1,h1,c1 = lstm(x1,h0,c0)
- loss1 = mse(torch.Tensor([5]),y1)
- print("y,h,c: ",y1.item()," -",h1.item()," -",c1.item())
- print("loss: ",loss1.item())
- print()
- print("---------------")
- loss = loss0 + loss1
- print("total loss: ", loss)
- loss.backward()
- print("grad wrt recurrent w in forget gate:")
- print(lstm.W_fh.grad)
- print()
- print("######################## manual version #######################")
- print()
- def tanh_prime(x):
- return 1-(tanh(x)**2)
- def sigmoid_prime(x):
- return sigmoid(x)*(1-sigmoid(x))
- x,h,c = x0,h_bar,c_bar
- print("------- t0 --------")
- zF0 = torch.Tensor([W_fx*x + W_fh*h + b_f])
- F0 = sigmoid(zF0)
- zI0 = torch.Tensor([W_ix*x + W_ih*h + b_i])
- I0 = sigmoid(zI0)
- zG0 = torch.Tensor([W_gx*x + W_gh*h + b_g])
- G0 = tanh(zG0)
- zO0 = torch.Tensor([W_ox*x + W_oh*h + b_o])
- O0 = sigmoid(zO0)
- c0 = F0*c + I0*G0
- h0 = tanh(c0)*O0
- y0 = W_y*h0 + b_y
- loss0 = (torch.Tensor([5])-y0)**2
- print("y,h,c: ",y0.item()," -",h0.item()," -",c0.item())
- print("loss: ",loss0.item())
- x,h,c = x1,h0,c0
- print("------- t1 --------")
- zF1 = torch.Tensor([W_fx*x + W_fh*h + b_f])
- F1 = sigmoid(zF1)
- zI1 = torch.Tensor([W_ix*x + W_ih*h + b_i])
- I1 = sigmoid(zI1)
- zG1 = torch.Tensor([W_gx*x + W_gh*h + b_g])
- G1 = tanh(zG1)
- zO1 = torch.Tensor([W_ox*x + W_oh*h + b_o])
- O1 = sigmoid(zO1)
- c1 = F1*c + I1*G1
- h1 = tanh(c1)*O1
- y1 = W_y*h1 + b_y
- loss1 = (torch.Tensor([5])-y1)**2
- print("y,h,c: ",y1.item()," -",h1.item()," -",c1.item())
- print("loss: ",loss1.item())
- dL_dy0 = 2*(y0-torch.Tensor([5]))
- d_L_0 = dL_dy0 * W_y * O0 * tanh_prime(c0) * c_bar * sigmoid_prime(zF0) * h_bar
- dL_dy1 = 2*(y1-torch.Tensor([5]))
- d_L_1 = dL_dy1 * W_y * ( O1 * tanh_prime(c1) * \
- ( c0 * sigmoid_prime(zF1) * (h0 + W_fh * (O0 * tanh_prime(c0) * c_bar * sigmoid_prime(zF0) * h_bar)) \
- + F1 * (c_bar * sigmoid_prime(zF0) * h_bar) \
- + G1 * sigmoid_prime(zI1) * W_ih * (O0 * tanh_prime(c0) * c_bar * sigmoid_prime(zF0) * h_bar) \
- + I1 * tanh_prime(zG1) * W_gh * (O0 * tanh_prime(c0) * c_bar * sigmoid_prime(zF0) * h_bar)) \
- + tanh(c1) * sigmoid_prime(zO1) * W_oh * (O0 * tanh_prime(c0) * c_bar * sigmoid_prime(zF0) * h_bar))
- w_fh_grad = d_L_0 + d_L_1
- print()
- print("---------------")
- loss = loss0 + loss1
- print("total loss: ", loss)
- print("grad wrt recurrent w in forget gate:")
- print(w_fh_grad)
Advertisement
Add Comment
Please, Sign In to add comment