Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def string_parser(s):
- if len(re.findall(r":leaf=", s)) == 0:
- out = re.findall(r"[\w.-]+", s)
- tabs = re.findall(r"[\t]+", s)
- if (out[4] == out[8]):
- missing_value_handling = (" or np.isnan(x['" + out[1] + "']) ")
- else:
- missing_value_handling = ""
- if len(tabs) > 0:
- return (re.findall(r"[\t]+", s)[0].replace('\t', ' ') +
- ' if state == ' + out[0] + ':\n' +
- re.findall(r"[\t]+", s)[0].replace('\t', ' ') +
- ' state = (' + out[4] +
- ' if ' + "x['" + out[1] +"']<" + out[2] + missing_value_handling +
- ' else ' + out[6] + ')\n' )
- else:
- return (' if state == ' + out[0] + ':\n' +
- ' state = (' + out[4] +
- ' if ' + "x['" + out[1] +"']<" + out[2] + missing_value_handling +
- ' else ' + out[6] + ')\n' )
- else:
- out = re.findall(r"[\d.-]+", s)
- return (re.findall(r"[\t]+", s)[0].replace('\t', ' ') +
- ' if state == ' + out[0] + ':\n ' +
- re.findall(r"[\t]+", s)[0].replace('\t', ' ') +
- ' return ' + out[1] + '\n')
- def tree_parser(tree, i):
- if i == 0:
- return (' if num_booster == 0:\n state = 0\n'
- + "".join([string_parser(tree.split('\n')[i])
- for i in range(len(tree.split('\n'))-1)]))
- else:
- return (' elif num_booster == '+str(i)+':\n state = 0\n'
- + "".join([string_parser(tree.split('\n')[i])
- for i in range(len(tree.split('\n'))-1)]))
- def model_to_py(base_score, model, out_file):
- trees = model.get_dump()
- result = ["import numpy as np\n\n"
- +"def xgb_tree(x, num_booster):\n"]
- for i in range(len(trees)):
- result.append(tree_parser(trees[i], i))
- with open(out_file, 'a') as the_file:
- the_file.write("".join(result) + "\ndef xgb_predict(x):\n predict = "
- + str(base_score) + "\n"
- + "# initialize prediction with base score\n"
- + " for i in range("
- + str(len(trees))
- + "):\n predict = predict + xgb_tree(x, i)"
- + "\n return predict")
Add Comment
Please, Sign In to add comment