Advertisement
eurismarpires

CNN simples com AUC

May 12th, 2017
263
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.48 KB | None | 0 0
  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Sat Apr 29 15:54:45 2017
  4.  
  5. @author: www.deeplearningbrasil.com.br
  6. """
  7.  
  8. # Rede Neural convolucional simples para o problema de reconhecimento de dígitos (MNIST)
  9. import tensorflow as tf
  10. import random
  11. # import matplotlib.pyplot as plt
  12.  
  13. from tensorflow.examples.tutorials.mnist import input_data
  14.  
  15. tf.set_random_seed(777)  # reproducibility
  16.  
  17. mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
  18. # Verifique o site https://www.tensorflow.org/get_started/mnist/beginners para
  19. # mais informações sobre o conjunto de dados
  20.  
  21. # Parâmetros de aprendizagem
  22. taxa_aprendizado = 0.001
  23. quantidade_maxima_epocas = 10
  24. batch_size = 100
  25.  
  26. # entrada dos place holders
  27. X = tf.placeholder(tf.float32, [None, 784])
  28. X_img = tf.reshape(X, [-1, 28, 28, 1])   # imagem 28x28x1 (preto e branca)
  29. Y = tf.placeholder(tf.float32, [None, 10])
  30.  
  31. # L1 ImgIn shape=(?, 28, 28, 1)
  32. W1 = tf.Variable(tf.random_normal([3, 3, 1, 32], stddev=0.01))
  33. #    Conv     -> (?, 28, 28, 32)
  34. #    Pool     -> (?, 14, 14, 32)
  35. L1 = tf.nn.conv2d(X_img, W1, strides=[1, 1, 1, 1], padding='SAME')
  36. L1 = tf.nn.relu(L1)
  37. L1 = tf.nn.avg_pool(L1, ksize=[1, 2, 2, 1],
  38.                     strides=[1, 2, 2, 1], padding='SAME')
  39. '''
  40. Tensor("Conv2D:0", shape=(?, 28, 28, 32), dtype=float32)
  41. Tensor("Relu:0", shape=(?, 28, 28, 32), dtype=float32)
  42. Tensor("MaxPool:0", shape=(?, 14, 14, 32), dtype=float32)
  43. '''
  44.  
  45. # L2 ImgIn shape=(?, 14, 14, 32)
  46. W2 = tf.Variable(tf.random_normal([3, 3, 32, 64], stddev=0.01))
  47. #    Conv      ->(?, 14, 14, 64)
  48. #    Pool      ->(?, 7, 7, 64)
  49. L2 = tf.nn.conv2d(L1, W2, strides=[1, 1, 1, 1], padding='SAME')
  50. L2 = tf.nn.relu(L2)
  51. L2 = tf.nn.avg_pool(L2, ksize=[1, 2, 2, 1],
  52.                     strides=[1, 2, 2, 1], padding='SAME')
  53. L2_flat = tf.reshape(L2, [-1, 7 * 7 * 64])
  54. '''
  55. Tensor("Conv2D_1:0", shape=(?, 14, 14, 64), dtype=float32)
  56. Tensor("Relu_1:0", shape=(?, 14, 14, 64), dtype=float32)
  57. Tensor("MaxPool_1:0", shape=(?, 7, 7, 64), dtype=float32)
  58. Tensor("Reshape_1:0", shape=(?, 3136), dtype=float32)
  59. '''
  60.  
  61. # Classificador - Camada Fully Connected entrada 7x7x64 -> 10 saídas
  62. W3 = tf.get_variable("W3", shape=[7 * 7 * 64, 10],
  63.                      initializer=tf.contrib.layers.xavier_initializer())
  64. b = tf.Variable(tf.random_normal([10]))
  65. logits = tf.matmul(L2_flat, W3) + b
  66.  
  67. # Define a função de custo e o método de otimização
  68. cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=Y))
  69. optimizer = tf.train.AdamOptimizer(taxa_aprendizado).minimize(cost)
  70.  
  71. # inicializa
  72. sess = tf.Session()
  73. sess.run(tf.global_variables_initializer())
  74.  
  75. # treina a rede
  76. print('Rede inicialiada. Treinamento inicializado. Tome um cafe...')
  77. for epoca in range(quantidade_maxima_epocas):
  78.     custo_medio = 0
  79.     total_batch = int(mnist.train.num_examples / batch_size)
  80.  
  81.     for i in range(total_batch):
  82.         batch_xs, batch_ys = mnist.train.next_batch(batch_size)
  83.         feed_dict = {X: batch_xs, Y: batch_ys}
  84.         c, _ = sess.run([cost, optimizer], feed_dict=feed_dict)
  85.         custo_medio += c / total_batch
  86.  
  87.     print('Epoca:', '%04d' % (epoca + 1), 'perda =', '{:.9f}'.format(custo_medio))
  88.  
  89. print('Treinamento finalizado!')
  90.  
  91. # Teste o modelo e verifica a taxa de acerto
  92. correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(Y, 1))
  93. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  94.  
  95. #Saída da rede
  96. d = tf.cast(tf.argmax(logits, 1),tf.float32)
  97. #Respostas Corretas
  98. y = tf.cast(tf.argmax(Y,1),tf.float32)
  99.  
  100. #Calcula a área da curva ROC (AUC)
  101. auc, update_auc = tf.contrib.metrics.streaming_auc(d,y)
  102.  
  103. sess.run(tf.local_variables_initializer())
  104.  
  105. #Avalia Acurácia
  106. print('Taxa de acerto:', sess.run(accuracy, feed_dict={
  107.       X: mnist.test.images, Y: mnist.test.labels}))
  108.  
  109. #Avalia AUC
  110. print('Area da curva ROC:', sess.run(update_auc, feed_dict={
  111.       X: mnist.test.images, Y: mnist.test.labels}))
  112.  
  113.  
  114. '''
  115. #Se for necessário "quebrar" o teste em batchs substitua os dois prints acima pelo código abaixo
  116. total_batch = int(mnist.test.num_examples / batch_size)
  117. acc = 0
  118. for i in range(total_batch):
  119.    batch_xs, batch_ys = mnist.test.next_batch(batch_size)
  120.    feed_dict = {X: batch_xs, Y: batch_ys}
  121.    
  122.    #Acurácia parcial
  123.    acc += sess.run(accuracy, feed_dict)/total_batch
  124.    #Atualiza curva ROC
  125.    roc = sess.run(update_auc, feed_dict)
  126.  
  127. #Resultado final da AUC
  128. roc = sess.run(auc, feed_dict)
  129.  
  130. #Avalia Acurácia
  131. print('Taxa de acerto:', acc)
  132.  
  133. #Avalia AUC
  134. print('Area da curva ROC:', roc)
  135.  
  136. '''
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement