Advertisement
Guest User

Untitled

a guest
May 4th, 2015
253
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.98 KB | None | 0 0
  1. #! c:\python27\python
  2. #Usage:
  3. #Training: NB.py 1 TrainingDataFile ModelFile
  4. #Testing: NB.py 0 TestDataFile ModelFile OutFile
  5.  
  6. import sys
  7. import os
  8. import math
  9.  
  10.  
  11. DefaultFreq = 0.1
  12. TrainingDataFile = "newdata.train"
  13. ModelFile = "newdata.model"
  14. TestDataFile = "newdata.test"
  15. TestOutFile = "newdata.out"
  16. ClassFeaDic = {}
  17. ClassFreq = {}
  18. WordDic = {}
  19. ClassFeaProb = {}
  20. ClassDefaultProb = {}
  21. ClassProb = {}
  22.  
  23. def Dedup(items):
  24. tempDic = {}
  25. for item in items:
  26. if item not in tempDic:
  27. tempDic[item] = True
  28. return tempDic.keys()
  29.  
  30. def LoadData():
  31. i =0
  32. infile = file(TrainingDataFile, 'r')
  33. sline = infile.readline().strip()
  34. while len(sline) > 0:
  35. pos = sline.find("#")
  36. if pos > 0:
  37. sline = sline[:pos].strip()
  38. words = sline.split(' ')
  39. if len(words) < 1:
  40. print "Format error!"
  41. break
  42. classid = int(words[0])
  43. if classid not in ClassFeaDic:
  44. ClassFeaDic[classid] = {}
  45. ClassFeaProb[classid] = {}
  46. ClassFreq[classid] = 0
  47. ClassFreq[classid] += 1
  48. words = words[1:] #take the words in certain news except newsid
  49. #remove duplicate words, binary distribution
  50. #words = Dedup(words)
  51. for word in words:
  52. if len(word) < 1:
  53. continue
  54. wid = int(word)
  55. if wid not in WordDic:
  56. WordDic[wid] = 1
  57. if wid not in ClassFeaDic[classid]:
  58. ClassFeaDic[classid][wid] = 1
  59. else:
  60. ClassFeaDic[classid][wid] += 1
  61. i += 1
  62. sline = infile.readline().strip()
  63. infile.close()
  64. print i, "instances loaded!"
  65. print len(ClassFreq), "classes!", len(WordDic), "words!"
  66.  
  67.  
  68. def ComputeModel():
  69. sum = 0.0
  70. for freq in ClassFreq.values():
  71. sum += freq
  72. for classid in ClassFreq.keys():
  73. ClassProb[classid] = (float)(ClassFreq[classid])/(float)(sum)
  74. for classid in ClassFeaDic.keys():
  75. #Multinomial Distribution
  76. sum = 0.0
  77. for wid in ClassFeaDic[classid].keys():
  78. sum += ClassFeaDic[classid][wid]
  79. newsum = (float)(sum+len(WordDic)*DefaultFreq)
  80. #Binary Distribution
  81. #newsum = (float)(ClassFreq[classid]+2*DefaultFreq)
  82. for wid in ClassFeaDic[classid].keys():
  83. ClassFeaProb[classid][wid] = (float)(ClassFeaDic[classid][wid]+DefaultFreq)/newsum
  84. ClassDefaultProb[classid] = (float)(DefaultFreq) / newsum
  85. return
  86.  
  87.  
  88. def SaveModel():
  89. outfile = file(ModelFile, 'w')
  90. for classid in ClassFreq.keys():
  91. outfile.write(str(classid))
  92. outfile.write(' ')
  93. outfile.write(str(ClassProb[classid]))
  94. outfile.write(' ')
  95. outfile.write(str(ClassDefaultProb[classid]))
  96. outfile.write(' ' )
  97. outfile.write('\n')
  98. for classid in ClassFeaDic.keys():
  99. for wid in ClassFeaDic[classid].keys():
  100. outfile.write(str(wid)+' '+str(ClassFeaProb[classid][wid]))
  101. outfile.write(' ')
  102. outfile.write('\n')
  103. outfile.close()
  104.  
  105.  
  106. def LoadModel():
  107. global WordDic
  108. WordDic = {}
  109. global ClassFeaProb
  110. ClassFeaProb = {}
  111. global ClassDefaultProb
  112. ClassDefaultProb = {}
  113. global ClassProb
  114. ClassProb = {}
  115. infile = file(ModelFile, 'r')
  116. sline = infile.readline().strip()
  117. items = sline.split(' ')
  118. if len(items) < 6:
  119. print "Model format error!"
  120. return
  121. i = 0
  122. while i < len(items):
  123. classid = int(items[i])
  124. ClassFeaProb[classid] = {}
  125. i += 1
  126. if i >= len(items):
  127. print "Model format error!"
  128. return
  129. ClassProb[classid] = float(items[i])
  130. i += 1
  131. if i >= len(items):
  132. print "Model format error!"
  133. return
  134. ClassDefaultProb[classid] = float(items[i])
  135. i += 1
  136. for classid in ClassProb.keys():
  137. sline = infile.readline().strip()
  138. items = sline.split(' ')
  139. i = 0
  140. while i < len(items):
  141. wid = int(items[i])
  142. if wid not in WordDic:
  143. WordDic[wid] = 1
  144. i += 1
  145. if i >= len(items):
  146. print "Model format error!"
  147. return
  148. ClassFeaProb[classid][wid] = float(items[i])
  149. i += 1
  150. infile.close()
  151. print len(ClassProb), "classes!", len(WordDic), "words!"
  152.  
  153. def Predict():
  154. global WordDic
  155. global ClassFeaProb
  156. global ClassDefaultProb
  157. global ClassProb
  158.  
  159. TrueLabelList = []
  160. PredLabelList = []
  161. i =0
  162. infile = file(TestDataFile, 'r')
  163. outfile = file(TestOutFile, 'w')
  164. sline = infile.readline().strip()
  165. scoreDic = {}
  166. iline = 0
  167. while len(sline) > 0:
  168. iline += 1
  169. if iline % 10 == 0:
  170. print iline," lines finished!\r",
  171. pos = sline.find("#")
  172. if pos > 0:
  173. sline = sline[:pos].strip()
  174. words = sline.split(' ')
  175. if len(words) < 1:
  176. print "Format error!"
  177. break
  178. classid = int(words[0])
  179. TrueLabelList.append(classid)
  180. words = words[1:]
  181. #remove duplicate words, binary distribution
  182. #words = Dedup(words)
  183. for classid in ClassProb.keys():
  184. scoreDic[classid] = math.log(ClassProb[classid])
  185. for word in words:
  186. if len(word) < 1:
  187. continue
  188. wid = int(word)
  189. if wid not in WordDic:
  190. #print "OOV word:",wid
  191. continue
  192. for classid in ClassProb.keys():
  193. if wid not in ClassFeaProb[classid]:
  194. scoreDic[classid] += math.log(ClassDefaultProb[classid])
  195. else:
  196. scoreDic[classid] += math.log(ClassFeaProb[classid][wid])
  197. #binary distribution
  198. #wid = 1
  199. #while wid < len(WordDic)+1:
  200. # if str(wid) in words:
  201. # wid += 1
  202. # continue
  203. # for classid in ClassProb.keys():
  204. # if wid not in ClassFeaProb[classid]:
  205. # scoreDic[classid] += math.log(1-ClassDefaultProb[classid])
  206. # else:
  207. # scoreDic[classid] += math.log(1-ClassFeaProb[classid][wid])
  208. # wid += 1
  209. i += 1
  210. maxProb = max(scoreDic.values())
  211. for classid in scoreDic.keys():
  212. if scoreDic[classid] == maxProb:
  213. PredLabelList.append(classid)
  214. sline = infile.readline().strip()
  215. infile.close()
  216. outfile.close()
  217. print len(PredLabelList),len(TrueLabelList)
  218. return TrueLabelList,PredLabelList
  219.  
  220. def Evaluate(TrueList, PredList):
  221. accuracy = 0
  222. i = 0
  223. while i < len(TrueList):
  224. if TrueList[i] == PredList[i]:
  225. accuracy += 1
  226. i += 1
  227. accuracy = (float)(accuracy)/(float)(len(TrueList))
  228. print "Accuracy:",accuracy
  229.  
  230. def CalPreRec(TrueList,PredList,classid):
  231. correctNum = 0
  232. allNum = 0
  233. predNum = 0
  234. i = 0
  235. while i < len(TrueList):
  236. if TrueList[i] == classid:
  237. allNum += 1
  238. if PredList[i] == TrueList[i]:
  239. correctNum += 1
  240. if PredList[i] == classid:
  241. predNum += 1
  242. i += 1
  243. return (float)(correctNum)/(float)(predNum),(float)(correctNum)/(float)(allNum)
  244.  
  245. #main framework
  246.  
  247. if len(sys.argv) < 4:
  248. print "Usage incorrect!"
  249. elif sys.argv[1] == '1':
  250. print "start training:"
  251. TrainingDataFile = sys.argv[2]
  252. ModelFile = sys.argv[3]
  253. LoadData()
  254. ComputeModel()
  255. SaveModel()
  256. elif sys.argv[1] == '0':
  257. print "start testing:"
  258. TestDataFile = sys.argv[2]
  259. ModelFile = sys.argv[3]
  260. TestOutFile = sys.argv[4]
  261. LoadModel()
  262. TList,PList = Predict()
  263. i = 0
  264. outfile = file(TestOutFile, 'w')
  265. while i < len(TList):
  266. outfile.write(str(TList[i]))
  267. outfile.write(' ')
  268. outfile.write(str(PList[i]))
  269. outfile.write('\n')
  270. i += 1
  271. outfile.close()
  272. Evaluate(TList,PList)
  273. for classid in ClassProb.keys():
  274. pre,rec = CalPreRec(TList, PList,classid)
  275. print "Precision and recall for Class",classid,":",pre,rec
  276. else:
  277. print "Usage incorrect!"
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement