{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Online Convolutional Dictionary Learning with Spatial Mask\n", "==========================================================\n", "\n", "This example demonstrates the use of [dictlrn.onlinecdl.OnlineConvBPDNMaskDictLearn](http://sporco.rtfd.org/en/latest/modules/sporco.dictlrn.onlinecdl.html#sporco.dictlrn.onlinecdl.OnlineConvBPDNMaskDictLearn) for learning a convolutional dictionary from a set of training images. The dictionary is learned using the online dictionary learning algorithm proposed in [[33]](http://sporco.rtfd.org/en/latest/zreferences.html#id33)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from __future__ import print_function\n", "from builtins import input\n", "\n", "import pyfftw # See https://github.com/pyFFTW/pyFFTW/issues/40\n", "import numpy as np\n", "\n", "from sporco.dictlrn import onlinecdl\n", "from sporco import util\n", "from sporco import signal\n", "from sporco import cuda\n", "from sporco import plot\n", "plot.config_notebook_plotting()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load training images." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [], "source": [ "exim = util.ExampleImages(scaled=True, zoom=0.5, gray=True)\n", "S1 = exim.image('barbara.png', idxexp=np.s_[10:522, 100:612])\n", "S2 = exim.image('kodim23.png', idxexp=np.s_[:, 60:572])\n", "S3 = exim.image('monarch.png', idxexp=np.s_[:, 160:672])\n", "S4 = exim.image('sail.png', idxexp=np.s_[:, 210:722])\n", "S5 = exim.image('tulips.png', idxexp=np.s_[:, 30:542])\n", "S = np.dstack((S1, S2, S3, S4, S5))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Highpass filter training images." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [], "source": [ "npd = 16\n", "fltlmbd = 5\n", "sl, sh = signal.tikhonov_filter(S, fltlmbd, npd)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create random mask and apply to highpass filtered training image set." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [], "source": [ "np.random.seed(12345)\n", "frc = 0.25\n", "W = signal.rndmask(S.shape, frc, dtype=np.float32)\n", "shw = W * sh" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Construct initial dictionary." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [], "source": [ "D0 = np.random.randn(8, 8, 32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set regularization parameter and options for dictionary learning solver." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [], "source": [ "lmbda = 0.1\n", "opt = onlinecdl.OnlineConvBPDNMaskDictLearn.Options({\n", " 'Verbose': True, 'ZeroMean': False, 'eta_a': 10.0,\n", " 'eta_b': 20.0, 'DataType': np.float32,\n", " 'CBPDN': {'rho': 3.0, 'AutoRho': {'Enabled': False},\n", " 'RelaxParam': 1.8, 'RelStopTol': 1e-4, 'MaxMainIter': 100,\n", " 'FastSolve': False, 'DataType': np.float32}})\n", "if cuda.device_count() > 0:\n", " opt['CUDA_CBPDN'] = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create solver object and solve." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Itn X r X s X ρ D cnstr D dlt D η \n", "----------------------------------------------------------------\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 0 0.00e+00 0.00e+00 0.00e+00 3.58e+01 2.12e+00 5.00e-01\n", " 1 0.00e+00 0.00e+00 0.00e+00 2.82e+01 1.48e+00 4.76e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 2 0.00e+00 0.00e+00 0.00e+00 2.52e+01 9.82e-01 4.55e-01\n", " 3 0.00e+00 0.00e+00 0.00e+00 2.39e+01 6.81e-01 4.35e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 4 0.00e+00 0.00e+00 0.00e+00 1.10e+01 7.68e-01 4.17e-01\n", " 5 0.00e+00 0.00e+00 0.00e+00 2.45e+01 9.28e-01 4.00e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 6 0.00e+00 0.00e+00 0.00e+00 2.29e+01 1.44e+00 3.85e-01\n", " 7 0.00e+00 0.00e+00 0.00e+00 9.67e+00 6.27e-01 3.70e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 8 0.00e+00 0.00e+00 0.00e+00 1.95e+01 5.49e-01 3.57e-01\n", " 9 0.00e+00 0.00e+00 0.00e+00 1.89e+01 3.76e-01 3.45e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 10 0.00e+00 0.00e+00 0.00e+00 8.65e+00 4.74e-01 3.33e-01\n", " 11 0.00e+00 0.00e+00 0.00e+00 1.97e+01 7.32e-01 3.23e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 12 0.00e+00 0.00e+00 0.00e+00 1.70e+01 4.52e-01 3.12e-01\n", " 13 0.00e+00 0.00e+00 0.00e+00 2.02e+01 7.10e-01 3.03e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 14 0.00e+00 0.00e+00 0.00e+00 1.94e+01 4.86e-01 2.94e-01\n", " 15 0.00e+00 0.00e+00 0.00e+00 1.89e+01 3.88e-01 2.86e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 16 0.00e+00 0.00e+00 0.00e+00 1.84e+01 3.19e-01 2.78e-01\n", " 17 0.00e+00 0.00e+00 0.00e+00 1.44e+01 6.04e-01 2.70e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 18 0.00e+00 0.00e+00 0.00e+00 1.54e+01 6.15e-01 2.63e-01\n", " 19 0.00e+00 0.00e+00 0.00e+00 1.49e+01 1.00e+00 2.56e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 20 0.00e+00 0.00e+00 0.00e+00 1.68e+01 5.41e-01 2.50e-01\n", " 21 0.00e+00 0.00e+00 0.00e+00 1.44e+01 5.49e-01 2.44e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 22 0.00e+00 0.00e+00 0.00e+00 1.40e+01 3.85e-01 2.38e-01\n", " 23 0.00e+00 0.00e+00 0.00e+00 1.37e+01 3.21e-01 2.33e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 24 0.00e+00 0.00e+00 0.00e+00 5.79e+00 4.06e-01 2.27e-01\n", " 25 0.00e+00 0.00e+00 0.00e+00 1.17e+01 4.26e-01 2.22e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 26 0.00e+00 0.00e+00 0.00e+00 1.45e+01 4.68e-01 2.17e-01\n", " 27 0.00e+00 0.00e+00 0.00e+00 1.20e+01 8.06e-01 2.13e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 28 0.00e+00 0.00e+00 0.00e+00 1.10e+01 3.56e-01 2.08e-01\n", " 29 0.00e+00 0.00e+00 0.00e+00 1.12e+01 5.25e-01 2.04e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 30 0.00e+00 0.00e+00 0.00e+00 1.06e+01 3.13e-01 2.00e-01\n", " 31 0.00e+00 0.00e+00 0.00e+00 4.95e+00 3.17e-01 1.96e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 32 0.00e+00 0.00e+00 0.00e+00 1.14e+01 4.91e-01 1.92e-01\n", " 33 0.00e+00 0.00e+00 0.00e+00 4.83e+00 2.83e-01 1.89e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 34 0.00e+00 0.00e+00 0.00e+00 4.76e+00 2.34e-01 1.85e-01\n", " 35 0.00e+00 0.00e+00 0.00e+00 1.00e+01 4.77e-01 1.82e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 36 0.00e+00 0.00e+00 0.00e+00 4.61e+00 2.16e-01 1.79e-01\n", " 37 0.00e+00 0.00e+00 0.00e+00 4.55e+00 1.90e-01 1.75e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 38 0.00e+00 0.00e+00 0.00e+00 1.03e+01 4.70e-01 1.72e-01\n", " 39 0.00e+00 0.00e+00 0.00e+00 4.44e+00 2.09e-01 1.69e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 40 0.00e+00 0.00e+00 0.00e+00 9.97e+00 3.55e-01 1.67e-01\n", " 41 0.00e+00 0.00e+00 0.00e+00 1.09e+01 3.94e-01 1.64e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 42 0.00e+00 0.00e+00 0.00e+00 8.92e+00 4.62e-01 1.61e-01\n", " 43 0.00e+00 0.00e+00 0.00e+00 1.06e+01 3.34e-01 1.59e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 44 0.00e+00 0.00e+00 0.00e+00 1.04e+01 2.74e-01 1.56e-01\n", " 45 0.00e+00 0.00e+00 0.00e+00 9.12e+00 4.01e-01 1.54e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 46 0.00e+00 0.00e+00 0.00e+00 7.94e+00 3.17e-01 1.52e-01\n", " 47 0.00e+00 0.00e+00 0.00e+00 7.89e+00 2.45e-01 1.49e-01\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " 48 0.00e+00 0.00e+00 0.00e+00 8.73e+00 3.32e-01 1.47e-01\n", " 49 0.00e+00 0.00e+00 0.00e+00 7.68e+00 2.40e-01 1.45e-01\n", "----------------------------------------------------------------\n", "OnlineConvBPDNMaskDictLearn solve time: 12.34s\n" ] } ], "source": [ "d = onlinecdl.OnlineConvBPDNMaskDictLearn(D0, lmbda, opt)\n", "\n", "iter = 50\n", "d.display_start()\n", "for it in range(iter):\n", " img_index = np.random.randint(0, sh.shape[-1])\n", " d.solve(shw[..., [img_index]], W[..., [img_index]])\n", "\n", "d.display_end()\n", "D1 = d.getdict()\n", "print(\"OnlineConvBPDNMaskDictLearn solve time: %.2fs\" %\n", " d.timer.elapsed('solve'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Display initial and final dictionaries." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "D1 = D1.squeeze()\n", "fig = plot.figure(figsize=(14, 7))\n", "plot.subplot(1, 2, 1)\n", "plot.imview(util.tiledict(D0), title='D0', fig=fig)\n", "plot.subplot(1, 2, 2)\n", "plot.imview(util.tiledict(D1), title='D1', fig=fig)\n", "fig.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Get iterations statistics from solver object and plot functional value." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "its = d.getitstat()\n", "fig = plot.figure(figsize=(7, 7))\n", "plot.plot(np.vstack((its.DeltaD, its.Eta)).T, xlbl='Iterations',\n", " lgnd=('Delta D', 'Eta'), fig=fig)\n", "fig.show()" ] } ], "metadata": { "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }