Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #part1 input the data
- #data is a list containing all the 2039 training lines as dictionary
- import os
- from math import log
- import copy
- dataDir=""
- dataFile="D:\\A-Good-Place\\Python\\.vscode\\CSC1001\\src\\train.csv"
- header=[]
- dataSet=[]
- splitValue=[]
- def parse_file(datafile):
- global header
- global dataSet
- data=[]
- with open(datafile,"r",encoding="UTF-8") as f:
- header=f.readline().strip("\n").split(",") #获取表头 header is a list
- header=[ea.strip() for ea in header]
- counter=0
- for line in f:
- #if counter==1119:
- # break
- fields=line.split(",")
- entry={}
- for i, value in enumerate(fields):
- entry[header[i].strip()]=value.strip()
- dataSet.append(entry)
- counter+=1
- return dataSet
- def changeCheckPoint(data):# change the quility of wines by compares with 6
- #global data
- for line in data:
- if float(line[header[-1].strip()])>6.0: #when using header, remember to add '.strip()'
- line[header[-1].strip()]=1
- else:
- line[header[-1].strip()]=0
- return data
- def featureGiniSplit(data):#get each feature's split value: (splitValue)list the parameter can used for subrigion like not in the root
- global splitValue
- # get the Ent(D)
- fail=0
- allCheck=[float(c[header[-1].strip()]) for c in data]
- for i in allCheck:
- if i==0:
- fail+=1
- success=len(allCheck)-fail
- p1=success/len(allCheck)
- p2=1-p1
- EntD=-(p1*log(p1)+p2*log(p2))
- #get feature's gini
- for feature in header[:11]: #except the last one, add it at last
- if feature.strip() in data[0].keys():###new add, check
- allPosibility=[]
- Ta=[]
- for i in range(len(data)):#add all the data of a feature into a list
- allPosibility.append(data[i][feature.strip()])
- all=set(allPosibility)
- all=[float(element) for element in all]#加float是应为不加则为string排序,除了大小还要考虑位数于是会从10.0分段
- all.sort()#make the entries of feature from small to big
- for j in range(len(all)-1):#get Ta
- a=(all[j]+all[j+1])/2
- Ta.append(a)
- tForMax=0
- EntDt=0
- for t in Ta: #get each gini for partition
- Ts=[]
- Tb=[]#list of element bigger than t
- for v in data:#get Ts/b s:small b:bigger
- if float(v[feature.strip()])<t:
- Ts.append(v)
- elif float(v[feature.strip()])>=t:
- Tb.append(v)
- #EntDbigger one
- pb=len(Tb)/len(allPosibility)
- fail2=0
- for q in Tb:
- if q[header[-1].strip()]==0:
- fail2+=1
- success2=len(Tb)-fail2
- p3=success2/len(Tb)
- p4=1-p3
- if p3==0:
- EntDb=-(p4*log(p4))
- if p4==0:
- EntDb=-(p3*log(p3))
- if p3!=0 and p4!=0:
- EntDb=-(p3*log(p3)+p4*log(p4))
- #Entdsmaller one
- ps=len(Ts)/len(allPosibility)
- fail3=0
- for w in Ts:
- if w[header[-1].strip()]==0:
- fail3+=1
- success3=len(Ts)-fail3
- p5=success3/len(Ts)
- p6=1-p5
- if p5==0:
- EntDs=-(p6*log(p6))
- if p6==0:
- EntDs=-(p5*log(p5))
- if p5!=0 and p6!=0:
- EntDs=-(p5*log(p5)+p6*log(p6))
- #EntD of this partition
- if EntDt<EntD-EntDb*pb-EntDs*ps:
- EntDt=EntD-EntDb*pb-EntDs*ps
- tForMax=t
- splitValue.append({"value":tForMax, "Gini": EntDt, "feature": feature})
- return splitValue
- def changeData(datachanged):#change the complex number into two type by check their gini choose the smallest
- for line in range(len(datachanged)):
- for i in range(len(datachanged[0])-1): #for the first 11 feature
- if header[i].strip() in datachanged[0].keys():
- if float(datachanged[line][header[i].strip()])>float(splitValue[i]["value"]):
- datachanged[line][header[i].strip()]=1
- else:
- datachanged[line][header[i].strip()]=0
- return datachanged
- #Part2: build the tree
- def splitDataSet(data,value):#去掉进行分类的那个属性
- counter=0
- for featVec in data:
- if featVec["quality"]==value:
- counter+=1
- return counter
- #reducedFeatVec=featVec.pop(axis)#去掉axis特征
- #retDataSet.append(reducedFeatVec)#len(reDataSet)=该feature中的这一类的个数
- def cart_chooseTheBestFeatureToSplit(data):
- bestGini=9999
- bestFeature=-1
- bestLeft=None
- bestRight=None
- for i in header[:11]:
- gini=0
- left=[]
- right=[]
- if i.strip() in data[0].keys():
- for l in range(len(data)):
- if data[l][i.strip()]==0:
- left.append(data[l])
- elif data[l][i.strip()]==1:
- right.append(data[l])
- print(right)
- p7=len(left)/float(len(data)) #len(subDataSet)= 该feature中的这一类的个数 len(data)总个数(Dv/D)
- subp7=float(splitDataSet(left,0))/float(len(left)) #该feature某类下满足大于6的比例(Gini(Dv))
- p8=len(right)/float(len(data)) #len(subDataSet)= 该feature中的这一类的个数 len(data)总个数(Dv/D)
- subp8=float(splitDataSet(right,0))/float(len(right)) #该feature某类下满足大于6(0)的比例(Gini(Dv))
- gini=gini+p7*(1.0-subp7**2-(1-subp7)**2)+p8*(1.0-subp8**2-(1-subp8)**2)
- if(gini<bestGini):
- bestGini=gini
- bestFeature=i.strip()
- return bestFeature
- #featList=[example[i.strip()] for example in data] #call the element in the dictionary
- #uniqueFeature=set(featList) #这个属性中包含的几种元素
- #gini=0
- #for value in featureList:
- # subDataSet=splitDataSet(data, i.strip(), value)#i: 第i个feature; value:i中的某一类
- def majorityClass(classList):
- success=0
- fail=0
- for vote in classList:
- if vote==0:
- fail+=1
- if vote==1:
- success+=1
- if success>fail:
- return "Above 6"
- if success<=fail: #approximate
- return "Below or equal to 6"
- def buildTheTree(data):
- elementList1=[]
- elementList2=[]
- dataRemain=[]
- classList=[example[header[-1].strip()] for example in data]
- if classList.count(classList[0])==len(classList):#every element in the same class
- if classList[0]==0:
- return "Below or equal to 6"
- elif classList[0]==1:
- return "Above 6"
- if len(data[0])==1:
- return majorityClass(classList)
- for i in header[:11]:
- left=[]
- right=[]
- if i.strip() in data[0].keys():
- for l in range(len(data)):
- if data[l][i.strip()]==0:
- left.append(data[l])
- elif data[l][i.strip()]==1:
- right.append(data[l])
- if len(right)==0 or len(left)==0:
- split=featureGiniSplit(data)
- dataRemain=copy.deepcopy(data)
- dataChanged=changeData(data)
- bestFeature=cart_chooseTheBestFeatureToSplit(dataChanged)
- print(dataRemain[0])
- for elem in range(len(dataRemain)):
- dataRemain[elem].pop(bestFeature)
- for ele in range(len(data)):
- if float(data[ele][bestFeature])==0.0:
- elementList1.append(dataRemain[ele])
- elif float(data[ele][bestFeature])==1.0:
- elementList2.append(dataRemain[ele])
- print(elementList1[0])
- buildTheTree(elementList2)
- buildTheTree(elementList1)
- #main
- data=parse_file(dataFile)
- data=changeCheckPoint(data)
- buildTheTree(data)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement