Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #Numpy for matrix math and matplotlib for plotting loss
- import numpy as np
- import matplotlib.pyplot as plt
- def forward(x, w1, w2):
- #BS, D_in * D_in, H = BS, H
- hidden_raw = np.matmul(x, w1)
- #BS, H = BS, H
- hidden = np.maximum(hidden_raw, 0)
- #BS, H * H, D_out = BS, D_out
- yhat = np.matmul(hidden, w2)
- #yhat for loss and prediction. hidden for backprop
- return yhat, hidden
- def backward(loss, hidden, x, y, yhat):
- #BS, D_out = BS, D_out
- grad_to_yhat = 2 * (yhat - y)
- #H, BS * BS, D_out = H, D_out
- grad_w2 = np.matmul(hidden.T, grad_to_yhat)
- #BS, 10 * 10, H = BS, H
- grad_hidden = np.matmul(grad_to_yhat, w2.T)
- #D_in, BS * BS, H = D_in, H
- grad_w1 = np.matmul(x.T, grad_hidden)
- return grad_w1, grad_w2
- # N is batch size; D_in is input dimension;
- # H is hidden dimension; D_out is output dimension.
- N, D_in, H, D_out = 64, 1000, 100, 10
- # Create random input and output data
- x = np.random.randn(N, D_in)
- y = np.random.randn(N, D_out)
- #Randomly initialize network weights
- w1 = np.random.randn(D_in, H)
- w2 = np.random.randn(H, D_out)
- #Track losses
- losses = []
- #Set a constant learning rate
- learning_rate = .00001
- #Perform full-batch optimization steps
- for t in range(500):
- #Forward propagate through the network
- yhat, hidden = forward(x, w1, w2)
- #Calculate our loss matrix. Sample by y_dimension
- loss_matrix = np.square(y - yhat)
- #Backpropagate and calculate gradients
- grad_w1, grad_w2 = backward(loss_matrix, hidden, x, y, yhat)
- #Update the weights by a small step in the direction of the gradient
- w1 = w1 - grad_w1 * learning_rate
- w2 = w2 - grad_w2 * learning_rate
- # norm of the loss vector for each sample. Take the mean between samples
- loss_rms = np.sqrt(np.square(loss_matrix).sum(1)).mean()
- losses.append(loss_rms)
- print(losses)
- #Visualize our losses over time, starting after the initial training
- plt.plot(losses)
- plt.title('Loss for simple model\napproaches ' + str(losses[-1])[:5])
- plt.savefig('model_1.jpg')
- plt.show()
Add Comment
Please, Sign In to add comment