Advertisement
Fenny_Theo

csc1001 group project part3

May 15th, 2020
70
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 10.69 KB | None | 0 0
  1. #part1 input the data
  2. #data is a list containing all the 2039 training lines as dictionary
  3. import os
  4. from math import log
  5. import copy
  6. dataDir=""
  7. dataFile="D:\\A-Good-Place\\Python\\.vscode\\CSC1001\\src\\train.csv"
  8. header=[]
  9. dataSet=[]
  10. splitValue=[]
  11.  
  12. class Node:
  13. def __init__(self,element,parent=None,left=None,right=None):
  14. self.element=element
  15. self.pareant=parent
  16. self.left=left
  17. self.right=right
  18.  
  19. class LBTree:
  20. def __init__(self):
  21. self.root=None
  22. self.size=0
  23. def __len__(self):
  24. return self.size
  25. def find_root(self):
  26. return self.root
  27. def parent(self,p):
  28. return p.parent
  29. def left(self,p):
  30. return p.left
  31. def right(self,p):
  32. return p.right
  33. def num_child(self,p):
  34. count=0
  35. if p.left is not None:
  36. count+=1
  37. if p.right is not None:
  38. count+=1
  39. return count
  40. def add_root(self,e):
  41. if self.root is not None:
  42. return None
  43. self.size=1
  44. self.root=Node(e)
  45. return self.root
  46. def add_left(self,p,e):
  47. self.size+=1
  48. p.left=Node(e,p)
  49. return p.left
  50. def add_right(self,p,e):
  51. self.size+=1
  52. p.right=Node(e,p)
  53. return p.right
  54. def replace(self,p,e):
  55. old=p.element
  56. p.element=e
  57. return old
  58. def delete(self,p):
  59. if p.parent.left is p:
  60. p.parent.left=None
  61. if p.parent.right is p:
  62. p.parent.right=None
  63. return p.parent
  64.  
  65.  
  66.  
  67.  
  68.  
  69.  
  70.  
  71. def parse_file(datafile):
  72. global header
  73. global dataSet
  74. data=[]
  75. with open(datafile,"r",encoding="UTF-8") as f:
  76. header=f.readline().strip("\n").split(",") #获取表头 header is a list
  77. header=[ea.strip() for ea in header]
  78. counter=0
  79. for line in f:
  80. #if counter==1119:
  81. # break
  82. fields=line.split(",")
  83. entry={}
  84. for i, value in enumerate(fields):
  85. entry[header[i].strip()]=value.strip()
  86. dataSet.append(entry)
  87. counter+=1
  88. return dataSet
  89.  
  90.  
  91. def changeCheckPoint(data):# change the quility of wines by compares with 6
  92. #global data
  93. for line in data:
  94. if float(line[header[-1].strip()])>6.0: #when using header, remember to add '.strip()'
  95. line[header[-1].strip()]=1
  96. else:
  97. line[header[-1].strip()]=0
  98. return data
  99.  
  100.  
  101. def featureGiniSplit(data):#get each feature's split value: (splitValue)list the parameter can used for subrigion like not in the root
  102. global splitValue
  103. # get the Ent(D)
  104. fail=0
  105. allCheck=[float(c[header[-1].strip()]) for c in data]
  106. for i in allCheck:
  107. if i==0:
  108. fail+=1
  109. success=len(allCheck)-fail
  110. p1=success/len(allCheck)
  111. p2=1-p1
  112. if p1==0:
  113. EntD=-(p2*log(p2))
  114. elif p2==0:
  115. EntD=-(p1*log(p1))
  116. elif p1!=0 and p2!=0:
  117. EntD=-(p1*log(p1)+p2*log(p2))
  118. #get feature's gini
  119. for feature in header[:11]: #except the last one, add it at last
  120. if feature.strip() in data[0].keys():###new add, check
  121. allPosibility=[]
  122. Ta=[]
  123. for i in range(len(data)):#add all the data of a feature into a list
  124. allPosibility.append(data[i][feature.strip()])
  125. all=set(allPosibility)
  126. all=[float(element) for element in all]#加float是应为不加则为string排序,除了大小还要考虑位数于是会从10.0分段
  127. all.sort()#make the entries of feature from small to big
  128. for j in range(len(all)-1):#get Ta
  129. a=(all[j]+all[j+1])/2
  130. Ta.append(a)
  131. tForMax=0
  132. EntDt=0
  133. for t in Ta: #get each gini for partition
  134. Ts=[]
  135. Tb=[]#list of element bigger than t
  136. for v in data:#get Ts/b s:small b:bigger
  137. if float(v[feature.strip()])<t:
  138. Ts.append(v)
  139. elif float(v[feature.strip()])>=t:
  140. Tb.append(v)
  141. #EntDbigger one
  142. pb=len(Tb)/len(allPosibility)
  143. fail2=0
  144. for q in Tb:
  145. if q[header[-1].strip()]==0:
  146. fail2+=1
  147. success2=len(Tb)-fail2
  148. p3=success2/len(Tb)
  149. p4=1-p3
  150. if p3==0:
  151. EntDb=-(p4*log(p4))
  152. if p4==0:
  153. EntDb=-(p3*log(p3))
  154. if p3!=0 and p4!=0:
  155. EntDb=-(p3*log(p3)+p4*log(p4))
  156.  
  157. #Entdsmaller one
  158. ps=len(Ts)/len(allPosibility)
  159. fail3=0
  160. for w in Ts:
  161. if w[header[-1].strip()]==0:
  162. fail3+=1
  163. success3=len(Ts)-fail3
  164. p5=success3/len(Ts)
  165. p6=1-p5
  166. if p5==0:
  167. EntDs=-(p6*log(p6))
  168. if p6==0:
  169. EntDs=-(p5*log(p5))
  170. if p5!=0 and p6!=0:
  171. EntDs=-(p5*log(p5)+p6*log(p6))
  172. #EntD of this partition
  173. if EntDt<EntD-EntDb*pb-EntDs*ps:
  174. EntDt=EntD-EntDb*pb-EntDs*ps
  175. tForMax=t
  176. splitValue.append({"value":tForMax, "Gini": EntDt, "feature": feature})
  177. return splitValue
  178.  
  179. def changeData(datachanged):#change the complex number into two type by check their gini choose the smallest
  180. for line in range(len(datachanged)):
  181. for i in range(11): #for the first 11 feature
  182. if header[i].strip() in datachanged[0].keys():
  183. if float(datachanged[line][header[i].strip()])>float(splitValue[i]["value"]):
  184. datachanged[line][header[i].strip()]=1
  185. else:
  186. datachanged[line][header[i].strip()]=0
  187. return datachanged
  188.  
  189.  
  190.  
  191. #Part2: build the tree
  192.  
  193. def splitDataSet(data,value):#去掉进行分类的那个属性
  194. counter=0
  195. for featVec in data:
  196. if featVec["quality"]==value:
  197. counter+=1
  198. return counter
  199.  
  200.  
  201. def cart_chooseTheBestFeatureToSplit(data):
  202. bestGini=9999
  203. bestFeature=-1
  204. bestLeft=None
  205. bestRight=None
  206. for i in header[:11]:
  207. gini=0
  208. left=[]
  209. right=[]
  210. if i.strip() in data[0].keys():
  211. for l in range(len(data)):
  212. if data[l][i.strip()]==0:
  213. left.append(data[l])
  214. elif data[l][i.strip()]==1:
  215. right.append(data[l])
  216. if len(right)==0 or len(left)==0:
  217. return
  218.  
  219. p7=len(left)/float(len(data)) #len(subDataSet)= 该feature中的这一类的个数 len(data)总个数(Dv/D)
  220. subp7=float(splitDataSet(left,0))/float(len(left)) #该feature某类下满足大于6的比例(Gini(Dv))
  221. p8=len(right)/float(len(data)) #len(subDataSet)= 该feature中的这一类的个数 len(data)总个数(Dv/D)
  222. subp8=float(splitDataSet(right,0))/float(len(right)) #该feature某类下满足大于6(0)的比例(Gini(Dv))
  223. gini=gini+p7*(1.0-subp7**2-(1-subp7)**2)+p8*(1.0-subp8**2-(1-subp8)**2)
  224. if(gini<bestGini):
  225. bestGini=gini
  226. bestFeature=i.strip()
  227. return bestFeature
  228.  
  229.  
  230. def majorityClass(classList):
  231. success=0
  232. fail=0
  233. for vote in classList:
  234. if vote==0:
  235. fail+=1
  236. if vote==1:
  237. success+=1
  238. if success>fail:
  239. return "Above 6"
  240. if success<=fail: #approximate
  241. return "Below or equal to 6"
  242.  
  243.  
  244. def buildTheTree(data,direction,parentNode):
  245. elementList1=[]
  246. elementList2=[]
  247. dataRemain=[]
  248. p=parentNode
  249. #Three situations that will stop
  250. classList=[example[header[-1].strip()] for example in data]
  251. if classList.count(classList[0])==len(classList):#every element in the same class
  252. if classList[0]==0:
  253. if direction=="right":
  254. t.add_right(p,"Below or equal to 6")
  255. return
  256. elif direction=="left":
  257. t.add_left(p,"Below or equal to 6")
  258. return
  259. elif classList[0]==1:
  260. if direction=="right":
  261. t.add_right(p,"Above 6")
  262. return
  263. elif direction=="left":
  264. t.add_left(p,"Above 6")
  265. return
  266.  
  267. if len(data[0])==1:
  268. labal=majorityClass(classList)
  269. if direction=="right":
  270. t.add_right(p,labal)
  271. return
  272. elif direction=="left":
  273. t.add_left(p,labal)
  274. return
  275.  
  276. split=featureGiniSplit(data)
  277. dataRemain=copy.deepcopy(data)
  278. dataChanged=changeData(data)
  279. bestFeature=cart_chooseTheBestFeatureToSplit(dataChanged)
  280.  
  281. for i in header[:11]:
  282. left=[]
  283. right=[]
  284. if i.strip() in dataChanged[0].keys():
  285. for l in range(len(dataChanged)):
  286. if dataChanged[l][i.strip()]==0:
  287. left.append(dataChanged[l])
  288. elif dataChanged[l][i.strip()]==1:
  289. right.append(dataChanged[l])
  290. if len(right)==0 or len(left)==0:
  291. labal= majorityClass(classList)
  292. if direction=="right":
  293. t.add_right(p,labal)
  294. return
  295. elif direction=="left":
  296. t.add_left(p,labal)
  297. return
  298.  
  299.  
  300. for fea in range(len(split)):####
  301. if bestFeature==split[fea]["feature"]:
  302. value=split[fea]["value"]
  303. for elem in range(len(dataRemain)):
  304. dataRemain[elem].pop(bestFeature)
  305. for ele in range(len(data)):
  306. if float(data[ele][bestFeature])==0.0:
  307. elementList1.append(dataRemain[ele])
  308. elif float(data[ele][bestFeature])==1.0:
  309. elementList2.append(dataRemain[ele])
  310. if direction=="root":
  311. t.add_root({"feature":bestFeature,"value":value})
  312. p=t.find_root()
  313. elif direction=="right":
  314. t.add_right(p,{"feature":bestFeature,"value":value})
  315. p=p.right
  316. elif direction=="left":
  317. t.add_left(p,{"feature":bestFeature,"value":value})
  318. p=p.left
  319.  
  320. buildTheTree(elementList2,"right",p)
  321. buildTheTree(elementList1,"left",p)
  322.  
  323.  
  324.  
  325.  
  326.  
  327.  
  328.  
  329.  
  330. #main
  331. data=parse_file(dataFile)
  332. data=changeCheckPoint(data)
  333. t=LBTree()
  334. buildTheTree(data,"root",None)
  335. #print(t.root.right.element)
  336. #print(t.root.left.element)
  337. #print(t.size)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement