Advertisement
Guest User

GAN-generated cat detector (April 2021)

a guest
Apr 10th, 2021
833
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.26 KB | None | 0 0
  1. import numpy as np
  2. import pandas as pd
  3. import os
  4. import sys
  5. import random
  6. from PIL import Image
  7. from sklearn.neural_network import MLPClassifier
  8. import pickle
  9.  
  10. SIZE = 512
  11. SIZE_TEXT = str (SIZE) + "x" + str (SIZE)
  12. PIXELS = SIZE * SIZE
  13. BORDER = 24
  14.  
  15. def count_and_sort (df, columns):
  16. g = df.groupby (columns)
  17. df = pd.DataFrame ({"count" : g.size ()}).reset_index ()
  18. return df.sort_values ("count", ascending=False).reset_index ()
  19.  
  20. def image_mask (pixels):
  21. shape = pixels.shape
  22. if shape[0] != SIZE or shape[1] != SIZE or shape[2] !=3:
  23. return None
  24. rgb = np.sum (pixels, axis=(0, 1))
  25. rgb = (rgb[0] / PIXELS, rgb[1] / PIXELS, rgb[2] / PIXELS)
  26. data = np.abs (pixels - rgb)
  27. return data * (1 / np.max (data))
  28.  
  29. def calculate_peaks (data):
  30. data = data * (2, 5, 3)
  31. data = np.sum (data, axis=2)
  32. peaks = np.sum (data, axis=0) + np.sum (data, axis=1)
  33. peaks = np.diff (peaks)
  34. peaks = peaks[BORDER:SIZE - BORDER]
  35. return peaks / np.max (np.abs (peaks))
  36.  
  37. def create_dataset (path):
  38. print ("loading images from " + path)
  39. rows = []
  40. if not path.endswith ("/"):
  41. path = path + "/"
  42. for f in os.listdir (path):
  43. try:
  44. if f[f.find ("."):].lower () in (".jpeg", ".jpg", ".png"):
  45. image = Image.open (path + f)
  46. data = np.array (image)
  47. image.close ()
  48. mask = image_mask (data)
  49. if mask is None:
  50. print ("skipped " + f + ", image not " + SIZE_TEXT)
  51. else:
  52. peaks = calculate_peaks (mask)
  53. rows.append ((peaks, path + f))
  54. if len (rows) % 100 == 0:
  55. print (str (len (rows)) + " images processed")
  56. except:
  57. print ("skipped: " + f)
  58. if len (rows) % 100 != 0:
  59. print (str (len (rows)) + " images processed")
  60. return rows
  61.  
  62. def train_model (realcats, gancats, training=0.5):
  63. data = [(x[0], 0, x[1]) for x in realcats]
  64. data.extend ([(x[0], 1, x[1]) for x in gancats])
  65. random.shuffle (data)
  66. cutoff = int (training * len (data))
  67. train = data[:cutoff]
  68. test = data[cutoff:]
  69. model = MLPClassifier (random_state=random.randint (0, 2147483647),
  70. max_iter=1000, hidden_layer_sizes=[48, 12, 3])
  71. data = np.stack ([t[0] for t in train], axis=0)
  72. labels = [t[1] for t in train]
  73. model = model.fit (data, labels)
  74. data = np.stack ([t[0] for t in test], axis=0)
  75. labels = [t[1] for t in test]
  76. images = [t[2] for t in test]
  77. predictions = model.predict (data)
  78. df = pd.DataFrame ({"actual" : labels,
  79. "predicted" : predictions,
  80. "image" : images})
  81. return (model, df)
  82.  
  83. def train_loop (realcats, gancats, iterations=100, training=0.7):
  84. accuracy = 0
  85. for i in range (iterations):
  86. result = train_model (realcats, gancats, training=training)
  87. df = result[1]
  88. score = len (df[df["actual"] == df["predicted"]]) / len (df.index)
  89. if score > accuracy:
  90. accuracy = score
  91. best = result
  92. if (i + 1) % 20 == 0:
  93. print (str (i + 1) + "/" + str (iterations ) + ", best accuracy " + str (accuracy))
  94. print ("accuracy: " + str (accuracy))
  95. return best
  96.  
  97. def classify_images (model, path):
  98. test = create_dataset (path)
  99. data = np.stack ([t[0] for t in test], axis=0)
  100. images = [t[1] for t in test]
  101. predictions = model.predict (data)
  102. df = pd.DataFrame ({"image" : images,
  103. "predicted" : predictions})
  104. df["predictedCategory"] = df["predicted"].apply (
  105. lambda x: "GAN" if x == 1 else "real")
  106. return df[["image", "predicted", "predictedCategory"]]
  107.  
  108. argv = sys.argv
  109. mode = argv[1].lower ().strip ()
  110. if mode == "train":
  111. realcats = create_dataset (argv[2])
  112. gancats = create_dataset (argv[3])
  113. result = train_loop (realcats, gancats)
  114. with open (argv[4], "wb") as fd:
  115. pickle.dump (result[0], fd)
  116. elif mode == "test":
  117. with (open (argv[3], "rb")) as fd:
  118. model = pickle.load (fd)
  119. df = classify_images (model, argv[2])
  120. df.to_csv (argv[4], index=False)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement