Guest User

Untitled

a guest
Jul 20th, 2018
61
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 15.46 KB | None | 0 0
  1. import tensorflow as tf
  2. from tensorflow.contrib import learn
  3. from tflearn.data_utils import to_categorical, pad_sequences
  4. import pickle
  5. from cnn_url_classifier.utils import * #utils file in cnn_url_classifier
  6. import time
  7. from url_utils import format_url
  8. from tqdm import tqdm #For progress bar
  9. import argparse
  10. import numpy as np
  11.  
  12.  
  13.  
  14. # Class to get a URL Safety Score for a single URL (can be used for subject score by changing parameters(
  15. # Easier to use and faster than
  16. class CNNURLScore():
  17.  
  18. def __init__(self, max_length_words = 200, max_length_chars = 200, root_dir = '', max_length_subwords = 20, data_directory = "test_1000.txt",
  19. delimit_mode_no = 1, subword_dict_directory = "runs/10000/subwords_dict.p",
  20. word_dict_directory = "runs/10000/words_dict.p", char_dict_directory = "runs/10000/chars_dict.p",
  21. emb_dimension = 32, emb_mode_no = 1, batch_size_no = 128,
  22. log_output_directory = "runs/url/urleval.txt", log_checkpoint_directory = "runs/10000/checkpoints/"):
  23.  
  24. # Set default parameters and load model
  25.  
  26. # data arguments
  27. self.max_len_words = max_length_words # maximum length of url in words
  28. self.max_len_chars = max_length_chars # maximum length of url in characters
  29. self.max_len_subwords = max_length_subwords # maximum length of word in subwords/ characters
  30.  
  31. self.data_dir = root_dir + data_directory # location of data file - Not useful for testing single URL/subject
  32. self.delimit_mode = delimit_mode_no # 0: delimit by special chars, 1: delimit by special chars + each char as a word
  33. self.subword_dict_dir = root_dir + subword_dict_directory # directory of the subword dictionary
  34. self.word_dict_dir = root_dir + word_dict_directory # directory of the word dictionary
  35. self.char_dict_dir = root_dir + char_dict_directory # directory of the character dictionary
  36.  
  37. # model args
  38. self.emb_dim = emb_dimension # embedding dimension size
  39. self.emb_mode = emb_mode_no # Char CNN
  40.  
  41. # test args
  42. self.batch_size = batch_size_no #For testing of multiple URLs (can be useful in future if classifying a list of URLs)
  43.  
  44. # log args
  45. self.log_output_dir = root_dir + log_output_directory # Directory to save the test results - Not used for single test
  46. self.log_checkpoint_dir = root_dir + log_checkpoint_directory # Directory of the learned model - Not used for single test
  47.  
  48. # Loading model
  49. # self.ngram_dict = pickle.load(open(self.subword_dict_dir, "rb")) # Loading trained model's ngram dictionary
  50. self.word_dict = pickle.load(open(self.word_dict_dir, "rb")) # Loading trained model's word dictionary
  51. self.chars_dict = pickle.load(open(self.char_dict_dir, "rb")) # Loading trained model's character dictionary
  52.  
  53.  
  54. def test_step(self, x, emb_mode):
  55. '''
  56.  
  57. :param x: (List)
  58. :param emb_mode: (int) 1: only character-based CNN, 2: only word-based CNN, 3: character and word CNN, 4: character-level word CNN, 5: character and character-level word CNN
  59. :return:
  60. '''
  61. p = 1.0
  62. if emb_mode == 1:
  63. feed_dict = {
  64. self.input_x_char_seq: x[0],
  65. self.dropout_keep_prob: p}
  66. elif emb_mode == 2:
  67. feed_dict = {
  68. self.input_x_word: x[0],
  69. self.dropout_keep_prob: p}
  70. elif emb_mode == 3:
  71. feed_dict = {
  72. self.input_x_char_seq: x[0],
  73. input_x_word: x[1],
  74. self.dropout_keep_prob: p}
  75. elif emb_mode == 4:
  76. feed_dict = {
  77. input_x_word: x[0],
  78. self.input_x_char: x[1],
  79. self.input_x_char_pad_idx: x[2],
  80. self.dropout_keep_prob: p}
  81. elif emb_mode == 5:
  82. feed_dict = {
  83. self.input_x_char_seq: x[0],
  84. input_x_word: x[1],
  85. self.input_x_char: x[2],
  86. self.input_x_char_pad_idx: x[3],
  87. self.dropout_keep_prob: p}
  88. preds, s = self.sess.run([self.predictions, self.scores], feed_dict)
  89. return preds, s
  90.  
  91.  
  92. def test_url(self, url_test):
  93. sys.stdout = open(os.devnull, 'w') #disable printing
  94. '''
  95.  
  96. :param url: (string) URL to be tested
  97. :return: (int) Normalized score indicating benign/malicious level
  98. '''
  99. sys.stdout = open(os.devnull, 'w') # disable printing
  100.  
  101. urls, labels = [url_test], [0] # TO-DO
  102.  
  103. x, word_reverse_dict = get_word_vocab(urls, self.max_len_words)
  104. word_x = get_words(x, word_reverse_dict, self.delimit_mode, urls) # TO-DO
  105. chared_id_x = char_id_x(urls, self.chars_dict, self.max_len_chars) # TO-DO
  106.  
  107. checkpoint_file = tf.train.latest_checkpoint(self.log_checkpoint_dir)
  108. graph = tf.Graph()
  109. with graph.as_default():
  110. session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
  111. session_conf.gpu_options.allow_growth = True
  112. self.sess = tf.Session(config=session_conf) #CHECK - if self is needed
  113. with self.sess.as_default():
  114. saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
  115. saver.restore(self.sess, checkpoint_file)
  116.  
  117. if self.emb_mode in [1, 3, 5]:
  118. self.input_x_char_seq = graph.get_operation_by_name("input_x_char_seq").outputs[0]
  119. if self.emb_mode in [2, 3, 4, 5]:
  120. input_x_word = graph.get_operation_by_name("input_x_word").outputs[0]
  121. if self.emb_mode in [4, 5]:
  122. self.input_x_char = graph.get_operation_by_name("input_x_char").outputs[0]
  123. self.input_x_char_pad_idx = graph.get_operation_by_name("input_x_char_pad_idx").outputs[0]
  124. self.dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
  125.  
  126. self.predictions = graph.get_operation_by_name("output/predictions").outputs[0]
  127. self.scores = graph.get_operation_by_name("output/scores").outputs[0]
  128.  
  129. if self.emb_mode == 1:
  130. batches = batch_iter(list(chared_id_x), self.batch_size, 1, shuffle=False)
  131. elif self.emb_mode == 2:
  132. batches = batch_iter(list(worded_id_x), self.batch_size, 1, shuffle=False)
  133. elif self.emb_mode == 3:
  134. batches = batch_iter(list(zip(chared_id_x, worded_id_x)), self.batch_size, 1, shuffle=False)
  135. elif self.emb_mode == 4:
  136. batches = batch_iter(list(zip(ngramed_id_x, worded_id_x)), self.batch_size, 1, shuffle=False)
  137. elif self.emb_mode == 5:
  138. batches = batch_iter(list(zip(ngramed_id_x, worded_id_x, chared_id_x)), self.batch_size, 1,
  139. shuffle=False)
  140. all_predictions = []
  141. all_scores = []
  142.  
  143. nb_batches = int(len(labels) / self.batch_size)
  144. if len(labels) % self.batch_size != 0:
  145. nb_batches += 1
  146. # print("Number of batches in total: {}".format(nb_batches))
  147.  
  148. batch = next(batches)
  149.  
  150. if self.emb_mode == 1:
  151. x_char_seq = batch
  152. elif self.emb_mode == 2:
  153. x_word = batch
  154. elif self.emb_mode == 3:
  155. x_char_seq, x_word = zip(*batch)
  156. elif self.emb_mode == 4:
  157. x_char, x_word = zip(*batch)
  158. elif self.emb_mode == 5:
  159. x_char, x_word, x_char_seq = zip(*batch)
  160.  
  161. x_batch = []
  162. if self.emb_mode in [1, 3, 5]:
  163. x_char_seq = pad_seq_in_word(x_char_seq, self.max_len_chars)
  164. x_batch.append(x_char_seq)
  165. if self.emb_mode in [2, 3, 4, 5]:
  166. x_word = pad_seq_in_word(x_word, self.max_len_words)
  167. x_batch.append(x_word)
  168. if self.emb_mode in [4, 5]:
  169. x_char, x_char_pad_idx = pad_seq(x_char, self.max_len_words, self.max_len_subwords,
  170. self.emb_dim)
  171. x_batch.extend([x_char, x_char_pad_idx])
  172.  
  173. batch_predictions, batch_scores = self.test_step(x_batch, self.emb_mode)
  174. all_predictions = np.concatenate([all_predictions, batch_predictions])
  175. all_scores.extend(batch_scores)
  176.  
  177. class_pred = all_predictions[0] # Class of URL -> Benign/Malicious
  178. score = softmax(all_scores)[0][1] # Maliciousness score of URL (using softmax)
  179.  
  180. state = ''
  181.  
  182. if (class_pred == 0):
  183. state = "Benign"
  184. elif (class_pred == 1):
  185. state = "Malicious"
  186.  
  187. sys.stdout = sys.__stdout__ # enable printing
  188. #print(str(urls[0]) + "\t" + "Prediction: " + state + "\tScore:\t" + str(score))
  189.  
  190. sys.stdout = sys.__stdout__ #enable printing
  191.  
  192. #print("all_scores for single URL: " + str(all_scores))
  193. return score
  194.  
  195. def test_url_lst(self, url_lst):
  196. '''
  197.  
  198. :param url: (List of strings) List of URLs to be tested
  199. :return: (List of floats) List of normalized scores indicating benign/malicious level of each URL
  200. '''
  201. sys.stdout = open(os.devnull, 'w') # disable printing
  202. url_lst = [format_url(url) for url in url_lst]
  203. urls, labels = url_lst, [0 for url in url_lst] # TO-DO
  204.  
  205. x, word_reverse_dict = get_word_vocab(urls, self.max_len_words)
  206. word_x = get_words(x, word_reverse_dict, self.delimit_mode, urls) # TO-DO
  207. chared_id_x = char_id_x(urls, self.chars_dict, self.max_len_chars) # TO-DO
  208.  
  209. checkpoint_file = tf.train.latest_checkpoint(self.log_checkpoint_dir)
  210. graph = tf.Graph()
  211.  
  212. with graph.as_default():
  213. session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
  214. session_conf.gpu_options.allow_growth = True
  215. self.sess = tf.Session(config=session_conf) #CHECK - if self is needed
  216.  
  217. with self.sess.as_default():
  218. saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
  219. saver.restore(self.sess, checkpoint_file)
  220.  
  221. if self.emb_mode in [1, 3, 5]:
  222. self.input_x_char_seq = graph.get_operation_by_name("input_x_char_seq").outputs[0]
  223. if self.emb_mode in [2, 3, 4, 5]:
  224. input_x_word = graph.get_operation_by_name("input_x_word").outputs[0]
  225. if self.emb_mode in [4, 5]:
  226. self.input_x_char = graph.get_operation_by_name("input_x_char").outputs[0]
  227. self.input_x_char_pad_idx = graph.get_operation_by_name("input_x_char_pad_idx").outputs[0]
  228. self.dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
  229.  
  230. self.predictions = graph.get_operation_by_name("output/predictions").outputs[0]
  231. self.scores = graph.get_operation_by_name("output/scores").outputs[0]
  232.  
  233. if self.emb_mode == 1:
  234. batches = batch_iter(list(chared_id_x), self.batch_size, 1, shuffle=False)
  235. elif self.emb_mode == 2:
  236. batches = batch_iter(list(worded_id_x), self.batch_size, 1, shuffle=False)
  237. elif self.emb_mode == 3:
  238. batches = batch_iter(list(zip(chared_id_x, worded_id_x)), self.batch_size, 1, shuffle=False)
  239. elif self.emb_mode == 4:
  240. batches = batch_iter(list(zip(ngramed_id_x, worded_id_x)), self.batch_size, 1, shuffle=False)
  241. elif self.emb_mode == 5:
  242. batches = batch_iter(list(zip(ngramed_id_x, worded_id_x, chared_id_x)), self.batch_size, 1,
  243. shuffle=False)
  244. all_predictions = []
  245. all_scores = []
  246.  
  247. nb_batches = int(len(labels) / self.batch_size)
  248. if len(labels) % self.batch_size != 0:
  249. nb_batches += 1
  250. # print("Number of batches in total: {}".format(nb_batches))
  251.  
  252. '''
  253. it = tqdm(range(nb_batches),
  254. desc="emb_mode {} delimit_mode {} test_size {}".format(self.emb_mode,
  255. self.emb_mode,
  256. len(labels)), ncols=0)
  257. '''
  258. #it = tqdm(range(nb_batches))
  259.  
  260. #for idx in it:
  261. for i in range(nb_batches):
  262. # for batch in batches:
  263. batch = next(batches)
  264.  
  265. if self.emb_mode == 1:
  266. x_char_seq = batch
  267. elif self.emb_mode == 2:
  268. x_word = batch
  269. elif self.emb_mode == 3:
  270. x_char_seq, x_word = zip(*batch)
  271. elif self.emb_mode == 4:
  272. x_char, x_word = zip(*batch)
  273. elif self.emb_mode == 5:
  274. x_char, x_word, x_char_seq = zip(*batch)
  275.  
  276. x_batch = []
  277. if self.emb_mode in [1, 3, 5]:
  278. x_char_seq = pad_seq_in_word(x_char_seq, self.max_len_chars)
  279. x_batch.append(x_char_seq)
  280. if self.emb_mode in [2, 3, 4, 5]:
  281. x_word = pad_seq_in_word(x_word, self.max_len_words)
  282. x_batch.append(x_word)
  283. if self.emb_mode in [4, 5]:
  284. x_char, x_char_pad_idx = pad_seq(x_char, self.max_len_words, self.max_len_subwords,
  285. self.emb_dim)
  286. x_batch.extend([x_char, x_char_pad_idx])
  287.  
  288. batch_predictions, batch_scores = self.test_step(x_batch, self.emb_mode)
  289. all_predictions = np.concatenate([all_predictions, batch_predictions])
  290. all_scores.extend(batch_scores)
  291.  
  292. # Leaving commented code for possible future use/testing
  293. #class_pred = all_predictions[0] # Class of URL -> Benign/Malicious
  294. #score = softmax(all_scores)[0][1] # Maliciousness score of URL (using softmax)
  295. #print("all_predictions " + str(all_predictions))
  296. #print("class_pred " + str(class_pred))
  297. #print("score " + str(score))
  298.  
  299. state = ''
  300.  
  301. softmax_scores = [softmax(i) for i in all_scores]
  302. #print("softmax_scores: " + str(softmax_scores))
  303. #for i in range(len(labels)):
  304. # print("Softmax score: " + str(softmax_scores[i][1]))
  305.  
  306. sys.stdout = sys.__stdout__ # enable printing
  307.  
  308. scores_lst = [i[1] for i in softmax_scores]
  309.  
  310. return scores_lst
  311.  
  312. ####### Examples #######
  313. if __name__ == "__main__":
  314. a = CNNURLScore()
  315. start = time.time()
  316. print("Score: " + "apple.com " + str(a.test_url("apple.com")))
  317. print("Score: " + "google.com " + str(a.test_url("google.com")))
  318. end = time.time()
  319. #print("Time: " + str(end - start))
  320. s = time.time()
  321. print(a.test_url_lst(["apple.com", "google.com"]))
  322. e = time.time()
  323. #print("Time 2: " + str(e - s))
  324. #s = CNNURLScore(delimit_mode_no = 0, subword_dict_directory = "runs/subjects/subwords_dict.p", word_dict_directory = "runs/subjects/words_dict.p",
  325. #char_dict_directory = "runs/subjects/chars_dict.p", emb_mode_no = 2)
  326. #print(s.emb_mode)
  327. #print(s.test_url("Subject testing"))
Add Comment
Please, Sign In to add comment