daily pastebin goal
70%
SHARE
TWEET

Untitled

a guest Oct 19th, 2018 64 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import numpy as np
  2.  
  3. def generate(count):
  4.     X = np.random.random_integers(0, high=255, size=(count, 9))
  5.     Y = X.dot(np.array([1, 1, 1, 0, 0, 0, -1, -1, -1]))
  6.     Y[Y > 0] = 1
  7.     Y[Y < 0] = 0
  8.     return X, Y
  9.  
  10. def run():
  11.     m = 5000
  12.     c = 9
  13.     epochs = 50000
  14.  
  15.     # weight vector
  16.     w = np.random.randn(c).reshape(c, 1)
  17.  
  18.     # training loop
  19.     lr = 1e-3
  20.     print('\n\n{:^8s} | {:^8s} | {:^6s}'.format('epoch', 'loss', 'acc'))
  21.     print('----------------------------')
  22.     for t in range(epochs):
  23.         # get new training data
  24.         X, y = generate(m)
  25.         X = X / 255
  26.         y = y.reshape(m, 1) * 2 - 1
  27.  
  28.         # model function
  29.         h = X.dot(w)
  30.  
  31.         # compute loss
  32.         loss = np.square(h - y).mean()
  33.  
  34.         # compute accuracy
  35.         acc = (np.sign(h) == y).mean()
  36.  
  37.         if t % 5000 == 0:
  38.             print('{:>8d} | {:>8f} | {:>.4f}'.format(t, loss, acc))
  39.  
  40.         # no more to do
  41.         if acc >= 1:
  42.             print('\nStopping:\n{:>8d} | {:>8f} | {:>.4f}'.format(t, loss, acc))
  43.             break
  44.  
  45.         # grad + update
  46.         grad = 2 * (h - y)
  47.         w = w - lr * X.T.dot(grad) / m
  48.  
  49.     print('\nFinal W = \n\n{}'.format(w))
  50.  
  51. if __name__ == "__main__":
  52.     run()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top