{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "# 05 - FWI with total variation (TV) minimization as constraints" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction\n", "\n", "There is a lot of research on regularization to improve the quality of the final result beyond the simple box constraints implemented in the previous tutorials. In this tutorial we look at how more advanced FWI techniques such as [total variation denoising]( https://en.wikipedia.org/wiki/Total_variation_denoising) applied as a constraint can be implemented using [Devito](http://www.opesci.org/devito-public) and [scipy.optimize.minimize](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html) using [scikit-image](http://scikit-image.org/). This is a variant of the FWI with TV constrains algorithm described in [Peters and Herrmann 2017](https://doi.org/10.1190/tle36010094.1).\n", "\n", "[Dask](https://dask.pydata.org/en/latest/#dask) is also used here to speed up the examples.\n", "\n", "This tutorial uses the same synthetic datasets and model setup as the previous two tutorials, so check back if you get lost on parts of the code specific to Devito, SciPy.optimize or Dask. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setting up (synthetic) data\n", "We are going to set up the same synthetic test case as for the previous tutorial (refer back for details)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "#NBVAL_IGNORE_OUTPUT\n", "from examples.seismic import Model, demo_model\n", "\n", "# Define the grid parameters\n", "def get_grid():\n", " shape = (101, 101) # Number of grid point (nx, nz)\n", " spacing = (10., 10.) # Grid spacing in m. The domain size is now 1km by 1km\n", " origin = (0., 0.) # Need origin to define relative source and receiver locations\n", "\n", " return shape, spacing, origin\n", "\n", "# Define the test phantom; in this case we are using a simple circle\n", "# so we can easily see what is going on.\n", "def get_true_model():\n", " shape, spacing, origin = get_grid()\n", " return demo_model('circle-isotropic', vp=3.0, vp_background=2.5, \n", " origin=origin, shape=shape, spacing=spacing, nbpml=40)\n", "\n", "# The initial guess for the subsurface model.\n", "def get_initial_model():\n", " shape, spacing, origin = get_grid()\n", "\n", " return demo_model('circle-isotropic', vp=2.5, vp_background=2.5, \n", " origin=origin, shape=shape, spacing=spacing, nbpml=40)\n", "\n", "\n", "from examples.seismic.acoustic import AcousticWaveSolver\n", "from examples.seismic import TimeAxis, RickerSource, Receiver\n", "\n", "# This is used by the worker to get the current model. \n", "def get_current_model(param):\n", " \"\"\" Returns the current model.\n", " \"\"\"\n", " model = get_initial_model()\n", " model.m.data[:] = np.reshape(np.load(param['model']), model.m.data.shape)\n", " return model\n", "\n", "# Inversion crime alert! Here the worker is creating the 'observed' data\n", "# using the real model. For a real case the worker would be reading\n", "# seismic data from disk.\n", "def get_data(param):\n", " \"\"\" Returns source and receiver data for a single shot labeled 'shot_id'.\n", " \"\"\"\n", " true_model = get_true_model()\n", " dt = true_model.critical_dt # Time step from model grid spacing\n", "\n", " # Set up source data and geometry.\n", " src = RickerSource(name='src', grid=true_model.grid, f0=param['f0'],\n", " time_range=TimeAxis(start=param['t0'], stop=param['tn'], step=dt))\n", " src.coordinates.data[0, :] = [30, param['shot_id']*1000./(param['nshots']-1)]\n", "\n", " # Set up receiver data and geometry.\n", " nreceivers = 101 # Number of receiver locations per shot\n", " rec = Receiver(name='rec', grid=true_model.grid, npoint=nreceivers,\n", " time_range=src.time_range)\n", " rec.coordinates.data[:, 1] = np.linspace(0, true_model.domain_size[0], num=nreceivers)\n", " rec.coordinates.data[:, 0] = 980. # 20m from the right end\n", "\n", " # Set up solver - using model_in so that we have the same dt,\n", " # otherwise we should use pandas to resample the time series data. \n", " solver = AcousticWaveSolver(true_model, src, rec, space_order=4)\n", "\n", " # Generate synthetic receiver data from true model\n", " true_d, _, _ = solver.forward(src=src, m=true_model.m)\n", "\n", " return src, true_d, solver\n", "\n", "# Define a type to store the functional and gradient.\n", "class fg_pair:\n", " def __init__(self, f, g):\n", " self.f = f\n", " self.g = g\n", " \n", " def __add__(self, other):\n", " f = self.f + other.f\n", " g = self.g + other.g\n", " \n", " return fg_pair(f, g)\n", " \n", " def __radd__(self, other):\n", " if other == 0:\n", " return self\n", " else:\n", " return self.__add__(other)\n", "\n", "from devito import Function\n", "\n", "# Create FWI gradient kernel for a single shot\n", "def fwi_gradient_i(param):\n", " from devito import clear_cache\n", "\n", " # Need to clear the workers cache.\n", " clear_cache()\n", "\n", " # Communicating the model via a file.\n", " model0 = get_current_model(param)\n", " src, rec, solver = get_data(param)\n", " \n", " # Create symbols to hold the gradient and the misfit between\n", " # the 'measured' and simulated data.\n", " grad = Function(name=\"grad\", grid=model0.grid)\n", " residual = Receiver(name='rec', grid=model0.grid, time_range=rec.time_range,\n", " coordinates=rec.coordinates.data)\n", " \n", " # Compute simulated data and full forward wavefield u0\n", " d, u0, _ = solver.forward(src=src, m=model0.m, save=True)\n", " \n", " # Compute the data misfit (residual) and objective function \n", " residual.data[:] = d.data[:] - rec.data[:]\n", " f = .5*np.linalg.norm(residual.data.flatten())**2\n", " \n", " # Compute gradient using the adjoint-state method. Note, this\n", " # backpropagates the data misfit through the model.\n", " solver.gradient(rec=residual, u=u0, m=model0.m, grad=grad)\n", " \n", " # Copying here to avoid a (probably overzealous) destructor deleting\n", " # the gradient before Dask has had a chance to communicate it.\n", " g = np.array(grad.data[:])\n", " \n", " # return the objective functional and gradient.\n", " return fg_pair(f, g)\n", "\n", "import numpy as np\n", "from distributed import LocalCluster, Client\n", "\n", "# Dumps the model to disk; workers will pick this up when they need it.\n", "def dump_model(param, model):\n", " np.save(param['model'], model.astype(np.float32))\n", "\n", "def fwi_gradient(model, param):\n", " # Dump a copy of the current model for the workers\n", " # to pick up when they are ready.\n", " param['model'] = \"model_0.npy\"\n", " dump_model(param, model)\n", "\n", " # Define work list\n", " work = [dict(param) for i in range(param['nshots'])]\n", " for i in range(param['nshots']):\n", " work[i]['shot_id'] = i\n", " \n", " # Distribute worklist to workers.\n", " fgi = client.map(fwi_gradient_i, work)\n", " \n", " # Perform data reduction.\n", " fg = client.submit(sum, fgi).result()\n", " \n", " # L-BFGS in scipy expects a flat array in 64-bit floats.\n", " return fg.f, fg.g.flatten().astype(np.float64)\n", "\n", "# Start Dask cluster\n", "cluster = LocalCluster(n_workers=5, death_timeout=600)\n", "client = Client(cluster)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## FWI with L-BFGS-B\n", "Equipped with a function to calculate the functional and gradient, we are finally ready to define the optimization function." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from scipy import optimize\n", "from skimage.restoration import denoise_tv_chambolle as denoise\n", "\n", "# Define bounding box constraints on the solution.\n", "def apply_box_constraint(m):\n", " # Maximum possible 'realistic' velocity is 3.5 km/sec\n", " # Minimum possible 'realistic' velocity is 2 km/sec\n", " return np.clip(m, 1/3.5**2, 1/2**2)\n", "\n", "# Many optimization methods in scipy.optimize.minimize accept a callback\n", "# function that can operate on the solution after every iteration. Here\n", "# we use this to apply constraints and to monitor the true relative\n", "# solution error.\n", "relative_error = []\n", "\n", "def fwi_tv_callbacks(x):\n", " # Apply boundary constraint\n", " x.data[:] = denoise(x.reshape(181, 181), weight=5.0e-3).flatten()\n", " x.data[:] = apply_box_constraint(x)\n", " \n", " # Calculate true relative error\n", " true_x = get_true_model().m.data.flatten()\n", " relative_error.append(np.linalg.norm((x-true_x)/true_x))\n", "\n", "def fwi(model, param, ftol=1e-6, maxiter=20):\n", " result = optimize.minimize(fwi_gradient,\n", " model.m.data.flatten().astype(np.float64),\n", " args=(param, ), method='L-BFGS-B', jac=True,\n", " callback=fwi_tv_callbacks,\n", " options={'ftol':ftol,\n", " 'maxiter':maxiter,\n", " 'disp':True})\n", "\n", " return result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now apply our FWI function and have a look at the result." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " fun: 199.0333099294088\n", " hess_inv: <32761x32761 LbfgsInvHessProduct with dtype=float64>\n", " jac: array([-7.24996760e-15, -4.09159536e-14, -1.42357399e-13, ...,\n", " -1.36085370e-13, -3.84566307e-14, -6.74617866e-15])\n", " message: b'STOP: TOTAL NO. of ITERATIONS EXCEEDS LIMIT'\n", " nfev: 8\n", " nit: 6\n", " status: 1\n", " success: False\n", " x: array([0.16, 0.16, 0.16, ..., 0.16, 0.16, 0.16])\n" ] } ], "source": [ "#NBVAL_SKIP\n", "\n", "# Change to the WARNING log level to reduce log output\n", "# as compared to the default DEBUG\n", "from devito import configuration\n", "configuration['log_level'] = 'WARNING'\n", "\n", "# Set up inversion parameters.\n", "param = {'t0': 0.,\n", " 'tn': 1000., # Simulation lasts 1 second (1000 ms)\n", " 'f0': 0.010, # Source peak frequency is 10Hz (0.010 kHz)\n", " 'nshots': 8} # Number of shots to create gradient from\n", "\n", "model0 = get_initial_model()\n", "\n", "# Apply FWI with TV.\n", "result = fwi(model0, param, ftol=1e-6, maxiter=5)\n", "\n", "# Print out results of optimizer.\n", "print(result)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#NBVAL_SKIP\n", "\n", "# Show what the update does to the model\n", "from examples.seismic import plot_image, plot_velocity\n", "\n", "model0.m.data[:] = result.x.astype(np.float32).reshape(model0.m.data.shape)\n", "model0.vp = np.sqrt(1. / model0.m.data[40:-40, 40:-40])\n", "plot_velocity(model0)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#NBVAL_SKIP\n", "\n", "# Plot percentage error\n", "plot_image(100*np.abs(model0.vp-get_true_model().vp.data)/get_true_model().vp.data, vmax=15, cmap=\"hot\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "distributed.core - WARNING - Event loop was unresponsive in Nanny for 89.32s. This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.\n", "distributed.core - WARNING - Event loop was unresponsive in Nanny for 93.64s. This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.\n", "distributed.core - WARNING - Event loop was unresponsive in Nanny for 93.65s. This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.\n", "distributed.core - WARNING - Event loop was unresponsive in Nanny for 93.65s. This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.\n", "distributed.core - WARNING - Event loop was unresponsive in Nanny for 93.65s. This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.\n", "distributed.core - WARNING - Event loop was unresponsive in Scheduler for 93.65s. This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.\n", "distributed.core - WARNING - Event loop was unresponsive in Nanny for 4.41s. This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.\n", "distributed.core - WARNING - Event loop was unresponsive in Nanny for 4.50s. This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.\n", "distributed.core - WARNING - Event loop was unresponsive in Nanny for 4.52s. This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.\n", "distributed.core - WARNING - Event loop was unresponsive in Nanny for 4.57s. This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.\n", "distributed.core - WARNING - Event loop was unresponsive in Nanny for 4.57s. This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.\n", "distributed.core - WARNING - Event loop was unresponsive in Scheduler for 4.57s. This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.\n", "distributed.core - WARNING - Event loop was unresponsive in Nanny for 4.56s. This is often caused by long-running GIL-holding functions or moving large chunks of data. This can cause timeouts and instability.\n" ] } ], "source": [ "#NBVAL_SKIP\n", "import matplotlib.pyplot as plt\n", "\n", "# Plot objective function decrease\n", "plt.figure()\n", "plt.loglog(relative_error)\n", "plt.xlabel('Iteration number')\n", "plt.ylabel('True relative error')\n", "plt.title('Convergence')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook is part of the tutorial \"Optimised Symbolic Finite Difference Computation with Devito\" presented at the IntelĀ® HPC Developer Conference 2017." ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" }, "widgets": { "state": {}, "version": "1.1.2" } }, "nbformat": 4, "nbformat_minor": 1 }