Advertisement
Fenny_Theo

csc1001 group project part2

May 15th, 2020
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 8.31 KB | None | 0 0
  1. #part1 input the data
  2. #data is a list containing all the 2039 training lines as dictionary
  3. import os
  4. from math import log
  5. import copy
  6. dataDir=""
  7. dataFile="D:\\A-Good-Place\\Python\\.vscode\\CSC1001\\src\\train.csv"
  8. header=[]
  9. dataSet=[]
  10. splitValue=[]
  11.  
  12. def parse_file(datafile):
  13. global header
  14. global dataSet
  15. data=[]
  16. with open(datafile,"r",encoding="UTF-8") as f:
  17. header=f.readline().strip("\n").split(",") #获取表头 header is a list
  18. header=[ea.strip() for ea in header]
  19. counter=0
  20. for line in f:
  21. #if counter==1119:
  22. # break
  23. fields=line.split(",")
  24. entry={}
  25. for i, value in enumerate(fields):
  26. entry[header[i].strip()]=value.strip()
  27. dataSet.append(entry)
  28. counter+=1
  29. return dataSet
  30.  
  31.  
  32. def changeCheckPoint(data):# change the quility of wines by compares with 6
  33. #global data
  34. for line in data:
  35. if float(line[header[-1].strip()])>6.0: #when using header, remember to add '.strip()'
  36. line[header[-1].strip()]=1
  37. else:
  38. line[header[-1].strip()]=0
  39. return data
  40.  
  41.  
  42. def featureGiniSplit(data):#get each feature's split value: (splitValue)list the parameter can used for subrigion like not in the root
  43. global splitValue
  44. # get the Ent(D)
  45. fail=0
  46. allCheck=[float(c[header[-1].strip()]) for c in data]
  47. for i in allCheck:
  48. if i==0:
  49. fail+=1
  50. success=len(allCheck)-fail
  51. p1=success/len(allCheck)
  52. p2=1-p1
  53. EntD=-(p1*log(p1)+p2*log(p2))
  54. #get feature's gini
  55. for feature in header[:11]: #except the last one, add it at last
  56. if feature.strip() in data[0].keys():###new add, check
  57. allPosibility=[]
  58. Ta=[]
  59. for i in range(len(data)):#add all the data of a feature into a list
  60. allPosibility.append(data[i][feature.strip()])
  61. all=set(allPosibility)
  62. all=[float(element) for element in all]#加float是应为不加则为string排序,除了大小还要考虑位数于是会从10.0分段
  63. all.sort()#make the entries of feature from small to big
  64. for j in range(len(all)-1):#get Ta
  65. a=(all[j]+all[j+1])/2
  66. Ta.append(a)
  67. tForMax=0
  68. EntDt=0
  69. for t in Ta: #get each gini for partition
  70. Ts=[]
  71. Tb=[]#list of element bigger than t
  72. for v in data:#get Ts/b s:small b:bigger
  73. if float(v[feature.strip()])<t:
  74. Ts.append(v)
  75. elif float(v[feature.strip()])>=t:
  76. Tb.append(v)
  77. #EntDbigger one
  78. pb=len(Tb)/len(allPosibility)
  79. fail2=0
  80. for q in Tb:
  81. if q[header[-1].strip()]==0:
  82. fail2+=1
  83. success2=len(Tb)-fail2
  84. p3=success2/len(Tb)
  85. p4=1-p3
  86. if p3==0:
  87. EntDb=-(p4*log(p4))
  88. if p4==0:
  89. EntDb=-(p3*log(p3))
  90. if p3!=0 and p4!=0:
  91. EntDb=-(p3*log(p3)+p4*log(p4))
  92.  
  93. #Entdsmaller one
  94. ps=len(Ts)/len(allPosibility)
  95. fail3=0
  96. for w in Ts:
  97. if w[header[-1].strip()]==0:
  98. fail3+=1
  99. success3=len(Ts)-fail3
  100. p5=success3/len(Ts)
  101. p6=1-p5
  102. if p5==0:
  103. EntDs=-(p6*log(p6))
  104. if p6==0:
  105. EntDs=-(p5*log(p5))
  106. if p5!=0 and p6!=0:
  107. EntDs=-(p5*log(p5)+p6*log(p6))
  108. #EntD of this partition
  109. if EntDt<EntD-EntDb*pb-EntDs*ps:
  110. EntDt=EntD-EntDb*pb-EntDs*ps
  111. tForMax=t
  112. splitValue.append({"value":tForMax, "Gini": EntDt, "feature": feature})
  113. return splitValue
  114.  
  115. def changeData(datachanged):#change the complex number into two type by check their gini choose the smallest
  116. for line in range(len(datachanged)):
  117. for i in range(len(datachanged[0])-1): #for the first 11 feature
  118. if header[i].strip() in datachanged[0].keys():
  119. if float(datachanged[line][header[i].strip()])>float(splitValue[i]["value"]):
  120. datachanged[line][header[i].strip()]=1
  121. else:
  122. datachanged[line][header[i].strip()]=0
  123. return datachanged
  124.  
  125.  
  126.  
  127. #Part2: build the tree
  128.  
  129. def splitDataSet(data,value):#去掉进行分类的那个属性
  130. counter=0
  131. for featVec in data:
  132. if featVec["quality"]==value:
  133. counter+=1
  134. return counter
  135. #reducedFeatVec=featVec.pop(axis)#去掉axis特征
  136. #retDataSet.append(reducedFeatVec)#len(reDataSet)=该feature中的这一类的个数
  137.  
  138.  
  139.  
  140. def cart_chooseTheBestFeatureToSplit(data):
  141. bestGini=9999
  142. bestFeature=-1
  143. bestLeft=None
  144. bestRight=None
  145. for i in header[:11]:
  146. gini=0
  147. left=[]
  148. right=[]
  149. if i.strip() in data[0].keys():
  150. for l in range(len(data)):
  151. if data[l][i.strip()]==0:
  152. left.append(data[l])
  153. elif data[l][i.strip()]==1:
  154. right.append(data[l])
  155. print(right)
  156. p7=len(left)/float(len(data)) #len(subDataSet)= 该feature中的这一类的个数 len(data)总个数(Dv/D)
  157. subp7=float(splitDataSet(left,0))/float(len(left)) #该feature某类下满足大于6的比例(Gini(Dv))
  158. p8=len(right)/float(len(data)) #len(subDataSet)= 该feature中的这一类的个数 len(data)总个数(Dv/D)
  159. subp8=float(splitDataSet(right,0))/float(len(right)) #该feature某类下满足大于6(0)的比例(Gini(Dv))
  160. gini=gini+p7*(1.0-subp7**2-(1-subp7)**2)+p8*(1.0-subp8**2-(1-subp8)**2)
  161. if(gini<bestGini):
  162. bestGini=gini
  163. bestFeature=i.strip()
  164. return bestFeature
  165.  
  166. #featList=[example[i.strip()] for example in data] #call the element in the dictionary
  167. #uniqueFeature=set(featList) #这个属性中包含的几种元素
  168. #gini=0
  169. #for value in featureList:
  170. # subDataSet=splitDataSet(data, i.strip(), value)#i: 第i个feature; value:i中的某一类
  171.  
  172. def majorityClass(classList):
  173. success=0
  174. fail=0
  175. for vote in classList:
  176. if vote==0:
  177. fail+=1
  178. if vote==1:
  179. success+=1
  180. if success>fail:
  181. return "Above 6"
  182. if success<=fail: #approximate
  183. return "Below or equal to 6"
  184.  
  185.  
  186. def buildTheTree(data):
  187. elementList1=[]
  188. elementList2=[]
  189. dataRemain=[]
  190. classList=[example[header[-1].strip()] for example in data]
  191. if classList.count(classList[0])==len(classList):#every element in the same class
  192. if classList[0]==0:
  193. return "Below or equal to 6"
  194. elif classList[0]==1:
  195. return "Above 6"
  196. if len(data[0])==1:
  197. return majorityClass(classList)
  198. for i in header[:11]:
  199. left=[]
  200. right=[]
  201. if i.strip() in data[0].keys():
  202. for l in range(len(data)):
  203. if data[l][i.strip()]==0:
  204. left.append(data[l])
  205. elif data[l][i.strip()]==1:
  206. right.append(data[l])
  207. if len(right)==0 or len(left)==0:
  208.  
  209. split=featureGiniSplit(data)
  210. dataRemain=copy.deepcopy(data)
  211. dataChanged=changeData(data)
  212. bestFeature=cart_chooseTheBestFeatureToSplit(dataChanged)
  213. print(dataRemain[0])
  214. for elem in range(len(dataRemain)):
  215. dataRemain[elem].pop(bestFeature)
  216. for ele in range(len(data)):
  217. if float(data[ele][bestFeature])==0.0:
  218. elementList1.append(dataRemain[ele])
  219. elif float(data[ele][bestFeature])==1.0:
  220. elementList2.append(dataRemain[ele])
  221. print(elementList1[0])
  222.  
  223. buildTheTree(elementList2)
  224. buildTheTree(elementList1)
  225.  
  226.  
  227.  
  228.  
  229.  
  230.  
  231.  
  232.  
  233. #main
  234. data=parse_file(dataFile)
  235. data=changeCheckPoint(data)
  236. buildTheTree(data)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement