Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- # Norbert Batfai, 27 Nov 2016
- # Some modifications and additions to the original code:
- # https://github.com/tensorflow/tensorflow/blob/r0.11/tensorflow/examples/tutorials/mnist/mnist_softmax.py
- # See also http://progpater.blog.hu/2016/11/13/hello_samu_a_tensorflow-bol
- # ==============================================================================
- """A very simple MNIST classifier.
- See extensive documentation at
- http://tensorflow.org/tutorials/mnist/beginners/index.md
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import argparse
- import numpy
- # Import data
- from tensorflow.examples.tutorials.mnist import input_data
- import tensorflow as tf
- import matplotlib.pyplot as plt
- import matplotlib as mpl
- FLAGS = None
- def readimg():
- file = tf.read_file("harom.png")
- img = tf.image.decode_png(file)
- return img
- def main(_):
- model_path = "/tmp/model.ckpt"
- mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
- # Create the model
- x = tf.placeholder(tf.float32, [None, 784])
- #W = tf.Variable(tf.zeros([784, 10]))
- W = tf.Variable(numpy.random.normal(0, 0.001, size=(784,10)),dtype=tf.float32)
- b = tf.Variable(tf.zeros([10]))
- y = tf.matmul(x, W) + b
- # Define loss and optimizer
- y_ = tf.placeholder(tf.float32, [None, 10])
- # The raw formulation of cross-entropy,
- #
- # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
- # reduction_indices=[1]))
- #
- # can be numerically unstable.
- #
- # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
- # outputs of 'y', and then average across the batch.
- cross_entropy = tf.reduce_mean(
- tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_))
- train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
- # 'Saver' op to save and restore all the variables
- saver = tf.train.Saver()
- sess = tf.InteractiveSession()
- # Train
- sess.run(tf.global_variables_initializer())
- # tf.initialize_all_variables().run()
- # tf.global_variables_initializer().run
- print("-- A halozat tanitasa")
- for i in range(1000+1):
- if i in range(0,9):
- img = W[:,0]
- image = img.eval()
- # plt.text(1, 1, "Iteracio: {}".format(i))
- # plt.imshow(image.reshape(28, 28), cmap="hot")
- # plt.savefig("w0.png")
- # plt.title(i)
- # plt.show()
- j=0
- if i % 2 == 0 :
- print(i / 10, "%")
- img = W[:,0]
- image = img.eval()
- plt.text(1, 1, "Iteracio: {}".format(i))
- plt.imshow(image.reshape(28, 28), cmap = mpl.colors.ListedColormap(['red', 'black', 'blue']))
- #plt.imshow(image.reshape(28, 28), cmap="seismic_r")
- plt.savefig('w' + str(j) + '.png')
- plt.title(i)
- plt.show()
- j+=1
- batch_xs, batch_ys = mnist.train.next_batch(4)
- sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
- print("-"*20)
- # Test trained model
- print("-- A halozat tesztelese")
- #saver.restore(sess, model_path)
- correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
- accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- print("-- Pontossag: ", sess.run(accuracy, feed_dict={x: mnist.test.images,
- y_: mnist.test.labels}))
- print("-"*20)
- j=0
- writer = tf.summary.FileWriter("/tmp/mnist_softmax_UDPROG61", sess.graph)
- img = W[:,0]
- image = img.eval()
- print("-- W_0[14,14] =", image[14*28+14])
- print("-- Egy vedesi reszfeladat az elso heten: a W_0 sulyok abrazolasa, mutatom, a tovabblepeshez csukd be az ablakat")
- #plt.imshow(image.reshape(28, 28))#, cmap=plt.cm.binary)
- plt.imshow(image.reshape(28, 28), cmap = mpl.colors.ListedColormap(['red', 'black', 'blue']))
- #plt.imshow(image.reshape(28, 28), cmap="seismic_r")
- plt.savefig('w' + str(j) + '.png')
- j+=1
- #plt.show()
- print("----------------------------------------------------------")
- print("-- A MNIST 42. tesztkepenek felismerese, mutatom a szamot, a tovabblepeshez csukd be az ablakat")
- img = mnist.test.images[42]
- image = img
- plt.imshow(image.reshape(
- 28, 28), cmap=plt.cm.binary)
- plt.savefig("4.png")
- plt.show()
- classification = sess.run(tf.argmax(y, 1), feed_dict={x: [image]})
- print("-- Ezt a halozat ennek ismeri fel: ", classification[0])
- print("----------------------------------------------------------")
- print("-- A sajat kezi 7-esem felismerese, mutatom a szamot, a tovabblepeshez csukd be az ablakat")
- img = mnist.test.images[41]
- image = img
- image = image.reshape(28 * 28)
- plt.imshow(image.reshape(
- 28, 28), cmap=plt.cm.binary)
- plt.savefig("7.png")
- plt.show()
- classification = sess.run(tf.argmax(y, 1), feed_dict={x: [image]})
- print("-- Ezt a halozat ennek ismeri fel: ", classification[0])
- print("----------------------------------------------------------")
- print("-- A sajat kezi 3-asom felismerese, mutatom a szamot, a tovabblepeshez csukd be az ablakat")
- img = readimg()
- image = img.eval().flatten()
- for i, img in enumerate(image):
- image[i] =abs(image[i]-255)
- plt.imshow(image.reshape(28, 28), cmap=plt.cm.binary)
- plt.savefig("3.png")
- plt.show()
- classification = sess.run(tf.argmax(y, 1), feed_dict={x: [image]})
- print("-- Ezt a halozat ennek ismeri fel: ", classification[0])
- print("----------------------------------------------------------")
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
- help='Directory for storing input data')
- FLAGS = parser.parse_args()
- tf.app.run()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement