{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Soft clustering of audio using LDA and bag-of-audio words\n", "\n", "If you find yourself in a situation where you have a massive amount of audio that you would like to somehow categorize, you might be tempted at doing clustering. However, an audio sample might contain more than one kind of sound, so putting it in a single cluster might not be warranted. Latent Dirichlet Allocation (LDA) is an unsupervised \"clustering\" algorithm, often applied to text, that allows you to describe a datapoint (called a \"document\") as a sparse combination of \"basis documents\". Each basis document, in turn, is a sparse combination of \"words\". In the audio context, we do not really have a well-defined notion of a word, but we can invent such a notion. In this notebook, an audio word is a frequency component. An \"audio document\" is thus made up out of \"frequency terms\". This will allow us to run LDA on a set of audio samples. The dataset I use is rather large, so I will not upload it to github, I can however share it with you in case you do not have a suitable dataset yourself. The dataset is made up out of 3000 short (~0.6s) audio clips of bird calls.\n", "\n", "We start by defining a function to load the data." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "get_audio_words (generic function with 2 methods)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "using Glob, WAV, Serialization, TextAnalysis, SparseArrays, Peaks, AMD\n", "using Grep\n", "using DSP, LPVSpectral\n", "using AudioClustering, SpectralDistances\n", "\n", "const fs = 44100 # sample rate\n", "const minpeak = 1e-6 # the minimum power to cinsider in a spectrum\n", "\n", "function get_audio_words(path, load_from_disc=true)\n", "\n", " cd(path)\n", " files = glob(\"*.wav\")\n", " labels0 = match.(r\"[a-z_]+\", files)..:match .|> String\n", " ulabels = unique(labels0)\n", " labels = sum((labels0 .== reshape(ulabels,1,:)) .* (1:30)', dims=2)[:]\n", " M = mel(fs, 512)\n", " if !load_from_disc # process and save all files\n", " models = mapsoundfiles(files) do sound\n", " sound = vec(sound[findfirst(!iszero, sound):findlast(!iszero, sound)])\n", " sound = SpectralDistances.bp_filter(sound, (50/fs, 18000/fs))\n", " P = welch_pgram(sound, 512, fs=fs)\n", " spectral_fingerprint(M,P)\n", " end;\n", " serialize(\"audiowords\", models)\n", " else\n", " models = deserialize(\"audiowords\")\n", " end\n", "\n", " labels,models\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The \"words\" we'll use will be the frequencies of peaks in the spectrum of the audio clip. The function below extract those peaks, their location and hight." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "spectral_fingerprint (generic function with 2 methods)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function spectral_fingerprint(M,P, l = 8)\n", " power = M*P.power # Transform to a Mel spectrum\n", " power ./= std(power)\n", " p,prom = peakprom(power, Maxima(), 7, minpeak)\n", " perm = sortperm(prom, rev=true)\n", "\n", " p = p[perm[1:min(l, length(p))]]\n", " (p, power[p])\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we are ready to load the data and calculate the features." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "path = \"/home/fredrikb/kokulbirds/test_padded_30birds/\"\n", "labels,words = get_audio_words(path,true);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are now ready to run LDA. The algorithm takes three parameters, the number of topics and two parameters related to the sparsity of the topics and \"documents\". The function `lda` expects a special data structure, `DocumentTermMatrix` as input, so we create that first." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_terms = 128 # The number of terms in the dictionary is the number of mel-frequencies in the spectra\n", "n_docs = length(words)\n", "terms = collect(1:n_terms) # We enumerate the terms starting from 1\n", "inner_dtm = spzeros(Int, n_docs, n_terms)\n", "\n", "for di in 1:n_docs # Populate the document term matrix\n", " for (ti, tc) in zip(words[di]...)\n", " inner_dtm[di,ti] = round(Int, log10.(tc/minpeak))\n", " end\n", "end\n", "\n", "dtm = DocumentTermMatrix(inner_dtm,terms)\n", "\n", "ntopics = 60 # We choose 60 topics, even though we know there are 30 classes in the \"corpus\"\n", "iteration = 1000\n", "α = 0.1/ntopics # Dirichlet dist. hyperparameter for topic distribution per document. `α<1` yields a sparse topic mixture for each document. `α>1` yields a more uniform topic mixture for each document.\n", "β = 0.01 # Dirichlet dist. hyperparameter for word distribution per topic. `β<1` yields a sparse word mixture for each topic. `β>1` yields a more uniform word mixture for each topic.\n", "\n", "ϕ, θ = lda(dtm, ntopics, iteration, α, β);" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/png": "", "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "500\n", "\n", "\n", "1000\n", "\n", "\n", "1500\n", "\n", "\n", "2000\n", "\n", "\n", "2500\n", "\n", "\n", "3000\n", "\n", "\n", "10\n", "\n", "\n", "20\n", "\n", "\n", "30\n", "\n", "\n", "40\n", "\n", "\n", "50\n", "\n", "\n", "60\n", "\n", "\n", "Audio sample\n", "\n", "\n", "Topic\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "0\n", "\n", "\n", "0.1\n", "\n", "\n", "0.2\n", "\n", "\n", "0.3\n", "\n", "\n", "0.4\n", "\n", "\n", "0.5\n", "\n", "\n", "0.6\n", "\n", "\n", "0.7\n", "\n", "\n", "0.8\n", "\n", "\n", "0.9\n", "\n", "\n", "1.0\n", "\n", "\n", "\n" ], "text/html": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "500\n", "\n", "\n", "1000\n", "\n", "\n", "1500\n", "\n", "\n", "2000\n", "\n", "\n", "2500\n", "\n", "\n", "3000\n", "\n", "\n", "10\n", "\n", "\n", "20\n", "\n", "\n", "30\n", "\n", "\n", "40\n", "\n", "\n", "50\n", "\n", "\n", "60\n", "\n", "\n", "Audio sample\n", "\n", "\n", "Topic\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "0\n", "\n", "\n", "0.1\n", "\n", "\n", "0.2\n", "\n", "\n", "0.3\n", "\n", "\n", "0.4\n", "\n", "\n", "0.5\n", "\n", "\n", "0.6\n", "\n", "\n", "0.7\n", "\n", "\n", "0.8\n", "\n", "\n", "0.9\n", "\n", "\n", "1.0\n", "\n", "\n", "\n" ], "text/plain": [ "Plot{Plots.GRBackend() n=1}" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "heatmap(θ, xlabel=\"Audio sample\", ylabel=\"Topic\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It takes a while to run this algorithm, should be done in a minute or so. The result is two matrices\n", "- `ϕ`: `ntopics × nwords` Sparse matrix of probabilities s.t. `sum(ϕ, 1) == 1`\n", "- `θ`: `ntopics × ndocs` Dense matrix of probabilities s.t. `sum(θ, 1) == 1`\n", "The heatmap of `θ` indicates what topics are ascribed to each \"document\" (audio sample). You should be able to see clusters of bright spots in this image, it's because the audio samples were sorted by class. In reality this picture would look very scrambled and the analysis below would be needed to find similar samples.\n", "\n", "Next, we investigate the correlation between different topics based on both which words they contain and which audio samples contained the topics." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "", "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "10\n", "\n", "\n", "20\n", "\n", "\n", "30\n", "\n", "\n", "40\n", "\n", "\n", "50\n", "\n", "\n", "60\n", "\n", "\n", "10\n", "\n", "\n", "20\n", "\n", "\n", "30\n", "\n", "\n", "40\n", "\n", "\n", "50\n", "\n", "\n", "60\n", "\n", "\n", "Correlation based on sample similarities\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "0\n", "\n", "\n", "0.1\n", "\n", "\n", "0.2\n", "\n", "\n", "0.3\n", "\n", "\n", "0.4\n", "\n", "\n", "0.5\n", "\n", "\n", "0.6\n", "\n", "\n", "0.7\n", "\n", "\n", "0.8\n", "\n", "\n", "0.9\n", "\n", "\n", "1.0\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "10\n", "\n", "\n", "20\n", "\n", "\n", "30\n", "\n", "\n", "40\n", "\n", "\n", "50\n", "\n", "\n", "60\n", "\n", "\n", "10\n", "\n", "\n", "20\n", "\n", "\n", "30\n", "\n", "\n", "40\n", "\n", "\n", "50\n", "\n", "\n", "60\n", "\n", "\n", "Correlation based on word similarities\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "0\n", "\n", "\n", "0.1\n", "\n", "\n", "0.2\n", "\n", "\n", "0.3\n", "\n", "\n", "0.4\n", "\n", "\n", "0.5\n", "\n", "\n", "0.6\n", "\n", "\n", "0.7\n", "\n", "\n", "0.8\n", "\n", "\n", "0.9\n", "\n", "\n", "1.0\n", "\n", "\n", "\n" ], "text/html": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "10\n", "\n", "\n", "20\n", "\n", "\n", "30\n", "\n", "\n", "40\n", "\n", "\n", "50\n", "\n", "\n", "60\n", "\n", "\n", "10\n", "\n", "\n", "20\n", "\n", "\n", "30\n", "\n", "\n", "40\n", "\n", "\n", "50\n", "\n", "\n", "60\n", "\n", "\n", "Correlation based on sample similarities\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "0\n", "\n", "\n", "0.1\n", "\n", "\n", "0.2\n", "\n", "\n", "0.3\n", "\n", "\n", "0.4\n", "\n", "\n", "0.5\n", "\n", "\n", "0.6\n", "\n", "\n", "0.7\n", "\n", "\n", "0.8\n", "\n", "\n", "0.9\n", "\n", "\n", "1.0\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "10\n", "\n", "\n", "20\n", "\n", "\n", "30\n", "\n", "\n", "40\n", "\n", "\n", "50\n", "\n", "\n", "60\n", "\n", "\n", "10\n", "\n", "\n", "20\n", "\n", "\n", "30\n", "\n", "\n", "40\n", "\n", "\n", "50\n", "\n", "\n", "60\n", "\n", "\n", "Correlation based on word similarities\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "0\n", "\n", "\n", "0.1\n", "\n", "\n", "0.2\n", "\n", "\n", "0.3\n", "\n", "\n", "0.4\n", "\n", "\n", "0.5\n", "\n", "\n", "0.6\n", "\n", "\n", "0.7\n", "\n", "\n", "0.8\n", "\n", "\n", "0.9\n", "\n", "\n", "1.0\n", "\n", "\n", "\n" ], "text/plain": [ "Plot{Plots.GRBackend() n=2}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function corr(x::AbstractMatrix)\n", " d = diag(x)\n", " y = copy(x)\n", " for i = 1:length(d), j = 1:length(d)\n", " y[i,j] /= sqrt(x[i,i]*x[j,j])\n", " end\n", " y\n", "end\n", "topic_cov_by_words = Matrix(ϕ*ϕ')\n", "topic_cov_by_sample = Matrix(θ*θ')\n", "topic_corr_by_words = corr(topic_cov_by_words)\n", "topic_corr_by_sample = corr(topic_cov_by_sample)\n", "\n", "plot(\n", " heatmap(topic_corr_by_sample, title=\"Correlation based on sample similarities\", titlefont=10),\n", " heatmap(topic_corr_by_words, title=\"Correlation based on word similarities\", titlefont=10)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Bright off-diagonal spots in the images above indicate that two topics share something in common.\n", "\n", "To find an average topic vector for each class, we average over all the class samples. We can do this since we have class labels associated with the samples. In practice, this might be much harder." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "", "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "10\n", "\n", "\n", "20\n", "\n", "\n", "30\n", "\n", "\n", "40\n", "\n", "\n", "50\n", "\n", "\n", "60\n", "\n", "\n", "1\n", "\n", "\n", "2\n", "\n", "\n", "3\n", "\n", "\n", "4\n", "\n", "\n", "5\n", "\n", "\n", "6\n", "\n", "\n", "7\n", "\n", "\n", "8\n", "\n", "\n", "9\n", "\n", "\n", "10\n", "\n", "\n", "11\n", "\n", "\n", "12\n", "\n", "\n", "13\n", "\n", "\n", "14\n", "\n", "\n", "15\n", "\n", "\n", "16\n", "\n", "\n", "17\n", "\n", "\n", "18\n", "\n", "\n", "19\n", "\n", "\n", "20\n", "\n", "\n", "21\n", "\n", "\n", "22\n", "\n", "\n", "23\n", "\n", "\n", "24\n", "\n", "\n", "25\n", "\n", "\n", "26\n", "\n", "\n", "27\n", "\n", "\n", "28\n", "\n", "\n", "29\n", "\n", "\n", "30\n", "\n", "\n", "Topic\n", "\n", "\n", "Class\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "0\n", "\n", "\n", "0.1\n", "\n", "\n", "0.2\n", "\n", "\n", "0.3\n", "\n", "\n", "0.4\n", "\n", "\n", "0.5\n", "\n", "\n", "0.6\n", "\n", "\n", "0.7\n", "\n", "\n", "0.8\n", "\n", "\n", "0.9\n", "\n", "\n", "\n" ], "text/html": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "10\n", "\n", "\n", "20\n", "\n", "\n", "30\n", "\n", "\n", "40\n", "\n", "\n", "50\n", "\n", "\n", "60\n", "\n", "\n", "1\n", "\n", "\n", "2\n", "\n", "\n", "3\n", "\n", "\n", "4\n", "\n", "\n", "5\n", "\n", "\n", "6\n", "\n", "\n", "7\n", "\n", "\n", "8\n", "\n", "\n", "9\n", "\n", "\n", "10\n", "\n", "\n", "11\n", "\n", "\n", "12\n", "\n", "\n", "13\n", "\n", "\n", "14\n", "\n", "\n", "15\n", "\n", "\n", "16\n", "\n", "\n", "17\n", "\n", "\n", "18\n", "\n", "\n", "19\n", "\n", "\n", "20\n", "\n", "\n", "21\n", "\n", "\n", "22\n", "\n", "\n", "23\n", "\n", "\n", "24\n", "\n", "\n", "25\n", "\n", "\n", "26\n", "\n", "\n", "27\n", "\n", "\n", "28\n", "\n", "\n", "29\n", "\n", "\n", "30\n", "\n", "\n", "Topic\n", "\n", "\n", "Class\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "0\n", "\n", "\n", "0.1\n", "\n", "\n", "0.2\n", "\n", "\n", "0.3\n", "\n", "\n", "0.4\n", "\n", "\n", "0.5\n", "\n", "\n", "0.6\n", "\n", "\n", "0.7\n", "\n", "\n", "0.8\n", "\n", "\n", "0.9\n", "\n", "\n", "\n" ], "text/plain": [ "Plot{Plots.GRBackend() n=1}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ulabels = unique(labels)\n", "classvecs = map(ulabels) do label\n", " classinds = findall(labels .== label)\n", " vec(mean(θ[:,classinds], dims=2))\n", "end\n", "classvecs = reduce(hcat,classvecs) # n_topics × n_class\n", "heatmap(classvecs',yticks=(1:length(ulabels), ulabels), xlabel=\"Topic\", ylabel=\"Class\", size=(400,400), color=:blues)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also try to find similarities between classes based on which topics are occuring for samples within the class. To make it easier to spot the similarities, we run an AMD algorithm on the similarity matrix, so that similar classes are likely to be placed next to each other. " ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "", "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "21\n", "\n", "\n", "12\n", "\n", "\n", "11\n", "\n", "\n", "5\n", "\n", "\n", "1\n", "\n", "\n", "25\n", "\n", "\n", "16\n", "\n", "\n", "7\n", "\n", "\n", "8\n", "\n", "\n", "6\n", "\n", "\n", "29\n", "\n", "\n", "13\n", "\n", "\n", "28\n", "\n", "\n", "4\n", "\n", "\n", "3\n", "\n", "\n", "9\n", "\n", "\n", "10\n", "\n", "\n", "14\n", "\n", "\n", "17\n", "\n", "\n", "18\n", "\n", "\n", "19\n", "\n", "\n", "20\n", "\n", "\n", "22\n", "\n", "\n", "23\n", "\n", "\n", "24\n", "\n", "\n", "26\n", "\n", "\n", "27\n", "\n", "\n", "30\n", "\n", "\n", "2\n", "\n", "\n", "15\n", "\n", "\n", "21\n", "\n", "\n", "12\n", "\n", "\n", "11\n", "\n", "\n", "5\n", "\n", "\n", "1\n", "\n", "\n", "25\n", "\n", "\n", "16\n", "\n", "\n", "7\n", "\n", "\n", "8\n", "\n", "\n", "6\n", "\n", "\n", "29\n", "\n", "\n", "13\n", "\n", "\n", "28\n", "\n", "\n", "4\n", "\n", "\n", "3\n", "\n", "\n", "9\n", "\n", "\n", "10\n", "\n", "\n", "14\n", "\n", "\n", "17\n", "\n", "\n", "18\n", "\n", "\n", "19\n", "\n", "\n", "20\n", "\n", "\n", "22\n", "\n", "\n", "23\n", "\n", "\n", "24\n", "\n", "\n", "26\n", "\n", "\n", "27\n", "\n", "\n", "30\n", "\n", "\n", "2\n", "\n", "\n", "15\n", "\n", "\n", "Class similarity\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "0.25\n", "\n", "\n", "0.50\n", "\n", "\n", "0.75\n", "\n", "\n", "1.00\n", "\n", "\n", "1.25\n", "\n", "\n", "1.50\n", "\n", "\n", "1.75\n", "\n", "\n", "\n" ], "text/html": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "21\n", "\n", "\n", "12\n", "\n", "\n", "11\n", "\n", "\n", "5\n", "\n", "\n", "1\n", "\n", "\n", "25\n", "\n", "\n", "16\n", "\n", "\n", "7\n", "\n", "\n", "8\n", "\n", "\n", "6\n", "\n", "\n", "29\n", "\n", "\n", "13\n", "\n", "\n", "28\n", "\n", "\n", "4\n", "\n", "\n", "3\n", "\n", "\n", "9\n", "\n", "\n", "10\n", "\n", "\n", "14\n", "\n", "\n", "17\n", "\n", "\n", "18\n", "\n", "\n", "19\n", "\n", "\n", "20\n", "\n", "\n", "22\n", "\n", "\n", "23\n", "\n", "\n", "24\n", "\n", "\n", "26\n", "\n", "\n", "27\n", "\n", "\n", "30\n", "\n", "\n", "2\n", "\n", "\n", "15\n", "\n", "\n", "21\n", "\n", "\n", "12\n", "\n", "\n", "11\n", "\n", "\n", "5\n", "\n", "\n", "1\n", "\n", "\n", "25\n", "\n", "\n", "16\n", "\n", "\n", "7\n", "\n", "\n", "8\n", "\n", "\n", "6\n", "\n", "\n", "29\n", "\n", "\n", "13\n", "\n", "\n", "28\n", "\n", "\n", "4\n", "\n", "\n", "3\n", "\n", "\n", "9\n", "\n", "\n", "10\n", "\n", "\n", "14\n", "\n", "\n", "17\n", "\n", "\n", "18\n", "\n", "\n", "19\n", "\n", "\n", "20\n", "\n", "\n", "22\n", "\n", "\n", "23\n", "\n", "\n", "24\n", "\n", "\n", "26\n", "\n", "\n", "27\n", "\n", "\n", "30\n", "\n", "\n", "2\n", "\n", "\n", "15\n", "\n", "\n", "Class similarity\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "0.25\n", "\n", "\n", "0.50\n", "\n", "\n", "0.75\n", "\n", "\n", "1.00\n", "\n", "\n", "1.25\n", "\n", "\n", "1.50\n", "\n", "\n", "1.75\n", "\n", "\n", "\n" ], "text/plain": [ "Plot{Plots.GRBackend() n=1}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function diagonalize(C, tol=:auto; permute_y=false, doplot=false)\n", " C = copy(C)\n", " amdmat = size(C,1) == size(C,2) ? copy(C) : C'C\n", " # amdmat = C'C\n", " if tol === :auto # Find tolerance automatically (adhoc)\n", " y = abs.(amdmat[:])\n", " inds = sortperm(y)\n", " cum = cumsum(y[inds])\n", " cum ./= cum[end]\n", " i = findfirst(c->c > 0.3, cum)\n", " tol = y[inds[i]]\n", " end\n", " doplot && (histogram(abs.(amdmat[:])) |> display)\n", " amdmat[abs.(amdmat) .< tol] .= 0\n", " perm = amd(sparse(amdmat))\n", " yperm = permute_y ? perm : 1:size(C,1)\n", " C[yperm,perm], perm, yperm\n", "end\n", "function plotcov(C, xvector, yvector; kwargs...)\n", " xticks = (1:length(xvector), xvector)\n", " yticks = (1:length(yvector), yvector)\n", " heatmap(C; xticks=xticks, yticks=yticks, xrotation=50, title=\"Class similarity\", kwargs...)\n", "end\n", "\n", "classvecs_unit = mapslices(v-> v./norm(v), classvecs, dims=1)\n", "classcov = classvecs_unit'topic_corr_by_sample*classvecs_unit\n", "\n", "C, perm, yperm = diagonalize(classcov, permute_y=true)\n", "plotcov(C,ulabels[perm],ulabels[perm], yflip=true)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The figure indicates that class 21 and 12 are similar, and that class 3 is similar to many other classes. Class 15 is very unique." ] } ], "metadata": { "kernelspec": { "display_name": "Julia 1.4.0-rc2", "language": "julia", "name": "julia-1.4" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.4.0" } }, "nbformat": 4, "nbformat_minor": 2 }