Advertisement
Guest User

Untitled

a guest
Mar 25th, 2017
76
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 7.99 KB | None | 0 0
  1. # coding=utf-8
  2. import csv
  3. import re
  4.  
  5. import pandas as pd
  6. import numpy as np
  7. import json
  8. import io
  9.  
  10. from sklearn.ensemble import ExtraTreesClassifier
  11. from sklearn.ensemble import RandomForestClassifier
  12. from sklearn.neural_network import MLPClassifier
  13. import datetime
  14.  
  15. train_data = []
  16. with io.open('train_data.json','r',encoding='utf8') as f:
  17.     for line in f.readlines():
  18.         d = json.loads(line)
  19.         train_data.append(d)
  20.  
  21. test_data = []
  22. with io.open('test_data.json','r',encoding='utf8') as f:
  23.     for line in f.readlines():
  24.         d = json.loads(line)
  25.         test_data.append(d)
  26.  
  27.  
  28. train_symbols = {}
  29. # for paragraph in train_data:
  30. #     real_paragraph = paragraph['Paragraph']
  31. #     for symbol in real_paragraph:
  32. #         if symbol not in train_symbols:
  33. #             train_symbols[symbol] = len(train_symbols)
  34.  
  35. train_symbols['start'] = len(train_symbols)
  36. train_symbols['end'] = len(train_symbols)
  37.  
  38.  
  39. def get_substring_indices(text, s):
  40.     result = [i for i in range(len(text)) if text.startswith(s, i)]
  41.     res = []
  42.     for elem1 in result:
  43.         res.append(elem1 + len(s) - 1)
  44.     return res
  45.  
  46.  
  47. prev_and_post_positions = 6
  48.  
  49.  
  50. def get_prev_symbols(text, pos1):
  51.     result = set()
  52.     start_pos = pos1
  53.     while pos1 > 0 and start_pos - pos1 <= prev_and_post_positions:
  54.         result.add(text[pos1])
  55.         pos1 -= 1
  56.     return result
  57.  
  58.  
  59. def get_prev_symbols_for_text_and_position(text, pos):
  60.     full_result = []
  61.     start_pos = pos
  62.     while pos > 0 and start_pos - pos < prev_and_post_positions:
  63.         pos -= 1
  64.         result = [0] * (len(train_symbols) + 2)
  65.         if train_symbols.__contains__(text[pos]):
  66.             result[train_symbols[text[pos]]] = 1
  67.         result[len(train_symbols)] = 1 if text[pos].isupper() else 0
  68.         result[len(train_symbols) + 1] = 1 if text[pos].isalpha() else 0
  69.         for elem1 in result:
  70.             full_result.append(elem1)
  71.     while start_pos - pos < prev_and_post_positions:
  72.         pos -= 1
  73.         result = [0] * (len(train_symbols) + 2)
  74.         result[train_symbols['start']] = 1
  75.         result[len(train_symbols)] = 0
  76.         result[len(train_symbols) + 1] = 0
  77.         for elem1 in result:
  78.             full_result.append(elem1)
  79.     return full_result
  80.  
  81.  
  82. def get_symbol_for_text_and_position(text, pos1):
  83.     if text[pos1] == u'!':
  84.         return [1, 0, 0, 0, 0, 0]
  85.     if text[pos1] == u'…':
  86.         return [0, 1, 0, 0, 0, 0]
  87.     if text[pos1] == u'.':
  88.         return [0, 0, 1, 0, 0, 0]
  89.     if text[pos1] == u'?':
  90.         return [0, 0, 0, 1, 0, 0]
  91.     if text[pos1] == u'»':
  92.         return [0, 0, 0, 0, 1, 0]
  93.     if text[pos1] == u'"':
  94.         return [0, 0, 0, 0, 0, 1]
  95.     return [0, 0, 0, 0, 0, 0]
  96.  
  97.  
  98. def get_post_symbols(text, pos1):
  99.     result = set()
  100.     start_pos = pos1
  101.     while pos1 < len(text) and pos1 - start_pos <= prev_and_post_positions:
  102.         result.add(text[pos1])
  103.         pos1 += 1
  104.     return result
  105.  
  106.  
  107. def get_post_symbols_for_text_and_position(text, pos):
  108.     full_result = []
  109.     start_pos = pos
  110.     while pos < len(text) and pos - start_pos < prev_and_post_positions:
  111.         pos += 1
  112.         result = [0] * len(train_symbols)
  113.         if train_symbols.__contains__(text[pos]):
  114.             result[train_symbols[text[pos]]] = 1
  115.         result[len(train_symbols)] = 1 if text[pos].isupper() else 0
  116.         result[len(train_symbols) + 1] = 1 if text[pos].isalpha() else 0
  117.         for elem1 in result:
  118.             full_result.append(elem1)
  119.     while pos - start_pos < prev_and_post_positions:
  120.         pos += 1
  121.         result = [0] * (len(train_symbols) + 2)
  122.         result[train_symbols['end']] = 1
  123.         result[len(train_symbols)] = 0
  124.         result[len(train_symbols) + 1] = 0
  125.         for elem1 in result:
  126.             full_result.append(elem1)
  127.     return full_result
  128.  
  129.  
  130. def get_punctuation_classification_vector(text, pos):
  131.     result = []
  132.     prev = get_prev_symbols_for_text_and_position(text, pos)
  133.     for elem1 in prev:
  134.         result.append(elem1)
  135.     symb = get_symbol_for_text_and_position(text, pos)
  136.     for elem1 in symb:
  137.         result.append(elem1)
  138.     post = get_prev_symbols_for_text_and_position(text, pos)
  139.     for elem1 in post:
  140.         result.append(elem1)
  141.     return result
  142.  
  143.  
  144. def get_maybe_punctuation_positions(real_paragraph1):
  145.     maybe_punctuation_marks1 = set(re.findall("[" + regexp + "]", real_paragraph1))
  146.     maybe_punctuation_positions1 = []
  147.     for maybe_punctuation_mark1 in maybe_punctuation_marks1:
  148.         for pos1, symbol1 in enumerate(real_paragraph1):
  149.             if symbol1 == maybe_punctuation_mark1:
  150.                 maybe_punctuation_positions1.append(pos1)
  151.     return maybe_punctuation_positions1
  152.  
  153. train_set = []
  154. train_result = []
  155. train_paragraphs = [] # Paragraphs for visualization of wrong classification
  156. train_positions = [] # Positions in paragraphs for visualization of wrong classification
  157. regexp = u'!' + u'…' + u'.' + u'?' + u'»' + u'"'
  158. for paragraph in train_data:
  159.     real_paragraph = paragraph['Paragraph']
  160.     maybe_punctuation_positions = get_maybe_punctuation_positions(real_paragraph)
  161.     for maybe_punctuation_position in maybe_punctuation_positions:
  162.         pre = get_prev_symbols(real_paragraph, maybe_punctuation_position)
  163.         for symbol in pre:
  164.             if symbol not in train_symbols:
  165.                 train_symbols[symbol] = len(train_symbols)
  166.         post = get_post_symbols(real_paragraph, maybe_punctuation_position)
  167.         for symbol in post:
  168.             if symbol not in train_symbols:
  169.                 train_symbols[symbol] = len(train_symbols)
  170. print("Train symbols length = " + str(len(train_symbols)))
  171.  
  172. for paragraph in train_data:
  173.     real_paragraph = paragraph['Paragraph']
  174.     maybe_punctuation_marks = set(re.findall("[" + regexp + "]", real_paragraph))
  175.     maybe_punctuation_positions = []
  176.     for maybe_punctuation_mark in maybe_punctuation_marks:
  177.         for pos, symbol in enumerate(real_paragraph):
  178.             if symbol == maybe_punctuation_mark:
  179.                 maybe_punctuation_positions.append(pos)
  180.     real_punctuation_positions = []
  181.     given_sentences = paragraph['Sentences']
  182.     for sentence in given_sentences:
  183.         tmp = get_substring_indices(real_paragraph, sentence)
  184.         for elem in tmp:
  185.             real_punctuation_positions.append(elem)
  186.  
  187.     for maybe_punctuation_position in maybe_punctuation_positions:
  188.         train_set.append(get_punctuation_classification_vector(real_paragraph, maybe_punctuation_position))
  189.         train_paragraphs.append(real_paragraph)
  190.         train_positions.append(maybe_punctuation_position)
  191.         if real_punctuation_positions.__contains__(maybe_punctuation_position):
  192.             train_result.append(1)
  193.         else:
  194.             train_result.append(0)
  195.  
  196. print("Train set length = " + str(len(train_set)))
  197.  
  198. out_data = {}
  199.  
  200. print datetime.datetime.now()
  201. print "Start training"
  202. clf = RandomForestClassifier(n_estimators=100, max_depth=None,min_samples_split = 2, random_state = 0)
  203. clf.fit(train_set, train_result)
  204. print "Trained"
  205.  
  206.  
  207. print "Start working with test data"
  208. different_marks = []
  209. for p in test_data:
  210.     par = p['Paragraph']
  211.     for cand in p['Marks']:
  212.         different_marks.append([cand['Mark']])
  213.         res = clf.predict([get_punctuation_classification_vector(par, cand['Pos'])])
  214.         out_data[cand['Index']] = res[0]
  215.  
  216. print "Analyzing train set"
  217. i = 0
  218. mistakes = 0
  219. while i < len(train_set):
  220.     train_elem = train_set[i]
  221.     train_res = train_result[i]
  222.     got_res = clf.predict([train_elem])
  223.     if train_res != got_res[0]:
  224.         print("Expected: " + str(train_res) + ", Actual: " + str(got_res[0]) + " on paragraph (position = " + str(train_positions[i]) + "):\n" + train_paragraphs[i] + "\n")
  225.         mistakes += 1
  226.     i += 1
  227. print 1 - 1.0 * mistakes / len(train_set)
  228.  
  229. print "Printing to CSV"
  230. with open('sampleSubmission.csv', 'wb') as fout:
  231.     writer = csv.writer(fout)
  232.     writer.writerow(['Id', 'Mark'])
  233.     q = 0
  234.     for item in out_data.keys():
  235.         writer.writerow([item, out_data[item]])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement