{ "cells": [ { "cell_type": "markdown", "source": [ "# Binary classification\n", "\n", "Binary classification of hand-written digits\n", "in Julia." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Setup\n", "\n", "Add the Julia packages used in this demo.\n", "Change `false` to `true` in the following code block\n", "if you are using any of the following packages for the first time." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "if false\n", " import Pkg\n", " Pkg.add([\n", " \"InteractiveUtils\"\n", " \"LaTeXStrings\"\n", " \"LinearAlgebra\"\n", " \"MIRTjim\"\n", " \"MLDatasets\"\n", " \"Plots\"\n", " \"Random\"\n", " \"StatsBase\"\n", " ])\n", "end" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Tell Julia to use the following packages.\n", "Run `Pkg.add()` in the preceding code block first, if needed." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "using InteractiveUtils: versioninfo\n", "using LaTeXStrings # nice plot labels\n", "using LinearAlgebra: dot\n", "using MIRTjim: jim, prompt\n", "using MLDatasets: MNIST\n", "using Plots: default, gui, savefig\n", "using Plots: histogram, histogram!, plot\n", "using Plots: RGB, cgrad\n", "using Plots.PlotMeasures: px\n", "using Random: seed!, randperm\n", "using StatsBase: mean\n", "default(); default(markersize=5, markerstrokecolor=:auto, label=\"\",\n", " tickfontsize=14, labelfontsize=18, legendfontsize=18, titlefontsize=18)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "The following line is helpful when running this file as a script;\n", "this way it will prompt user to hit a key after each figure is displayed." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "isinteractive() ? jim(:prompt, true) : prompt(:draw);" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "## Load data\n", "\n", "Read the MNIST data for some handwritten digits.\n", "This code will automatically download the data from web if needed\n", "and put it in a folder like: `~/.julia/datadeps/MNIST/`." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "if !@isdefined(data)\n", " digitn = (0, 1) # which digits to use\n", " isinteractive() || (ENV[\"DATADEPS_ALWAYS_ACCEPT\"] = true) # avoid prompt\n", " dataset = MNIST(Float32, :train)\n", " nrep = 100 # how many of each digit\n", " # function to extract the 1st `nrep` examples of digit n:\n", " data = n -> dataset.features[:,:,findall(==(n), dataset.targets)[1:nrep]]\n", " data = cat(dims=4, data.(digitn)...)\n", " labels = vcat([fill(d, nrep) for d in digitn]...) # to check later\n", " nx, ny, nrep, ndigit = size(data)\n", " data = data[:,2:ny,:,:] # make images non-square to force debug\n", " ny = size(data,2)\n", " data = reshape(data, nx, ny, :)\n", " seed!(0)\n", " tmp = randperm(nrep * ndigit)\n", " data = data[:,:,tmp]\n", " labels = labels[tmp]\n", " size(data) # (nx, ny, nrep*ndigit)\n", "end" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Look at \"unlabeled\" image data" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "pd = jim(data, \"Data\"; size=(600,300), tickfontsize=8,)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Extract training data" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "data0 = data[:,:,labels .== 0]\n", "data1 = data[:,:,labels .== 1];\n", "\n", "pd0 = jim(data0[:,:,1:36]; nrow=6, colorbar=nothing, size=(400,400))\n", "pd1 = jim(data1[:,:,1:36]; nrow=6, colorbar=nothing, size=(400,400))\n", "# savefig(pd0, \"class01-0.pdf\")\n", "# savefig(pd1, \"class01-1.pdf\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "red-black-blue colorbar:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "RGB255(args...) = RGB((args ./ 255)...)\n", "color = cgrad([RGB255(230, 80, 65), :black, RGB255(23, 120, 232)]);" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "## Weights\n", "Compute sample average of each training class\n", "and define classifier weights as differences of the means." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "μ0 = mean(data0, dims=3)\n", "μ1 = mean(data1, dims=3)\n", "w = μ1 - μ0; # hand-crafted weights" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "images of means and weights" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "siz = (540,400)\n", "args = (xaxis = false, yaxis = false) # book\n", "p0 = jim(μ0; clim=(0,1), size=siz, cticks=[0,1], args...)\n", "p1 = jim(μ1; clim=(0,1), size=siz, cticks=[0,1], args...)\n", "pw = jim(w; color, clim=(-1,1).*0.8, size=siz, cticks=(-1:1)*0.75, args...)\n", "pm = plot( p0, p1, pw;\n", " size = (1400, 350),\n", " layout = (1,3),\n", " rightmargin = 20px,\n", ")\n", "# savefig(p0, \"class01-0.pdf\")\n", "# savefig(p1, \"class01-1.pdf\")\n", "# savefig(pw, \"class01-w.pdf\")\n", "# savefig(pm, \"class01-mean.pdf\")" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "## Inner products\n", "Examine performance of simple linear classifier.\n", "(Should be done with test data, not training data...)" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "i0 = [dot(w, x) for x in eachslice(data0, dims=3)]\n", "i1 = [dot(w, x) for x in eachslice(data1, dims=3)];\n", "\n", "bins = -80:20\n", "ph = plot(\n", " xaxis = (L\"⟨\\mathbf{\\mathit{v}},\\mathbf{\\mathit{x}}⟩\", (-80, 20), -80:20:20),\n", " yaxis = (\"\", (0, 25), 0:10:20),\n", " size = (600, 250), bottommargin = 20px,\n", ")\n", "histogram!(i0; bins, color=:red , label=\"0\")\n", "histogram!(i1; bins, color=:blue, label=\"1\")\n", "\n", "# savefig(ph, \"class01-hist.pdf\")" ], "metadata": {}, "execution_count": null }, { "outputs": [], "cell_type": "code", "source": [ "prompt()" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "---\n", "\n", "*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*" ], "metadata": {} } ], "nbformat_minor": 3, "metadata": { "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.11.1" }, "kernelspec": { "name": "julia-1.11", "display_name": "Julia 1.11.1", "language": "julia" } }, "nbformat": 4 }