{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from ROOT import TMVA, TFile, TTree, TCut, TString, TH1F, TCanvas, gStyle\n", "gStyle.SetOptStat(0)\n", "%jsroot on\n", "from subprocess import call\n", "from os.path import isfile\n", "from array import array\n", "\n", "from keras.models import Sequential\n", "from keras.layers import Dense" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "TMVA.Tools.Instance()\n", "TMVA.PyMethodBase.PyInitialize()\n", "\n", "output = TFile.Open('TMVA.root', 'RECREATE')\n", "factory = TMVA.Factory('TMVARegression', output,\n", " '!V:!Silent:Color:!DrawProgressBar:Transformations=D,G:AnalysisType=Regression')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if not isfile('tmva_reg_example.root'):\n", " call(['curl', '-O', 'http://root.cern.ch/files/tmva_reg_example.root'])\n", "\n", "data = TFile.Open('tmva_reg_example.root')\n", "tree = data.Get('TreeR')\n", "\n", "dataloader = TMVA.DataLoader('dataset')\n", "for branch in tree.GetListOfBranches():\n", " name = branch.GetName()\n", " if name != 'fvalue':\n", " dataloader.AddVariable(name)\n", "dataloader.AddTarget('fvalue')\n", "\n", "dataloader.AddRegressionTree(tree, 1.0)\n", "dataloader.PrepareTrainingAndTestTree(TCut(''),\n", " 'nTrain_Regression=4000:SplitMode=Random:NormMode=NumEvents:!V')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Sequential()\n", "model.add(Dense(64, activation='tanh', input_dim=2))\n", "model.add(Dense(1, activation='linear'))\n", "\n", "model.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.compile(loss='mean_squared_error', optimizer=\"sgd\")\n", "\n", "model.save(\"model.h5\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "factory.BookMethod(dataloader, TMVA.Types.kPyKeras, 'PyKeras',\n", " 'H:!V:VarTransform=D,G:FilenameModel=model.h5:NumEpochs=20:BatchSize=32')\n", "factory.BookMethod(dataloader, TMVA.Types.kBDT, 'BDTG',\n", " '!H:!V:NTrees=1000:nCuts=20:MaxDepth=3')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "factory.TrainAllMethods()\n", "factory.TestAllMethods()\n", "factory.EvaluateAllMethods()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "reader = TMVA.Reader(\"Color:!Silent\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "branches = {}\n", "for branch in tree.GetListOfBranches():\n", " branchName = branch.GetName()\n", " branches[branchName] = array('f', [-999])\n", " tree.SetBranchAddress(branchName, branches[branchName])\n", " if branchName != 'fvalue':\n", " reader.AddVariable(branchName, branches[branchName])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "reader.BookMVA('PyKeras', TString('dataset/weights/TMVARegression_PyKeras.weights.xml'))\n", "reader.BookMVA('BDTG', TString('dataset/weights/TMVARegression_BDTG.weights.xml'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "c = TCanvas()\n", "h_keras = TH1F(\"PyKeras\", \"PyKeras\", 30, 0, 400)\n", "h_bdtg = TH1F(\"BDTG\", \"BDTG\", 30, 0, 400)\n", "h_truth = TH1F(\"Ground truth\", \"Ground truth\", 30, 0, 400)\n", "for i in range(tree.GetEntries()):\n", " tree.GetEntry(i)\n", " h_keras.Fill(reader.EvaluateMVA(\"PyKeras\"))\n", " h_bdtg.Fill(reader.EvaluateMVA(\"BDTG\"))\n", " h_truth.Fill(branches[\"fvalue\"][0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "h_bdtg.SetTitle(\"BDTG vs PyKeras\")\n", "h_bdtg.GetXaxis().SetTitle(\"Target\")\n", "h_bdtg.Draw()\n", "h_keras.Draw(\"SAME\")\n", "h_truth.Draw(\"SAME\")\n", "c.Draw()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.13" } }, "nbformat": 4, "nbformat_minor": 2 }