Advertisement
Guest User

Untitled

a guest
Jul 24th, 2019
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.71 KB | None | 0 0
  1. import re
  2. import xml.etree.cElementTree as ET
  3. regex_float_pattern = r'[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?'
  4.  
  5. def build_tree(xgtree, base_xml_element, var_indices):
  6. parent_element_dict = {'0':base_xml_element}
  7. pos_dict = {'0':'s'}
  8. for line in xgtree.split('\n'):
  9. if not line: continue
  10. if ':leaf=' in line:
  11. #leaf node
  12. result = re.match(r'(\t*)(\d+):leaf=({0})$'.format(regex_float_pattern), line)
  13. if not result:
  14. print line
  15. depth = result.group(1).count('\t')
  16. inode = result.group(2)
  17. res = result.group(3)
  18. node_elementTree = ET.SubElement(parent_element_dict[inode], "Node", pos=str(pos_dict[inode]),
  19. depth=str(depth), NCoef="0", IVar="-1", Cut="0.0e+00", cType="1", res=str(res), rms="0.0e+00", purity="0.0e+00", nType="-99")
  20. else:
  21. #\t\t3:[var_topcand_mass<138.19] yes=7,no=8,missing=7
  22. result = re.match(r'(\t*)([0-9]+):\[(?P<var>.+)<(?P<cut>{0})\]\syes=(?P<yes>\d+),no=(?P<no>\d+)'.format(regex_float_pattern),line)
  23. if not result:
  24. print line
  25. depth = result.group(1).count('\t')
  26. inode = result.group(2)
  27. var = result.group('var')
  28. cut = result.group('cut')
  29. lnode = result.group('yes')
  30. rnode = result.group('no')
  31. pos_dict[lnode] = 'l'
  32. pos_dict[rnode] = 'r'
  33. node_elementTree = ET.SubElement(parent_element_dict[inode], "Node", pos=str(pos_dict[inode]),
  34. depth=str(depth), NCoef="0", IVar=str(var_indices[var]), Cut=str(cut),
  35. cType="1", res="0.0e+00", rms="0.0e+00", purity="0.0e+00", nType="0")
  36. parent_element_dict[lnode] = node_elementTree
  37. parent_element_dict[rnode] = node_elementTree
  38.  
  39. def convert_model(model, input_variables, output_xml):
  40. NTrees = len(model)
  41. var_list = input_variables
  42. var_indices = {}
  43.  
  44. # <MethodSetup>
  45. MethodSetup = ET.Element("MethodSetup", Method="BDT::BDT")
  46.  
  47. # <Variables>
  48. Variables = ET.SubElement(MethodSetup, "Variables", NVar=str(len(var_list)))
  49. for ind, val in enumerate(var_list):
  50. name = val[0]
  51. var_type = val[1]
  52. var_indices[name] = ind
  53. Variable = ET.SubElement(Variables, "Variable", VarIndex=str(ind), Type=val[1],
  54. Expression=name, Label=name, Title=name, Unit="", Internal=name,
  55. Min="0.0e+00", Max="0.0e+00")
  56.  
  57. # <GeneralInfo>
  58. GeneralInfo = ET.SubElement(MethodSetup, "GeneralInfo")
  59. Info_Creator = ET.SubElement(GeneralInfo, "Info", name="Creator", value="xgboost2TMVA")
  60. Info_AnalysisType = ET.SubElement(GeneralInfo, "Info", name="AnalysisType", value="Classification")
  61.  
  62. # <Options>
  63. Options = ET.SubElement(MethodSetup, "Options")
  64. Option_NodePurityLimit = ET.SubElement(Options, "Option", name="NodePurityLimit", modified="No").text = "5.00e-01"
  65. Option_BoostType = ET.SubElement(Options, "Option", name="BoostType", modified="Yes").text = "Grad"
  66.  
  67. # <Weights>
  68. Weights = ET.SubElement(MethodSetup, "Weights", NTrees=str(NTrees), AnalysisType="1")
  69.  
  70. for itree in range(NTrees):
  71. BinaryTree = ET.SubElement(Weights, "BinaryTree", type="DecisionTree", boostWeight="1.0e+00", itree=str(itree))
  72. build_tree(model[itree], BinaryTree, var_indices)
  73.  
  74. tree = ET.ElementTree(MethodSetup)
  75. tree.write(output_xml)
  76. # format it with 'xmllint --format'
  77.  
  78. # example
  79. # bst = xgb.train( param, d_train, num_round, watchlist );
  80. # model = bst.get_dump()
  81. # convert_model(model,input_variables=[('var1','F'),('var2','I')],output_xml='xgboost.xml')
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement