Advertisement
Guest User

Untitled

a guest
Jul 17th, 2019
87
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.04 KB | None | 0 0
  1. # X: input; y: output
  2. import itertools
  3. from collections import defaultdict
  4. from itertools import zip_longest
  5.  
  6. def genX(texts,fss,ypreds, maxlenTest):
  7. X = np.zeros((1, maxlenTest, 1024+4+13))
  8. nx =[]
  9. for text,fs,ypred in zip(texts,fss,ypreds):
  10. nx.extend(
  11. [np.concatenate([
  12. elmo.predict(gE([t]))[0],
  13. fs,
  14. ypred])
  15. for t in text]
  16. )
  17. if(len(nx)>0):
  18. X[0,:len(nx),:] = nx
  19. return X
  20.  
  21. def genTags(words, heat):
  22. bins = [
  23. ("SK", [], []),
  24. ("LG", [], []),
  25. ("NM", [], []),
  26. ("AD", [], []),
  27. ("PH", [], []),
  28. ("MA", [], []),
  29. ("CP", [], []),
  30. ("WP", [], []),
  31. ("PO", [], []),
  32. ("IN", [], []),
  33. ("FI", [], []),
  34. ("EP", [], [])
  35. ]
  36. for w, h in zip(words, heat):
  37. v = np.argmax(h)
  38. if v != 0:
  39. ent = int(np.ceil(v / 2))
  40. acc = bins[ent - 1][1]
  41. lacc = bins[ent - 1][2]
  42. if (v == (ent * 2) - 1):
  43. if (len(lacc) > 0):
  44. acc.append(' '.join(lacc))
  45. lacc.clear()
  46. lacc.append(w)
  47. elif (v == ent * 2):
  48. lacc.append(w)
  49. for b in bins:
  50. if (len(b[2]) > 0):
  51. b[1].append(' '.join(b[2]))
  52.  
  53. used = set()
  54. ents = [(ent, item) for ent, acc, lacc in bins for item in acc]
  55.  
  56. mylist = [u'nowplaying', u'PBS', u'PBS', u'nowplaying', u'job', u'debate', u'thenandnow']
  57. nonrepeat = [ent for ent in ents if ent not in used and (used.add(ent) or True)]
  58.  
  59. d = defaultdict(list)
  60. embDict = defaultdict(list)
  61.  
  62. def nameFormat(word):
  63. return word[0].upper()+word[1:].lower()
  64.  
  65. for k, v in nonrepeat:
  66. d[k].append(v)
  67. embDict[k].append(list(elmo.predict(gE([nameFormat(v)]))))
  68. return d, embDict
  69.  
  70. def getEntities(fileName, model2):
  71. path = "ManuallyTagged" + "/" + fileName
  72. d = open(path, encoding="utf8").read()
  73. cssRules = getCssRules(d)
  74. text = remove_html_tags(d, "span")
  75. text = remove_nl(text)
  76. testSents, ids = getSents(text, cssRules) # testSents = [x, y, fontSize, CvPage], ids = [?, ...]
  77.  
  78. test_embs = elmo.predict(gE([t for _, t in testSents]))
  79. test_fs = scaler.transform([f for f, _ in testSents]) # Positions are here TODO
  80. test_sents = np.hstack([test_embs, test_fs])
  81. test_docs = np.zeros((len(test_sents), 259, 1028), dtype='float32')
  82. c = 0
  83. for _ in range(len(test_sents)):
  84. for t, sentEmb in enumerate(test_sents):
  85. test_docs[c, t, :] = sentEmb
  86. c += 1
  87.  
  88. test_texts = [t for _, t in testSents]
  89.  
  90. ypred = model.predict([test_docs, test_sents])
  91. clusts = np.argmax(ypred, axis=1)
  92.  
  93. mask = np.any([ypred[:, vi] > 0.5 for vi in KEEP], axis=0)
  94.  
  95. procCurr = [[tokenize(t) for (f, t), m in zip(testSents, mask) if m],
  96. test_fs[mask], ypred[mask]]
  97.  
  98. X = genX(*procCurr, maxlenTest)
  99.  
  100. heat = modelEnt.predict(X)
  101.  
  102. heat = heat.reshape(len(heat[0]), len(data_out))
  103. words = [t for text, fs, ypred in zip(*procCurr) for t in text]
  104.  
  105. relevantClusts = [2, 3, 5, 6, 7, 8, 9]
  106. skDict, embDict = genTags(words, heat)
  107.  
  108. # Postprocessing the data
  109. NAME = "NM"
  110. ADDRESS = "AD"
  111. PHONE = "PH"
  112. MAIL = "MA"
  113. SKILLS = "SK" # Embedding
  114. COMPANY = "CP" # Embedding
  115. WORK_PERIOD = "WP"
  116. POSITION = "PO" # Embedding
  117. LANGUAGES = "LG" # Embedding
  118.  
  119. UNI = "IN"
  120. FIELD = "FI" # Embedding
  121. EDUCATION_PERIOD = "EP"
  122.  
  123. def get_sentence_embedding(sentence_string):
  124. return elmo.predict(gE([sentence_string]))[0]
  125.  
  126. def getEmbeddingsForList(datas):
  127. return [(data, get_sentence_embedding(data)) for data in datas]
  128.  
  129. entities = {}
  130. entities[NAME] = " ".join(skDict[NAME])
  131. entities[ADDRESS] = " ".join(skDict[ADDRESS])
  132. entities[PHONE] = " ".join(skDict[PHONE])
  133. entities[MAIL] = "".join(skDict[MAIL])
  134. entities[SKILLS] = getEmbeddingsForList(skDict[SKILLS])
  135. entities[LANGUAGES] = getEmbeddingsForList(skDict[LANGUAGES])
  136.  
  137. companies = getEmbeddingsForList(skDict[COMPANY])
  138. workPeriods = skDict[WORK_PERIOD]
  139. positions = getEmbeddingsForList(skDict[POSITION])
  140. entities["WORK_EXPERIENCE"] = list(zip_longest(companies, positions, workPeriods, fillvalue="UNKNOWN"))
  141.  
  142. unis, educationPeriods = skDict[UNI], skDict[EDUCATION_PERIOD]
  143. fields = getEmbeddingsForList(skDict[FIELD])
  144. entities["EDUCATION"] = list(zip_longest(unis, fields, educationPeriods, fillvalue="UNKNOWN"))
  145.  
  146. curr = {
  147. "Name": fileName,
  148. "HTML": getTaggedHtml(path, clusts),
  149. "Texts": test_texts,
  150. "Embeds": test_embs,
  151. "EmbedsPos":
  152. [np.mean(test_sents[np.isin(clusts, relevantClusts)], axis=0)] +
  153. [np.mean(test_sents[clusts == rc], axis=0) for rc in relevantClusts],
  154. "Clusts": clusts,
  155. "Properties": skDict,
  156. "PropEmbs": embDict
  157. }
  158.  
  159. print(entities)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement