{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# DNN Example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Declare Factory" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from ROOT import TMVA, TFile, TTree, TCut, TString" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "TMVA.Tools.Instance()\n", "\n", "inputFile = TFile.Open(\"https://raw.githubusercontent.com/iml-wg/tmvatutorials/master/inputdata.root\")\n", "outputFile = TFile.Open(\"TMVAOutputDNN.root\", \"RECREATE\")\n", "\n", "factory = TMVA.Factory(\"TMVAClassification\", outputFile,\n", " \"!V:!Silent:Color:!DrawProgressBar:AnalysisType=Classification\" )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Declare Variables in DataLoader" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "loader = TMVA.DataLoader(\"dataset_dnn\")\n", "\n", "loader.AddVariable(\"var1\")\n", "loader.AddVariable(\"var2\")\n", "loader.AddVariable(\"var3\")\n", "loader.AddVariable(\"var4\")\n", "loader.AddVariable(\"var5 := var1-var3\")\n", "loader.AddVariable(\"var6 := var1+var2\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup Dataset(s)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tsignal = inputFile.Get(\"Sig\")\n", "tbackground = inputFile.Get(\"Bkg\")\n", "\n", "loader.AddSignalTree(tsignal)\n", "loader.AddBackgroundTree(tbackground) \n", "loader.PrepareTrainingAndTestTree(TCut(\"\"),\n", " \"nTrain_Signal=1000:nTrain_Background=1000:SplitMode=Random:NormMode=NumEvents:!V\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Configure Network Layout " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# General layout\n", "layoutString = TString(\"Layout=TANH|128,TANH|128,TANH|128,LINEAR\");\n", "\n", "# Training strategies\n", "training0 = TString(\"LearningRate=1e-1,Momentum=0.9,Repetitions=1,\"\n", " \"ConvergenceSteps=2,BatchSize=256,TestRepetitions=10,\"\n", " \"WeightDecay=1e-4,Regularization=L2,\"\n", " \"DropConfig=0.0+0.5+0.5+0.5, Multithreading=True\")\n", "training1 = TString(\"LearningRate=1e-2,Momentum=0.9,Repetitions=1,\"\n", " \"ConvergenceSteps=2,BatchSize=256,TestRepetitions=10,\"\n", " \"WeightDecay=1e-4,Regularization=L2,\"\n", " \"DropConfig=0.0+0.0+0.0+0.0, Multithreading=True\")\n", "trainingStrategyString = TString(\"TrainingStrategy=\")\n", "trainingStrategyString += training0 + TString(\"|\") + training1\n", "\n", "# General Options\n", "dnnOptions = TString(\"!H:!V:ErrorStrategy=CROSSENTROPY:VarTransform=N:\"\n", " \"WeightInitialization=XAVIERUNIFORM\")\n", "dnnOptions.Append(\":\")\n", "dnnOptions.Append(layoutString)\n", "dnnOptions.Append(\":\")\n", "dnnOptions.Append(trainingStrategyString)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Booking Methods" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Standard implementation, no dependencies.\n", "stdOptions = dnnOptions + \":Architecture=CPU\"\n", "factory.BookMethod(loader, TMVA.Types.kDNN, \"DNN\", stdOptions)\n", "\n", "# CPU implementation, using BLAS\n", "#cpuOptions = dnnOptions + \":Architecture=CPU\"\n", "#factory.BookMethod(loader, TMVA.Types.kDNN, \"DNN CPU\", cpuOptions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train Methods" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "factory.TrainAllMethods()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test and Evaluate Methods" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "factory.TestAllMethods()\n", "factory.EvaluateAllMethods()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot ROC Curve\n", "We enable JavaScript visualisation for the plots" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%jsroot on" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "c = factory.GetROCCurve(loader)\n", "c.Draw()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.13" } }, "nbformat": 4, "nbformat_minor": 1 }