Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from scipy.io import loadmat
- import matplotlib
- import matplotlib.pyplot as plt
- import numpy as np
- def train(mnist_input,mnist_output):
- stala_uczenia = 0.00001
- for iteration in range(2000):
- wagi1 = np.random.randn(64, 28 * 28)
- wagi2 = np.random.randn(10, 64)
- przed1 = np.dot(wagi1, mnist_input)
- po1 = sigmoid(przed1)
- przed2 = np.dot(wagi2, po1)
- po2 = np.exp(przed2) / np.sum(np.exp(przed2), axis=0)
- d_k = (po2 - mnist_output) * po2 * (1 - po2)
- d_j = np.dot(wagi2.T, d_k) * po1 * (1 - po1)
- wagi2 = wagi2 - stala_uczenia * np.dot(d_k, po1.T)
- wagi1 = wagi1 - stala_uczenia * np.dot(d_j, mnist_input.T)
- f = open("wagi1.txt", "w")
- for i in range(64):
- for j in range(784):
- f.write(str(wagi1[i][j]) + " ")
- f.write("\n")
- f = open("wagi2.txt", "w")
- for i in range(10):
- for j in range(64):
- f.write(str(wagi2[i][j]) + " ")
- f.write("\n")
- def sigmoid(x):
- return 1 / (1 + np.exp(-x))
- mnist = loadmat("./mnist-original.mat")
- mnist_input = mnist["data"].T / 255
- mnist_output = mnist["label"][0]
- temp = list()
- for i in range(70000):
- letter = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- letter[np.math.floor(mnist_output[i])] = 1
- temp.append(letter)
- mnist_output = np.array(temp).T
- mnist_input_train = mnist_input[:60000].T
- mnist_input_test = mnist_input[60000:].T
- mnist_output_train = mnist_output[:, :60000]
- mnist_output_test = mnist_output[:, 60000:]
- index = np.random.permutation(60000)
- mnist_input_train = mnist_input_train[:, index]
- mnist_output_train = mnist_output_train[:, index]
- # train(mnist_input_train,mnist_output_train)
- with open('wagi1.txt') as output_f:
- wagi1 = np.asarray([[float(digit) for digit in line.split()] for line in output_f])
- with open('wagi2.txt') as output_f:
- wagi2 = np.asarray([[float(digit) for digit in line.split()] for line in output_f])
- wyjsciowe = sigmoid(np.dot(mnist_input_test.T, wagi1.T))
- wyjsciowe2 = sigmoid(np.dot(wagi2, wyjsciowe.T))
- i = 0
- while(i<10):
- value = np.random.randint(0, 10000)
- print("Rozpoznana liczba to",np.argmax(wyjsciowe2.T[value]))
- plt.imshow(mnist_input_test[:, value].reshape(28, 28), cmap=matplotlib.cm.binary)
- plt.axis("off")
- plt.show()
- i +=1
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement