Advertisement
andybuckley

Semi-clean TMVA training from Python

Apr 4th, 2013
159
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.61 KB | None | 0 0
  1. #! /usr/bin/env python
  2.  
  3. import sys, time
  4. from ROOT import TFile, TTree, TCut, TMVA
  5.  
  6. ## Files
  7. inputFile = TFile("tmva_reg_example.root")
  8. outputFile = TFile("out.root", 'RECREATE')
  9.  
  10. ## Create instance of TMVA factory
  11. factory = TMVA.Factory("TMVARegression", outputFile, "!V:!Silent:Color:DrawProgressBar")
  12. factory.SetVerbose(True)
  13.  
  14. ## Define the input variables that shall be used for the MVA training
  15. ## note that you may also use variable expressions, such as: "3*var1/var2*abs(var3)"
  16. ## [all types of expressions that can also be parsed by TTree::Draw( "expression" )]
  17. factory.AddVariable("var1", "Variable 1", "units", 'F')
  18. factory.AddVariable("var2", "Variable 2", "units", 'F')
  19.  
  20. ## You can add so-called "Spectator variables", which are not used in the MVA training,
  21. ## but will appear in the final "TestTree" produced by TMVA. This TestTree will contain the
  22. ## input variables, the response values of all trained MVAs, and the spectator variables
  23. #factory.AddSpectator("spec1:=var1*2", "Spectator 1", "units", 'F')
  24. #factory.AddSpectator("spec2:=var1*3", "Spectator 2", "units", 'F')
  25.  
  26. ## Add the variable carrying the regression target
  27. factory.AddTarget("fvalue")
  28. ## For MLP, can declare additional targets for multi-dimensional regression, e.g.
  29. # factory.AddTarget("fvalue2")
  30.  
  31. ## Set up regression tree
  32. # for k in inputFile.GetListOfKeys():
  33. #    print k
  34. regTree = inputFile.Get("TreeR")
  35. regWeight = 1.0
  36. factory.AddRegressionTree(regTree, regWeight);
  37. ## This would set individual event weights (the variables defined in the
  38. ## expression need to exist in the original TTree)
  39. factory.SetWeightExpression("var1", "Regression")
  40.  
  41. ## Prepare for regression with no cut
  42. ## Cut *could* be e.g. TCut("abs(var1)<0.5 && abs(var2-0.5)<1")
  43. ## Tell the factory to use all remaining events in the trees after training for testing:
  44. factory.PrepareTrainingAndTestTree(TCut(""), "nTrain_Regression=0:nTest_Regression=0:SplitMode=Random:NormMode=NumEvents:!V")
  45. ## If no numbers of events are given, half are used for training, and the other half for testing:
  46. # factory.PrepareTrainingAndTestTree(mycut, "SplitMode=random:!V")
  47.  
  48. ## Set up machine learning method
  49. #config = "!H:!V:VarTransform=Norm:NeuronType=tanh:NCycles=20000:HiddenLayers=N+20"
  50. config = "!H:!V:VarTransform=Norm:NeuronType=tanh:NCycles=20000:HiddenLayers=N+20,N+20"
  51. config += ":TestRate=6:TrainingMethod=BFGS:Sampling=0.3:SamplingEpoch=0.8"
  52. config += ":ConvergenceImprove=1e-6:ConvergenceTests=15:!UseRegulator"
  53. factory.BookMethod(TMVA.Types.kMLP, "MLP", config)
  54. # TODO: add CalculateError
  55. #factory.BookMethod( TMVA.Types.kMLP, "MLPBFGS", "H:!V:NeuronType=tanh:VarTransform=N:NCycles=600:HiddenLayers=N+5:TestRate=5:TrainingMethod=BFGS:!UseRegulator" )
  56. #factory.BookMethod( TMVA.Types.kMLP, "MLPBNN", "H:!V:NeuronType=tanh:VarTransform=N:NCycles=600:HiddenLayers=N+5:TestRate=5:TrainingMethod=BFGS:UseRegulator" ) # BFGS training with bayesian regulators
  57. #
  58. #factory.BookMethod(TMVA.Types.kSVM, "SVM", "Gamma=0.25:Tol=0.001:VarTransform=Norm")
  59. #factory.BookMethod(TMVA.Types.kBDT, "BDT", "!H:!V:NTrees=100:nEventsMin=5:BoostType=AdaBoostR2:SeparationType=RegressionVariance:nCuts=20:PruneMethod=CostComplexity:PruneStrength=30")
  60. #factory.BookMethod( TMVA.Types.kBDT, "BDTG", "!H:!V:NTrees=2000::BoostType=Grad:Shrinkage=0.1:UseBaggedGrad:GradBaggingFraction=0.5:nCuts=20:MaxDepth=3:NNodesMax=15")
  61.  
  62. ## Train MVAs using the set of training events
  63. factory.TrainAllMethods()
  64. ## Evaluate all MVAs using the set of test events
  65. factory.TestAllMethods()
  66. ## Evaluate and compare performance of all configured MVAs
  67. factory.EvaluateAllMethods()
  68.  
  69. ## Save the output
  70. outputFile.Close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement