Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #part1 input the data
- #data is a list containing all the 2039 training lines as dictionary
- import os
- from math import log
- import copy
- dataDir=""
- dataFile="D:\\A-Good-Place\\Python\\.vscode\\CSC1001\\src\\train.csv"
- header=[]
- dataSet=[]
- splitValue=[]
- class Node:
- def __init__(self,element,parent=None,left=None,right=None):
- self.element=element
- self.pareant=parent
- self.left=left
- self.right=right
- class LBTree:
- def __init__(self):
- self.root=None
- self.size=0
- def __len__(self):
- return self.size
- def find_root(self):
- return self.root
- def parent(self,p):
- return p.parent
- def left(self,p):
- return p.left
- def right(self,p):
- return p.right
- def num_child(self,p):
- count=0
- if p.left is not None:
- count+=1
- if p.right is not None:
- count+=1
- return count
- def add_root(self,e):
- if self.root is not None:
- return None
- self.size=1
- self.root=Node(e)
- return self.root
- def add_left(self,p,e):
- self.size+=1
- p.left=Node(e,p)
- return p.left
- def add_right(self,p,e):
- self.size+=1
- p.right=Node(e,p)
- return p.right
- def replace(self,p,e):
- old=p.element
- p.element=e
- return old
- def delete(self,p):
- if p.parent.left is p:
- p.parent.left=None
- if p.parent.right is p:
- p.parent.right=None
- return p.parent
- def parse_file(datafile):
- global header
- global dataSet
- data=[]
- with open(datafile,"r",encoding="UTF-8") as f:
- header=f.readline().strip("\n").split(",") #获取表头 header is a list
- header=[ea.strip() for ea in header]
- counter=0
- for line in f:
- #if counter==1119:
- # break
- fields=line.split(",")
- entry={}
- for i, value in enumerate(fields):
- entry[header[i].strip()]=value.strip()
- dataSet.append(entry)
- counter+=1
- return dataSet
- def changeCheckPoint(data):# change the quility of wines by compares with 6
- #global data
- for line in data:
- if float(line[header[-1].strip()])>6.0: #when using header, remember to add '.strip()'
- line[header[-1].strip()]=1
- else:
- line[header[-1].strip()]=0
- return data
- def featureGiniSplit(data):#get each feature's split value: (splitValue)list the parameter can used for subrigion like not in the root
- global splitValue
- # get the Ent(D)
- fail=0
- allCheck=[float(c[header[-1].strip()]) for c in data]
- for i in allCheck:
- if i==0:
- fail+=1
- success=len(allCheck)-fail
- p1=success/len(allCheck)
- p2=1-p1
- if p1==0:
- EntD=-(p2*log(p2))
- elif p2==0:
- EntD=-(p1*log(p1))
- elif p1!=0 and p2!=0:
- EntD=-(p1*log(p1)+p2*log(p2))
- #get feature's gini
- for feature in header[:11]: #except the last one, add it at last
- if feature.strip() in data[0].keys():###new add, check
- allPosibility=[]
- Ta=[]
- for i in range(len(data)):#add all the data of a feature into a list
- allPosibility.append(data[i][feature.strip()])
- all=set(allPosibility)
- all=[float(element) for element in all]#加float是应为不加则为string排序,除了大小还要考虑位数于是会从10.0分段
- all.sort()#make the entries of feature from small to big
- for j in range(len(all)-1):#get Ta
- a=(all[j]+all[j+1])/2
- Ta.append(a)
- tForMax=0
- EntDt=0
- for t in Ta: #get each gini for partition
- Ts=[]
- Tb=[]#list of element bigger than t
- for v in data:#get Ts/b s:small b:bigger
- if float(v[feature.strip()])<t:
- Ts.append(v)
- elif float(v[feature.strip()])>=t:
- Tb.append(v)
- #EntDbigger one
- pb=len(Tb)/len(allPosibility)
- fail2=0
- for q in Tb:
- if q[header[-1].strip()]==0:
- fail2+=1
- success2=len(Tb)-fail2
- p3=success2/len(Tb)
- p4=1-p3
- if p3==0:
- EntDb=-(p4*log(p4))
- if p4==0:
- EntDb=-(p3*log(p3))
- if p3!=0 and p4!=0:
- EntDb=-(p3*log(p3)+p4*log(p4))
- #Entdsmaller one
- ps=len(Ts)/len(allPosibility)
- fail3=0
- for w in Ts:
- if w[header[-1].strip()]==0:
- fail3+=1
- success3=len(Ts)-fail3
- p5=success3/len(Ts)
- p6=1-p5
- if p5==0:
- EntDs=-(p6*log(p6))
- if p6==0:
- EntDs=-(p5*log(p5))
- if p5!=0 and p6!=0:
- EntDs=-(p5*log(p5)+p6*log(p6))
- #EntD of this partition
- if EntDt<EntD-EntDb*pb-EntDs*ps:
- EntDt=EntD-EntDb*pb-EntDs*ps
- tForMax=t
- splitValue.append({"value":tForMax, "Gini": EntDt, "feature": feature})
- return splitValue
- def changeData(datachanged):#change the complex number into two type by check their gini choose the smallest
- for line in range(len(datachanged)):
- for i in range(11): #for the first 11 feature
- if header[i].strip() in datachanged[0].keys():
- if float(datachanged[line][header[i].strip()])>float(splitValue[i]["value"]):
- datachanged[line][header[i].strip()]=1
- else:
- datachanged[line][header[i].strip()]=0
- return datachanged
- #Part2: build the tree
- def splitDataSet(data,value):#去掉进行分类的那个属性
- counter=0
- for featVec in data:
- if featVec["quality"]==value:
- counter+=1
- return counter
- def cart_chooseTheBestFeatureToSplit(data):
- bestGini=9999
- bestFeature=-1
- bestLeft=None
- bestRight=None
- for i in header[:11]:
- gini=0
- left=[]
- right=[]
- if i.strip() in data[0].keys():
- for l in range(len(data)):
- if data[l][i.strip()]==0:
- left.append(data[l])
- elif data[l][i.strip()]==1:
- right.append(data[l])
- if len(right)==0 or len(left)==0:
- return
- p7=len(left)/float(len(data)) #len(subDataSet)= 该feature中的这一类的个数 len(data)总个数(Dv/D)
- subp7=float(splitDataSet(left,0))/float(len(left)) #该feature某类下满足大于6的比例(Gini(Dv))
- p8=len(right)/float(len(data)) #len(subDataSet)= 该feature中的这一类的个数 len(data)总个数(Dv/D)
- subp8=float(splitDataSet(right,0))/float(len(right)) #该feature某类下满足大于6(0)的比例(Gini(Dv))
- gini=gini+p7*(1.0-subp7**2-(1-subp7)**2)+p8*(1.0-subp8**2-(1-subp8)**2)
- if(gini<bestGini):
- bestGini=gini
- bestFeature=i.strip()
- return bestFeature
- def majorityClass(classList):
- success=0
- fail=0
- for vote in classList:
- if vote==0:
- fail+=1
- if vote==1:
- success+=1
- if success>fail:
- return "Above 6"
- if success<=fail: #approximate
- return "Below or equal to 6"
- def buildTheTree(data,direction,parentNode):
- elementList1=[]
- elementList2=[]
- dataRemain=[]
- p=parentNode
- #Three situations that will stop
- classList=[example[header[-1].strip()] for example in data]
- if classList.count(classList[0])==len(classList):#every element in the same class
- if classList[0]==0:
- if direction=="right":
- t.add_right(p,"Below or equal to 6")
- return
- elif direction=="left":
- t.add_left(p,"Below or equal to 6")
- return
- elif classList[0]==1:
- if direction=="right":
- t.add_right(p,"Above 6")
- return
- elif direction=="left":
- t.add_left(p,"Above 6")
- return
- if len(data[0])==1:
- labal=majorityClass(classList)
- if direction=="right":
- t.add_right(p,labal)
- return
- elif direction=="left":
- t.add_left(p,labal)
- return
- split=featureGiniSplit(data)
- dataRemain=copy.deepcopy(data)
- dataChanged=changeData(data)
- bestFeature=cart_chooseTheBestFeatureToSplit(dataChanged)
- for i in header[:11]:
- left=[]
- right=[]
- if i.strip() in dataChanged[0].keys():
- for l in range(len(dataChanged)):
- if dataChanged[l][i.strip()]==0:
- left.append(dataChanged[l])
- elif dataChanged[l][i.strip()]==1:
- right.append(dataChanged[l])
- if len(right)==0 or len(left)==0:
- labal= majorityClass(classList)
- if direction=="right":
- t.add_right(p,labal)
- return
- elif direction=="left":
- t.add_left(p,labal)
- return
- for fea in range(len(split)):####
- if bestFeature==split[fea]["feature"]:
- value=split[fea]["value"]
- for elem in range(len(dataRemain)):
- dataRemain[elem].pop(bestFeature)
- for ele in range(len(data)):
- if float(data[ele][bestFeature])==0.0:
- elementList1.append(dataRemain[ele])
- elif float(data[ele][bestFeature])==1.0:
- elementList2.append(dataRemain[ele])
- if direction=="root":
- t.add_root({"feature":bestFeature,"value":value})
- p=t.find_root()
- elif direction=="right":
- t.add_right(p,{"feature":bestFeature,"value":value})
- p=p.right
- elif direction=="left":
- t.add_left(p,{"feature":bestFeature,"value":value})
- p=p.left
- buildTheTree(elementList2,"right",p)
- buildTheTree(elementList1,"left",p)
- #main
- data=parse_file(dataFile)
- data=changeCheckPoint(data)
- t=LBTree()
- buildTheTree(data,"root",None)
- #print(t.root.right.element)
- #print(t.root.left.element)
- #print(t.size)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement