Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import pandas as pd
- import numpy as np
- import spacy
- from collections import Counter
- import re
- import pickle
- from rank_bm25 import BM25Plus
- from nltk.util import ngrams
- import unicodedata
- import multiprocessing
- from sklearn.metrics.pairwise import cosine_similarity
- from loguru import logger
- import copy
- def preprocessing(labels=True, spacy_lang='pt_core_news_sm', train_file='train.csv'):
- nlp = spacy.load(spacy_lang)
- tokenizer = nlp.Defaults.create_tokenizer(nlp)
- df = pd.read_csv(train_file)
- token_labels = None
- if labels == True:
- df = df.groupby('category', as_index=False).agg({'title' : ' '.join })
- token_labels = df['category'].values.tolist()
- tokens = []
- docs = df['title'].values.tolist()
- for item in docs:
- processed_item = re.sub('[0-9]', '__label_DIGIT__', item.lower())
- tmp = tokenizer(processed_item)
- tokens.append([str(x) for x in tmp if not (x.is_punct or x.is_stop)])
- for idx,token in enumerate(tokens):
- token_ngrams = [ ' '.join(list(x)) for x in list(ngrams(token, 2)) ]
- tokens[idx] = token + token_ngrams
- if labels == True:
- return tokens, token_labels
- else:
- return tokens
- class Expansion(object):
- def __init__(self):
- self.matrix = None
- def run(self, matrix):
- self.matrix = cosine_similarity(matrix)
- self.matrix = np.array(self.matrix)
- def __call__(self, query_id, k=5):
- return np.argsort(self.matrix[query_id])[::-1][1:][:k]
- def save(self, target):
- np.save(target, self.matrix)
- def load(self, target):
- self.matrix = np.load(target)
- def calc_query(submission):
- return bm25.get_top_n(submission, corpus, n=1)[0]
- def load_vectors(cache=True, name='vectors.npy'):
- expansion_matrix = []
- try:
- if cache == False:
- raise Exception(" ")
- expansion_matrix = np.load(name)
- except:
- for item in test_tokens:
- expansion_matrix.append(bm25.get_scores(item))
- np.save(name, expansion_matrix)
- return expansion_matrix
- def load_expansion(expansion_matrix, cache=True, name='expansion.npy'):
- expansion = Expansion()
- try:
- if cache == False:
- raise Exception(" ")
- expansion.load(name)
- except:
- expansion.run(expansion_matrix)
- expansion.save(name)
- return expansion
- if __name__ == '__main__':
- tokens, labels = preprocessing()
- label_dict = { ' '.join(i) : v for i,v in list(zip(tokens, labels)) }
- bm25 = BM25Plus(tokens, k1=4.5)
- test_tokens = preprocessing(labels=False, train_file='test.csv')
- corpus = [ ' '.join(x) for x in tokens ]
- expansion_matrix = load_vectors(cache=False)
- logger.debug('Calculating cosine similarity matrix.')
- expansion = load_expansion(expansion_matrix, cache=False)
- logger.debug('Starting queries...')
- arg_list = []
- answers = []
- model_output = []
- for tokenized_query_idx, tokenized_query in enumerate(test_tokens):
- submission = copy.deepcopy(tokenized_query)
- first_word_list = ['antena', 'receptor', 'conversor', 'estante', 'comando', 'retentor', 'radiador', 'receiver']
- if tokenized_query[0] in first_word_list:
- submission = list([tokenized_query[0]])
- elif tokenized_query[0] == 'mangueira' and 'radiador' in tokenized_query:
- submission = list(filter(lambda x: x != 'radiador', tokenized_query))
- elif ( 'escova dente' in tokenized_query ) or ( 'escova dental' in tokenized_query ):
- submission = ['escova dental']
- else:
- expansion_ids = expansion(tokenized_query_idx, k=10)
- for idx in expansion_ids:
- submission.extend(test_tokens[idx])
- answer = bm25.get_top_n(submission, corpus, n=1)[0]
- answers.append(answer)
- if not tokenized_query_idx % 100:
- logger.debug("{0} / {1} ( {2} )".format(tokenized_query_idx, len(test_tokens),tokenized_query_idx / len(test_tokens) ))
- for answer in answers:
- classification = label_dict[answer]
- model_output.append(classification)
- final_df = []
- for idx, item in enumerate(model_output):
- final_df.append({'id': idx, 'category': item})
- pd.DataFrame(final_df).to_csv('submission.csv', index=False)
- logger.debug("End of queries.")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement