Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # coding=utf-8
- import csv
- import re
- import pandas as pd
- import numpy as np
- import json
- import io
- from sklearn.ensemble import ExtraTreesClassifier
- from sklearn.ensemble import RandomForestClassifier
- from sklearn.neural_network import MLPClassifier
- import datetime
- train_data = []
- with io.open('train_data.json','r',encoding='utf8') as f:
- for line in f.readlines():
- d = json.loads(line)
- train_data.append(d)
- test_data = []
- with io.open('test_data.json','r',encoding='utf8') as f:
- for line in f.readlines():
- d = json.loads(line)
- test_data.append(d)
- train_symbols = {}
- # for paragraph in train_data:
- # real_paragraph = paragraph['Paragraph']
- # for symbol in real_paragraph:
- # if symbol not in train_symbols:
- # train_symbols[symbol] = len(train_symbols)
- train_symbols['start'] = len(train_symbols)
- train_symbols['end'] = len(train_symbols)
- def get_substring_indices(text, s):
- result = [i for i in range(len(text)) if text.startswith(s, i)]
- res = []
- for elem1 in result:
- res.append(elem1 + len(s) - 1)
- return res
- prev_and_post_positions = 6
- def get_prev_symbols(text, pos1):
- result = set()
- start_pos = pos1
- while pos1 > 0 and start_pos - pos1 <= prev_and_post_positions:
- result.add(text[pos1])
- pos1 -= 1
- return result
- def get_prev_symbols_for_text_and_position(text, pos):
- full_result = []
- start_pos = pos
- while pos > 0 and start_pos - pos < prev_and_post_positions:
- pos -= 1
- result = [0] * (len(train_symbols) + 2)
- if train_symbols.__contains__(text[pos]):
- result[train_symbols[text[pos]]] = 1
- result[len(train_symbols)] = 1 if text[pos].isupper() else 0
- result[len(train_symbols) + 1] = 1 if text[pos].isalpha() else 0
- for elem1 in result:
- full_result.append(elem1)
- while start_pos - pos < prev_and_post_positions:
- pos -= 1
- result = [0] * (len(train_symbols) + 2)
- result[train_symbols['start']] = 1
- result[len(train_symbols)] = 0
- result[len(train_symbols) + 1] = 0
- for elem1 in result:
- full_result.append(elem1)
- return full_result
- def get_symbol_for_text_and_position(text, pos1):
- if text[pos1] == u'!':
- return [1, 0, 0, 0, 0, 0]
- if text[pos1] == u'…':
- return [0, 1, 0, 0, 0, 0]
- if text[pos1] == u'.':
- return [0, 0, 1, 0, 0, 0]
- if text[pos1] == u'?':
- return [0, 0, 0, 1, 0, 0]
- if text[pos1] == u'»':
- return [0, 0, 0, 0, 1, 0]
- if text[pos1] == u'"':
- return [0, 0, 0, 0, 0, 1]
- return [0, 0, 0, 0, 0, 0]
- def get_post_symbols(text, pos1):
- result = set()
- start_pos = pos1
- while pos1 < len(text) and pos1 - start_pos <= prev_and_post_positions:
- result.add(text[pos1])
- pos1 += 1
- return result
- def get_post_symbols_for_text_and_position(text, pos):
- full_result = []
- start_pos = pos
- while pos < len(text) and pos - start_pos < prev_and_post_positions:
- pos += 1
- result = [0] * len(train_symbols)
- if train_symbols.__contains__(text[pos]):
- result[train_symbols[text[pos]]] = 1
- result[len(train_symbols)] = 1 if text[pos].isupper() else 0
- result[len(train_symbols) + 1] = 1 if text[pos].isalpha() else 0
- for elem1 in result:
- full_result.append(elem1)
- while pos - start_pos < prev_and_post_positions:
- pos += 1
- result = [0] * (len(train_symbols) + 2)
- result[train_symbols['end']] = 1
- result[len(train_symbols)] = 0
- result[len(train_symbols) + 1] = 0
- for elem1 in result:
- full_result.append(elem1)
- return full_result
- def get_punctuation_classification_vector(text, pos):
- result = []
- prev = get_prev_symbols_for_text_and_position(text, pos)
- for elem1 in prev:
- result.append(elem1)
- symb = get_symbol_for_text_and_position(text, pos)
- for elem1 in symb:
- result.append(elem1)
- post = get_prev_symbols_for_text_and_position(text, pos)
- for elem1 in post:
- result.append(elem1)
- return result
- def get_maybe_punctuation_positions(real_paragraph1):
- maybe_punctuation_marks1 = set(re.findall("[" + regexp + "]", real_paragraph1))
- maybe_punctuation_positions1 = []
- for maybe_punctuation_mark1 in maybe_punctuation_marks1:
- for pos1, symbol1 in enumerate(real_paragraph1):
- if symbol1 == maybe_punctuation_mark1:
- maybe_punctuation_positions1.append(pos1)
- return maybe_punctuation_positions1
- train_set = []
- train_result = []
- train_paragraphs = [] # Paragraphs for visualization of wrong classification
- train_positions = [] # Positions in paragraphs for visualization of wrong classification
- regexp = u'!' + u'…' + u'.' + u'?' + u'»' + u'"'
- for paragraph in train_data:
- real_paragraph = paragraph['Paragraph']
- maybe_punctuation_positions = get_maybe_punctuation_positions(real_paragraph)
- for maybe_punctuation_position in maybe_punctuation_positions:
- pre = get_prev_symbols(real_paragraph, maybe_punctuation_position)
- for symbol in pre:
- if symbol not in train_symbols:
- train_symbols[symbol] = len(train_symbols)
- post = get_post_symbols(real_paragraph, maybe_punctuation_position)
- for symbol in post:
- if symbol not in train_symbols:
- train_symbols[symbol] = len(train_symbols)
- print("Train symbols length = " + str(len(train_symbols)))
- for paragraph in train_data:
- real_paragraph = paragraph['Paragraph']
- maybe_punctuation_marks = set(re.findall("[" + regexp + "]", real_paragraph))
- maybe_punctuation_positions = []
- for maybe_punctuation_mark in maybe_punctuation_marks:
- for pos, symbol in enumerate(real_paragraph):
- if symbol == maybe_punctuation_mark:
- maybe_punctuation_positions.append(pos)
- real_punctuation_positions = []
- given_sentences = paragraph['Sentences']
- for sentence in given_sentences:
- tmp = get_substring_indices(real_paragraph, sentence)
- for elem in tmp:
- real_punctuation_positions.append(elem)
- for maybe_punctuation_position in maybe_punctuation_positions:
- train_set.append(get_punctuation_classification_vector(real_paragraph, maybe_punctuation_position))
- train_paragraphs.append(real_paragraph)
- train_positions.append(maybe_punctuation_position)
- if real_punctuation_positions.__contains__(maybe_punctuation_position):
- train_result.append(1)
- else:
- train_result.append(0)
- print("Train set length = " + str(len(train_set)))
- out_data = {}
- print datetime.datetime.now()
- print "Start training"
- clf = RandomForestClassifier(n_estimators=100, max_depth=None,min_samples_split = 2, random_state = 0)
- clf.fit(train_set, train_result)
- print "Trained"
- print "Start working with test data"
- different_marks = []
- for p in test_data:
- par = p['Paragraph']
- for cand in p['Marks']:
- different_marks.append([cand['Mark']])
- res = clf.predict([get_punctuation_classification_vector(par, cand['Pos'])])
- out_data[cand['Index']] = res[0]
- print "Analyzing train set"
- i = 0
- mistakes = 0
- while i < len(train_set):
- train_elem = train_set[i]
- train_res = train_result[i]
- got_res = clf.predict([train_elem])
- if train_res != got_res[0]:
- print("Expected: " + str(train_res) + ", Actual: " + str(got_res[0]) + " on paragraph (position = " + str(train_positions[i]) + "):\n" + train_paragraphs[i] + "\n")
- mistakes += 1
- i += 1
- print 1 - 1.0 * mistakes / len(train_set)
- print "Printing to CSV"
- with open('sampleSubmission.csv', 'wb') as fout:
- writer = csv.writer(fout)
- writer.writerow(['Id', 'Mark'])
- q = 0
- for item in out_data.keys():
- writer.writerow([item, out_data[item]])
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement