SHARE
TWEET

Untitled

a guest Jul 24th, 2019 70 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import re
  2. import xml.etree.cElementTree as ET
  3. regex_float_pattern = r'[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?'
  4.  
  5. def build_tree(xgtree, base_xml_element, var_indices):
  6.     parent_element_dict = {'0':base_xml_element}
  7.     pos_dict = {'0':'s'}
  8.     for line in xgtree.split('\n'):
  9.         if not line: continue
  10.         if ':leaf=' in line:
  11.             #leaf node
  12.             result = re.match(r'(\t*)(\d+):leaf=({0})$'.format(regex_float_pattern), line)
  13.             if not result:
  14.                 print line
  15.             depth = result.group(1).count('\t')
  16.             inode = result.group(2)
  17.             res = result.group(3)
  18.             node_elementTree = ET.SubElement(parent_element_dict[inode], "Node", pos=str(pos_dict[inode]),
  19.                                              depth=str(depth), NCoef="0", IVar="-1", Cut="0.0e+00", cType="1", res=str(res), rms="0.0e+00", purity="0.0e+00", nType="-99")
  20.         else:
  21.             #\t\t3:[var_topcand_mass<138.19] yes=7,no=8,missing=7
  22.             result = re.match(r'(\t*)([0-9]+):\[(?P<var>.+)<(?P<cut>{0})\]\syes=(?P<yes>\d+),no=(?P<no>\d+)'.format(regex_float_pattern),line)
  23.             if not result:
  24.                 print line
  25.             depth = result.group(1).count('\t')
  26.             inode = result.group(2)
  27.             var = result.group('var')
  28.             cut = result.group('cut')
  29.             lnode = result.group('yes')
  30.             rnode = result.group('no')
  31.             pos_dict[lnode] = 'l'
  32.             pos_dict[rnode] = 'r'
  33.             node_elementTree = ET.SubElement(parent_element_dict[inode], "Node", pos=str(pos_dict[inode]),
  34.                                              depth=str(depth), NCoef="0", IVar=str(var_indices[var]), Cut=str(cut),
  35.                                              cType="1", res="0.0e+00", rms="0.0e+00", purity="0.0e+00", nType="0")
  36.             parent_element_dict[lnode] = node_elementTree
  37.             parent_element_dict[rnode] = node_elementTree
  38.            
  39. def convert_model(model, input_variables, output_xml):
  40.     NTrees = len(model)
  41.     var_list = input_variables
  42.     var_indices = {}
  43.    
  44.     # <MethodSetup>
  45.     MethodSetup = ET.Element("MethodSetup", Method="BDT::BDT")
  46.  
  47.     # <Variables>
  48.     Variables = ET.SubElement(MethodSetup, "Variables", NVar=str(len(var_list)))
  49.     for ind, val in enumerate(var_list):
  50.         name = val[0]
  51.         var_type = val[1]
  52.         var_indices[name] = ind
  53.         Variable = ET.SubElement(Variables, "Variable", VarIndex=str(ind), Type=val[1],
  54.             Expression=name, Label=name, Title=name, Unit="", Internal=name,
  55.             Min="0.0e+00", Max="0.0e+00")
  56.  
  57.     # <GeneralInfo>
  58.     GeneralInfo = ET.SubElement(MethodSetup, "GeneralInfo")
  59.     Info_Creator = ET.SubElement(GeneralInfo, "Info", name="Creator", value="xgboost2TMVA")
  60.     Info_AnalysisType = ET.SubElement(GeneralInfo, "Info", name="AnalysisType", value="Classification")
  61.  
  62.     # <Options>
  63.     Options = ET.SubElement(MethodSetup, "Options")
  64.     Option_NodePurityLimit = ET.SubElement(Options, "Option", name="NodePurityLimit", modified="No").text = "5.00e-01"
  65.     Option_BoostType = ET.SubElement(Options, "Option", name="BoostType", modified="Yes").text = "Grad"
  66.    
  67.     # <Weights>
  68.     Weights = ET.SubElement(MethodSetup, "Weights", NTrees=str(NTrees), AnalysisType="1")
  69.    
  70.     for itree in range(NTrees):
  71.         BinaryTree = ET.SubElement(Weights, "BinaryTree", type="DecisionTree", boostWeight="1.0e+00", itree=str(itree))
  72.         build_tree(model[itree], BinaryTree, var_indices)
  73.        
  74.     tree = ET.ElementTree(MethodSetup)
  75.     tree.write(output_xml)
  76.     # format it with 'xmllint --format'
  77.    
  78. # example
  79. # bst = xgb.train( param, d_train, num_round, watchlist );
  80. # model = bst.get_dump()
  81. # convert_model(model,input_variables=[('var1','F'),('var2','I')],output_xml='xgboost.xml')
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top