{ "cells": [ { "cell_type": "code", "execution_count": 196, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: using MLDataPattern.obsview in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.BatchView in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.RandomObs in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.batchsize in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.undersample in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.oversample in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.randobs in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.ObsView in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.RandomBatches in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.eachbatch in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.DataSubset in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.BufferGetObs in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.kfolds in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.splitobs in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.getobs! in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.batchview in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.shuffleobs in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.leaveout in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.datasubset in module Main conflicts with an existing identifier.\n", "WARNING: using MLDataPattern.eachobs in module Main conflicts with an existing identifier.\n" ] }, { "data": { "text/plain": [ "Plots.GRBackend()" ] }, "execution_count": 196, "metadata": {}, "output_type": "execute_result" } ], "source": [ "using SwiftObjectStores\n", "using ColoringNames\n", "using Distributions\n", "using MLDataPattern\n", "using Iterators\n", "using MLLabelUtils\n", "using StaticArrays\n", "using Juno\n", "using StatsBase\n", "using Colors\n", "using DataFrames\n", "using Query\n", "using Plots\n", "gr()" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[1m\u001b[34mINFO: Recompiling stale cache file /home/ubuntu/.julia/lib/v0.5/MLLabelUtils.ji for module MLLabelUtils.\n", "\u001b[0m\u001b[1m\u001b[34mINFO: Recompiling stale cache file /home/ubuntu/.julia/lib/v0.5/MLDataPattern.ji for module MLDataPattern.\n", "\u001b[0mWARNING: Method definition getobs(Any) in module MLDataPattern at /home/ubuntu/.julia/v0.5/MLDataPattern/src/container.jl:1 overwritten in module MLDataUtils at /home/ubuntu/.julia/v0.5/MLDataUtils/src/accesspattern/datasubset.jl:1.\n", "WARNING: Method definition nobs(Any) in module MLDataPattern at /home/ubuntu/.julia/v0.5/MLDataPattern/src/container.jl:41 overwritten in module MLDataUtils at /home/ubuntu/.julia/v0.5/MLDataUtils/src/accesspattern/datasubset.jl:30.\n", "WARNING: Method definition #nobs(Array{Any, 1}, StatsBase.#nobs, Any) in module MLDataPattern overwritten in module MLDataUtils.\n", "WARNING: Method definition getobs(Any, Any) in module MLDataPattern at /home/ubuntu/.julia/v0.5/MLDataPattern/src/container.jl:67 overwritten in module MLDataUtils at /home/ubuntu/.julia/v0.5/MLDataUtils/src/accesspattern/datasubset.jl:37.\n", "WARNING: Method definition getobs(Base.SubArray) in module MLDataPattern at /home/ubuntu/.julia/v0.5/MLDataPattern/src/container.jl:92 overwritten in module MLDataUtils at /home/ubuntu/.julia/v0.5/MLDataUtils/src/accesspattern/datasubset.jl:385.\n", "WARNING: Method definition nobs(Tuple) in module MLDataPattern at /home/ubuntu/.julia/v0.5/MLDataPattern/src/container.jl:173 overwritten in module MLDataUtils at /home/ubuntu/.julia/v0.5/MLDataUtils/src/accesspattern/datasubset.jl:445.\n", "WARNING: Method definition nobs(Tuple, Tuple) in module MLDataPattern at /home/ubuntu/.julia/v0.5/MLDataPattern/src/container.jl:183 overwritten in module MLDataUtils at /home/ubuntu/.julia/v0.5/MLDataUtils/src/accesspattern/datasubset.jl:455.\n", "WARNING: Method definition getobs(Tuple, Any) in module MLDataPattern at /home/ubuntu/.julia/v0.5/MLDataPattern/src/container.jl:198 overwritten in module MLDataUtils at /home/ubuntu/.julia/v0.5/MLDataUtils/src/accesspattern/datasubset.jl:462.\n", "WARNING: Method definition getobs(Tuple, Any, Tuple) in module MLDataPattern at /home/ubuntu/.julia/v0.5/MLDataPattern/src/container.jl:208 overwritten in module MLDataUtils at /home/ubuntu/.julia/v0.5/MLDataUtils/src/accesspattern/datasubset.jl:472.\n", "\u001b[1m\u001b[34mINFO: Recompiling stale cache file /home/ubuntu/.julia/lib/v0.5/SortingAlgorithms.ji for module SortingAlgorithms.\n", "\u001b[0m\u001b[1m\u001b[34mINFO: Recompiling stale cache file /home/ubuntu/.julia/lib/v0.5/DataFrames.ji for module DataFrames.\n", "\u001b[0mWARNING: Method definition describe(AbstractArray) in module StatsBase at /home/ubuntu/.julia/v0.5/StatsBase/src/scalarstats.jl:573 overwritten in module DataFrames at /home/ubuntu/.julia/v0.5/DataFrames/src/abstractdataframe/abstractdataframe.jl:407.\n", "WARNING: Method definition describe(AbstractArray) in module StatsBase at /home/ubuntu/.julia/v0.5/StatsBase/src/scalarstats.jl:573 overwritten in module DataFrames at /home/ubuntu/.julia/v0.5/DataFrames/src/abstractdataframe/abstractdataframe.jl:407.\n", "WARNING: Method definition require(Symbol) in module Base at loading.jl:345 overwritten in module Query at /home/ubuntu/.julia/v0.5/Requires/src/require.jl:12.\n" ] }, { "data": { "text/plain": [ "1523108-element Array{Any,1}:\n", " \"acid green\"\n", " \"acid green\"\n", " \"acid green\"\n", " \"acid green\"\n", " \"acid green\"\n", " \"acid green\"\n", " \"acid green\"\n", " \"acid green\"\n", " \"acid green\"\n", " \"acid green\"\n", " \"acid green\"\n", " \"acid green\"\n", " \"acid green\"\n", " ⋮ \n", " \"yuck\" \n", " \"yuck\" \n", " \"yuck\" \n", " \"yuck\" \n", " \"yuck\" \n", " \"yuck\" \n", " \"yuck\" \n", " \"yuck\" \n", " \"yuck\" \n", " \"yuck\" \n", " \"yuck\" \n", " \"yuck\" " ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "\n", "const od =(MLDataUtils.ObsDim.First(), MLDataUtils.ObsDim.Last())\n", "const serv=SwiftService()\n", "\n", "const valid_raw = get_file(fh->readdlm(fh,'\\t'), serv, \"color\", \"monroe/dev.csv\")\n", "const valid_hsv, valid_terms_padded, encoding = prepare_data(valid_raw; do_demacate=false)\n", "const valid_text = valid_raw[:, 1]\n", "\n", "const train_raw = get_file(fh->readdlm(fh,'\\t'), serv, \"color\", \"monroe/train.csv\")\n", "const train_hsv, train_terms_padded, encoding = prepare_data(train_raw, encoding; do_demacate=false)\n", "const train_text = train_raw[:, 1]\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "name\n", "Summary Stats:\n", "Length: 829\n", "Type: String\n", "Number Unique: 829\n", "\n", "n_samples\n", "Summary Stats:\n", "Mean: 1837.283474\n", "Minimum: 70.000000\n", "1st Quartile: 109.000000\n", "Median: 214.000000\n", "3rd Quartile: 627.000000\n", "Maximum: 152953.000000\n", "Length: 829\n", "Type: Int64\n", "\n", "hs\n", "Summary Stats:\n", "Mean: 0.022951\n", "Minimum: -0.429653\n", "1st Quartile: -0.086345\n", "Median: 0.019934\n", "3rd Quartile: 0.121130\n", "Maximum: 0.577041\n", "Length: 829\n", "Type: Float64\n", "\n", "hv\n", "Summary Stats:\n", "Mean: 0.000171\n", "Minimum: -0.577392\n", "1st Quartile: -0.101860\n", "Median: 0.003642\n", "3rd Quartile: 0.103550\n", "Maximum: 0.497044\n", "Length: 829\n", "Type: Float64\n", "\n", "vs\n", "Summary Stats:\n", "Mean: -0.048629\n", "Minimum: -0.497048\n", "1st Quartile: -0.132881\n", "Median: -0.050129\n", "3rd Quartile: 0.031004\n", "Maximum: 0.455614\n", "Length: 829\n", "Type: Float64\n", "\n" ] } ], "source": [ "\n", "\n", "function pairwise_stats(fun, names, hsvs)\n", " dt = DataFrame(name=String[], n_samples=Int[], hs=Float64[], hv=Float64[], vs=Float64[])\n", " @progress for (name, inds) in labelmap(names)\n", " eg_hsvs = @view hsvs[inds, :]\n", "\n", " hs = fun(eg_hsvs[:,1], eg_hsvs[:,2])\n", " hv = fun(eg_hsvs[:,1], eg_hsvs[:,3])\n", " vs = fun(eg_hsvs[:,3], eg_hsvs[:,2])\n", "\n", " push!(dt, [name, length(inds), hs, hv, vs])\n", " end\n", " dt\n", "end\n", "spearman = pairwise_stats(corspearman, train_text, train_hsv)\n", "\n", "\n", "\n", "describe(spearman)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "train_cols = hsv2colorant(train_hsv)|>vec" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "pairwise_fields (generic function with 1 method)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function pairwise_fields(T)\n", " ns = fieldnames(T)\n", " pairwise_ns = Tuple{Symbol,Symbol}[]\n", " for (ii, n1) in enumerate(ns)\n", " for (jj, n2) in enumerate(ns)\n", " jj<=ii && continue\n", " push!(pairwise_ns, (n1,n2))\n", " end\n", " end\n", " pairwise_ns\n", "end" ] }, { "cell_type": "code", "execution_count": 235, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: Method definition pairwise_stats(Any, Any, Array{#T<:Any, 1}) in module Main at In[233]:2 overwritten at In[235]:2.\n" ] }, { "data": { "text/plain": [ "pairwise_stats (generic function with 2 methods)" ] }, "execution_count": 235, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function pairwise_stats{T}(fun, names, data::Vector{T})\n", " pairwise_ns = pairwise_fields(T)\n", " dt = DataFrame([String;Int;fill(Float64, 2*length(pairwise_ns))], \n", " [:name; :n_samples; \n", " [Symbol(n1,n2) for (n1,n2) in pairwise_ns];\n", " [Symbol(:abs_,n1,n2) for (n1,n2) in pairwise_ns]],\n", " 0)\n", " \n", " for (name, inds) in labelmap(names)\n", " sub = @view data[inds]\n", " row=Dict{Symbol, Any}(:name=>name, :n_samples=>length(inds))\n", " for (n1,n2) in pairwise_ns\n", " v1 = getfield.(sub, Scalar(n1))\n", " v2 = getfield.(sub, Scalar(n2))\n", " val = fun(v1, v2)\n", " row[Symbol(n1,n2)] = val\n", " row[Symbol(:abs_, n1,n2)] = abs(val)\n", " end\n", " push!(dt, row)\n", " end\n", " dt\n", "end\n" ] }, { "cell_type": "code", "execution_count": 263, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: Method definition Q1(Any) in module Main at In[262]:3 overwritten at In[263]:2.\n" ] }, { "data": { "text/plain": [ "Q3 (generic function with 1 method)" ] }, "execution_count": 263, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Q1(v)=quantile(v,0.25)\n", "Q3(v)=quantile(v,0.75)" ] }, { "cell_type": "code", "execution_count": 285, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING: Method definition summarise(Any, Any, Any) in module Main at In[280]:2 overwritten at In[285]:2.\n" ] }, { "data": { "text/html": [ "
name_stat | f12 | f13 | f23 | |
---|---|---|---|---|
1 | RGB_Q3 | 0.6030468173161735 | 0.4471926980645492 | 0.5655902318754856 |
2 | HSV_Q3 | 0.18613341836244784 | 0.18665944073636134 | 0.1627735443944819 |
3 | HSI_Q3 | 0.2446086788211584 | 0.23911783105162776 | 0.6301694579237653 |
4 | HSL_Q3 | 0.16551731146099502 | 0.2147074863173264 | 0.3112894699855042 |
5 | xyY_Q3 | 0.7230260376798017 | 0.5024120339688362 | 0.41654070660522274 |
6 | XYZ_Q3 | 0.9726011334766465 | 0.8166502753828205 | 0.7843590386564617 |
7 | xyY_Q3 | 0.7230260376798017 | 0.5024120339688362 | 0.41654070660522274 |
8 | Lab_Q3 | 0.5730180530926171 | 0.4597498889534288 | 0.6390289755602638 |
9 | Luv_Q3 | 0.5597935438719273 | 0.611177475128859 | 0.4379400128184171 |
10 | LCHab_Q3 | 0.5258344651603079 | 0.411030235162374 | 0.3687582570400751 |
11 | LCHuv_Q3 | 0.6123511957186991 | 0.4071940339464407 | 0.34158057511350925 |
12 | DIN99_Q3 | 0.5449474335188621 | 0.493092269212655 | 0.5235018403496664 |
13 | DIN99d_Q3 | 0.5442307692307692 | 0.4426073009880284 | 0.4802693716414654 |
14 | DIN99o_Q3 | 0.5608196488954583 | 0.40822381791085216 | 0.5211370440261909 |
15 | LMS_Q3 | 0.9680367104231742 | 0.7457534246575342 | 0.7789878283151825 |
16 | YIQ_Q3 | 0.4088309593708092 | 0.49752505716774 | 0.4064337083345767 |
17 | YCbCr_Q3 | 0.4004605183428265 | 0.43929499072356215 | 0.3377021554829844 |