{ "cells": [ { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "
\n", "\n", "
\"Unidata
\n", "\n", "

Using numpy and KD-trees with netCDF data

\n", "

Unidata Python Workshop

\n", "\n", "
\n", "
\n", "\n", "
\n", "\n", "
\"Example
\n", "\n", "There is now a Unidata Developer's [Blog entry](http://www.unidata.ucar.edu/blogs/developer/en/entry/accessing_netcdf_data_by_coordinates) accompanying this iPython notebook.\n", "\n", "The goal is to demonstrate how to quickly access netCDF data based on geospatial coordinates instead of array indices.\n", "\n", "- First we show a naive and slow way to do this, in which we also have to worry about longitude anomalies\n", "- Then we speed up access with numpy arrays\n", "- Next, we demonstrate how to eliminate longitude anomalies\n", "- Finally, we use a kd-tree data structure to significantly speed up access by coordinates for large problems" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "## Getting data by coordinates from a netCDF File\n", "\n", "Let's look at a netCDF file from the *Atlantic Real-Time Ocean Forecast System*. If you have cloned the [Unidata 2015 Python Workshop](https://github.com/Unidata/unidata-python-workshop), this data is already available in '../data/rtofs_glo_3dz_f006_6hrly_reg3.nc'. Otherwise you can get it from [rtofs_glo_3dz_f006_6hrly_reg3.nc](https://github.com/Unidata/tds-python-workshop/blob/master/data/rtofs_glo_3dz_f006_6hrly_reg3.nc)." ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Looking at netCDF metadata from Python\n", "\n", "In iPython, we could invoke the **ncdump** utility like this:\n", " \n", " filename = '../data/rtofs_glo_3dz_f006_6hrly_reg3.nc'\n", " !ncdump -h $filename\n", " \n", "*if* we know that a recent version of **ncdump** is installed that\n", "can read compressed data from netCDF-4 classic model files.\n", "\n", "Alternatively, we'll use the netCDF4python package to show information about\n", "the file in a form that's somewhat less familiar, but contains the information\n", "we need for the subsequent examples. This works for any netCDF file format:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "root group (NETCDF4_CLASSIC data model, file format HDF5):\n", " Conventions: CF-1.0\n", " title: HYCOM ATLb2.00\n", " institution: National Centers for Environmental Prediction\n", " source: HYCOM archive file\n", " experiment: 90.9\n", " history: archv2ncdf3z\n", " dimensions(sizes): MT(1), Y(850), X(712), Depth(10)\n", " variables(dimensions): float64 \u001b[4mMT\u001b[0m(MT), float64 \u001b[4mDate\u001b[0m(MT), float32 \u001b[4mDepth\u001b[0m(Depth), int32 \u001b[4mY\u001b[0m(Y), int32 \u001b[4mX\u001b[0m(X), float32 \u001b[4mLatitude\u001b[0m(Y,X), float32 \u001b[4mLongitude\u001b[0m(Y,X), float32 \u001b[4mu\u001b[0m(MT,Depth,Y,X), float32 \u001b[4mv\u001b[0m(MT,Depth,Y,X), float32 \u001b[4mtemperature\u001b[0m(MT,Depth,Y,X), float32 \u001b[4msalinity\u001b[0m(MT,Depth,Y,X)\n", " groups: \n", "\n", "\n", "float32 temperature(MT, Depth, Y, X)\n", " coordinates: Longitude Latitude Date\n", " standard_name: sea_water_potential_temperature\n", " units: degC\n", " _FillValue: 1.26765e+30\n", " valid_range: [ -5.07860279 11.14989948]\n", " long_name: temp [90.9H]\n", "unlimited dimensions: MT\n", "current shape = (1, 10, 850, 712)\n", "filling on\n", "\n", "float32 salinity(MT, Depth, Y, X)\n", " coordinates: Longitude Latitude Date\n", " standard_name: sea_water_salinity\n", " units: psu\n", " _FillValue: 1.26765e+30\n", " valid_range: [ 11.61832619 35.04695129]\n", " long_name: salinity [90.9H]\n", "unlimited dimensions: MT\n", "current shape = (1, 10, 850, 712)\n", "filling on\n", "\n", "float32 Latitude(Y, X)\n", " standard_name: latitude\n", " units: degrees_north\n", "unlimited dimensions: \n", "current shape = (850, 712)\n", "filling on, default _FillValue of 9.969209968386869e+36 used\n", "\n", "\n", "float32 Longitude(Y, X)\n", " standard_name: longitude\n", " units: degrees_east\n", "unlimited dimensions: \n", "current shape = (850, 712)\n", "filling on, default _FillValue of 9.969209968386869e+36 used\n", "\n" ] } ], "source": [ "import netCDF4\n", "filename = '../data/rtofs_glo_3dz_f006_6hrly_reg3.nc'\n", "ncfile = netCDF4.Dataset(filename, 'r')\n", "print(ncfile) # shows global attributes, dimensions, and variables\n", "ncvars = ncfile.variables # a dictionary of variables\n", "# print information about specific variables, including type, shape, and attributes\n", "for varname in ['temperature', 'salinity', 'Latitude', 'Longitude']:\n", " print(ncvars[varname])" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "Here's a sparse picture (every 25th point on each axis) of what the grid looks like on which Latitude, Longitude, Temperature, Salinity, and other variables are defined:\n", "\n", "![Example lat-lon grid](http://www.unidata.ucar.edu/software/netcdf/workshops/images/20130811_rew_blog_grid.png)" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "## Example query: sea surface temperature and salinity at 50N, 140W?\n", "\n", "- So **Longitude** and **Latitude** are 2D netCDF variables of shape 850 x 712, indexed by **Y** and **X** dimensions\n", "- That's 605200 values for each\n", "- There's no _direct_ way in this file (and many netCDF files) to compute grid indexes from coordinates via a coordinate system and projection parameters. Instead, we have to rely on the latitude and longitude auxiliary coordinate variables, as required by the CF conventions for data not on a simple lat,lon grid.\n", "- To get the temperature at 50N, 140W, we need to find **Y** and **X** indexes **iy** and **ix** such that (**Longitude[iy, ix]**, **Latitude[iy, ix]**) is \"close\" to (50.0, -140.0)." ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Naive, slow way using nested loops\n", "\n", "- Initially, for simplicity, we just use Euclidean distance squared, as if the Earth is flat, latitude and longitude are $x$- and $y$-coordinates, and the distance squared between points $(lat_1,lon_1)$ and $(lat_0,lon_0)$ is $( lat_1 - lat_0 )^2 + ( lon_1 - lon_0 )^2$.\n", "- Note: these assumptions are wrong near the poles and on opposite sides of longitude boundary discontinuity.\n", "- So, keeping things simple, we want to find **iy** and **ix** to minimize\n", "\n", " ``(Latitude[iy, ix] - lat0)**2 + (Longitude[iy, ix] - lon0)**2``\n", " \n", "![Flat Earth](http://www.unidata.ucar.edu/software/netcdf/workshops/images/unproj.gif)" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "## Reading netCDF data into numpy arrays\n", "\n", "To access netCDF data, rather than just metadata, we will also need NumPy:\n", "\n", "- A Python library for scientific programming.\n", "- Supports n-dimensional array-based calculations similar to Fortran and IDL.\n", "- Includes fast mathematical functions to act on scalars and arrays.\n", "\n", "With the Python netCDF4 package, using \"[ ... ]\" to index a netCDF variable object reads or writes a numpy array from the associated netCDF file.\n", "\n", "The code below reads latitude and longitude values into 2D numpy arrays named **latvals** and **lonvals**:" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### First version: slow and spatially challenged\n", "Here's a function that uses simple nested loops to find indices that minimize the distance to the desired coordinates, written as if using Fortran or C rather than Python. We'll call this function in the cell following this definition ..." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "import numpy as np\n", "import netCDF4\n", "\n", "def naive_slow(latvar,lonvar,lat0,lon0):\n", " '''\n", " Find \"closest\" point in a set of (lat,lon) points to specified point\n", " latvar - 2D latitude variable from an open netCDF dataset\n", " lonvar - 2D longitude variable from an open netCDF dataset\n", " lat0,lon0 - query point\n", " Returns iy,ix such that \n", " (lonval[iy,ix] - lon0)**2 + (latval[iy,ix] - lat0)**2\n", " is minimum. This \"closeness\" measure works badly near poles and\n", " longitude boundaries.\n", " '''\n", " # Read from file into numpy arrays\n", " latvals = latvar[:]\n", " lonvals = lonvar[:]\n", " ny,nx = latvals.shape\n", " dist_sq_min = 1.0e30\n", " for iy in range(ny):\n", " for ix in range(nx):\n", " latval = latvals[iy, ix]\n", " lonval = lonvals[iy, ix]\n", " dist_sq = (latval - lat0)**2 + (lonval - lon0)**2\n", " if dist_sq < dist_sq_min:\n", " iy_min, ix_min, dist_sq_min = iy, ix, dist_sq\n", " return iy_min,ix_min" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "When we call the function above it takes several seconds to run, because it calculates distances one point at a time, for each of the 605200 $(lat, lon)$ points. Note that once indices for the point nearest to (50, -140) are found, they can be used to access temperature, salinity, and other netCDF variables that use the same dimensions." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Closest lat lon: 49.9867 -139.982\n", "temperature: 6.46312 degC\n", "salinity: 32.6572 psu\n" ] } ], "source": [ "ncfile = netCDF4.Dataset(filename, 'r')\n", "latvar = ncfile.variables['Latitude']\n", "lonvar = ncfile.variables['Longitude']\n", "iy,ix = naive_slow(latvar, lonvar, 50.0, -140.0)\n", "print('Closest lat lon:', latvar[iy,ix], lonvar[iy,ix])\n", "tempvar = ncfile.variables['temperature']\n", "salvar = ncfile.variables['salinity']\n", "print('temperature:', tempvar[0, 0, iy, ix], tempvar.units)\n", "print('salinity:', salvar[0, 0, iy, ix], salvar.units)\n", "ncfile.close()" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### NumPy arrays instead of loops: fast, but still assumes flat earth\n", "\n", "The above function is slow, because it doesn't make good use of NumPy arrays. It's much faster to use whole array operations to eliminate loops and element-at-a-time computation. NumPy functions that help eliminate loops include:\n", "\n", "- The `argmin()` method that returns a 1D index of the minimum value of a NumPy array\n", "- The `unravel_index()` function that converts a 1D index back into a multidimensional index" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Closest lat lon: 49.9867 -139.982\n" ] } ], "source": [ "import numpy as np\n", "import netCDF4\n", "\n", "def naive_fast(latvar,lonvar,lat0,lon0):\n", " # Read latitude and longitude from file into numpy arrays\n", " latvals = latvar[:]\n", " lonvals = lonvar[:]\n", " ny,nx = latvals.shape\n", " dist_sq = (latvals-lat0)**2 + (lonvals-lon0)**2\n", " minindex_flattened = dist_sq.argmin() # 1D index of min element\n", " iy_min,ix_min = np.unravel_index(minindex_flattened, latvals.shape)\n", " return iy_min,ix_min\n", "\n", "ncfile = netCDF4.Dataset(filename, 'r')\n", "latvar = ncfile.variables['Latitude']\n", "lonvar = ncfile.variables['Longitude']\n", "iy,ix = naive_fast(latvar, lonvar, 50.0, -140.0)\n", "print('Closest lat lon:', latvar[iy,ix], lonvar[iy,ix])\n", "ncfile.close()" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Spherical Earth with tunnel distance: fast _and_ correct\n", "\n", "Though assuming a flat Earth may work OK for this example, we'd like to not worry about whether longitudes are from 0 to 360 or -180 to 180, or whether points are close to the poles.\n", "The code below fixes this by using the square of \"tunnel distance\" between (lat,lon) points. This version is both fast and correct (for a _spherical_ Earth)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Closest lat lon: 49.9867 -139.982\n" ] } ], "source": [ "import numpy as np\n", "import netCDF4\n", "from math import pi\n", "from numpy import cos, sin\n", "\n", "def tunnel_fast(latvar,lonvar,lat0,lon0):\n", " '''\n", " Find closest point in a set of (lat,lon) points to specified point\n", " latvar - 2D latitude variable from an open netCDF dataset\n", " lonvar - 2D longitude variable from an open netCDF dataset\n", " lat0,lon0 - query point\n", " Returns iy,ix such that the square of the tunnel distance\n", " between (latval[it,ix],lonval[iy,ix]) and (lat0,lon0)\n", " is minimum.\n", " '''\n", " rad_factor = pi/180.0 # for trignometry, need angles in radians\n", " # Read latitude and longitude from file into numpy arrays\n", " latvals = latvar[:] * rad_factor\n", " lonvals = lonvar[:] * rad_factor\n", " ny,nx = latvals.shape\n", " lat0_rad = lat0 * rad_factor\n", " lon0_rad = lon0 * rad_factor\n", " # Compute numpy arrays for all values, no loops\n", " clat,clon = cos(latvals),cos(lonvals)\n", " slat,slon = sin(latvals),sin(lonvals)\n", " delX = cos(lat0_rad)*cos(lon0_rad) - clat*clon\n", " delY = cos(lat0_rad)*sin(lon0_rad) - clat*slon\n", " delZ = sin(lat0_rad) - slat;\n", " dist_sq = delX**2 + delY**2 + delZ**2\n", " minindex_1d = dist_sq.argmin() # 1D index of minimum element\n", " iy_min,ix_min = np.unravel_index(minindex_1d, latvals.shape)\n", " return iy_min,ix_min\n", "\n", "ncfile = netCDF4.Dataset(filename, 'r')\n", "latvar = ncfile.variables['Latitude']\n", "lonvar = ncfile.variables['Longitude']\n", "iy,ix = tunnel_fast(latvar, lonvar, 50.0, -140.0)\n", "print('Closest lat lon:', latvar[iy,ix], lonvar[iy,ix])\n", "ncfile.close()" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### KD-Trees: faster data structure for lots of queries\n", "\n", "We can still do better, by using a data structure designed to support efficient nearest-neighbor queries: the [KD-tree](http://en.wikipedia.org/wiki/K-d_tree). It works like a multidimensional binary tree, so finding the point nearest to a query point is _much_ faster than computing all the distances to find the minimum. It takes some setup time to load all the points into the data structure, but that only has to be done once for a given set of points. \n", " \n", "For a single point query, it's still more than twice as fast as the naive slow version above, but building the KD-tree for 605,200 points takes more time than the fast numpy search through all the points, so in this case using the KD-tree for a _single_ point query is sort of pointless ..." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Closest lat lon: 49.9867 -139.982\n" ] } ], "source": [ "import numpy as np\n", "import netCDF4\n", "from math import pi\n", "from numpy import cos, sin\n", "from scipy.spatial import cKDTree\n", "\n", "def kdtree_fast(latvar,lonvar,lat0,lon0):\n", " rad_factor = pi/180.0 # for trignometry, need angles in radians\n", " # Read latitude and longitude from file into numpy arrays\n", " latvals = latvar[:] * rad_factor\n", " lonvals = lonvar[:] * rad_factor\n", " ny,nx = latvals.shape\n", " clat,clon = cos(latvals),cos(lonvals)\n", " slat,slon = sin(latvals),sin(lonvals)\n", " # Build kd-tree from big arrays of 3D coordinates\n", " triples = list(zip(np.ravel(clat*clon), np.ravel(clat*slon), np.ravel(slat)))\n", " kdt = cKDTree(triples)\n", " lat0_rad = lat0 * rad_factor\n", " lon0_rad = lon0 * rad_factor\n", " clat0,clon0 = cos(lat0_rad),cos(lon0_rad)\n", " slat0,slon0 = sin(lat0_rad),sin(lon0_rad)\n", " dist_sq_min, minindex_1d = kdt.query([clat0*clon0, clat0*slon0, slat0])\n", " iy_min, ix_min = np.unravel_index(minindex_1d, latvals.shape)\n", " return iy_min,ix_min\n", " \n", "ncfile = netCDF4.Dataset(filename, 'r')\n", "latvar = ncfile.variables['Latitude']\n", "lonvar = ncfile.variables['Longitude']\n", "iy,ix = kdtree_fast(latvar, lonvar, 50.0, -140.0)\n", "print('Closest lat lon:', latvar[iy,ix], lonvar[iy,ix])\n", "ncfile.close()" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Timing the functions\n", "\n", "If you're curious about actual times for the versions above, the iPython notebook \"%%timeit\" statement gets accurate timings of all of them. Below, we time just a single query point, in this case (50.0, -140.0). To get accurate timings, the \"%%timeit\" statement lets us do untimed setup first on the same line, before running the function call in a loop." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "ncfile = netCDF4.Dataset(filename,'r')\n", "latvar = ncfile.variables['Latitude']\n", "lonvar = ncfile.variables['Longitude']" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1 loop, best of 3: 4.72 s per loop\n" ] } ], "source": [ "%%timeit\n", "naive_slow(latvar, lonvar, 50.0, -140.0)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "100 loops, best of 3: 5.46 ms per loop\n" ] } ], "source": [ "%%timeit\n", "naive_fast(latvar, lonvar, 50.0, -140.0)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10 loops, best of 3: 40.8 ms per loop\n" ] } ], "source": [ "%%timeit\n", "tunnel_fast(latvar, lonvar, 50.0, -140.0)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1 loop, best of 3: 1.11 s per loop\n" ] } ], "source": [ "%%timeit\n", "kdtree_fast(latvar, lonvar, 50.0, -140.0)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "ncfile.close()" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "## Separating setup from query\n", "\n", "The above use of the KD-tree data structure is not the way it's meant to be used. Instead, it should be initialized _once_ with all the k-dimensional data for which nearest-neighbors are desired, then used repeatedly on each query, amortizing the work done to build the data structure over all the following queries. By separately timing the setup and the time required per query, the threshold for number of queries beyond which the KD-tree is faster can be determined.\n", "\n", "That's exactly what we'll do now. We split each algorithm into two functions, a setup function and a query function. The times per query go from seconds (the naive version) to milliseconds (the array-oriented numpy version) to microseconds (the turbo-charged KD-tree, once it's built).\n", " \n", "Rather than just using functions, we define a Class for each algorithm, do the setup in the class constructor, and provide a query method. " ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Closest lat lon: 49.9867 -139.982\n" ] } ], "source": [ "# Split naive_slow into initialization and query, so we can time them separately\n", "import numpy as np\n", "import netCDF4\n", "\n", "class Naive_slow(object):\n", " def __init__(self, ncfile, latvarname, lonvarname):\n", " self.ncfile = ncfile\n", " self.latvar = self.ncfile.variables[latvarname]\n", " self.lonvar = self.ncfile.variables[lonvarname]\n", " # Read latitude and longitude from file into numpy arrays\n", " self.latvals = self.latvar[:]\n", " self.lonvals = self.lonvar[:]\n", " self.shape = self.latvals.shape\n", "\n", " def query(self,lat0,lon0):\n", " ny,nx = self.shape\n", " dist_sq_min = 1.0e30\n", " for iy in range(ny):\n", " for ix in range(nx):\n", " latval = self.latvals[iy, ix]\n", " lonval = self.lonvals[iy, ix]\n", " dist_sq = (latval - lat0)**2 + (lonval - lon0)**2\n", " if dist_sq < dist_sq_min:\n", " iy_min, ix_min, dist_sq_min = iy, ix, dist_sq\n", " return iy_min,ix_min\n", "\n", "ncfile = netCDF4.Dataset(filename, 'r')\n", "ns = Naive_slow(ncfile,'Latitude','Longitude')\n", "iy,ix = ns.query(50.0, -140.0)\n", "print('Closest lat lon:', ns.latvar[iy,ix], ns.lonvar[iy,ix])\n", "ncfile.close()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Closest lat lon: 49.9867 -139.982\n" ] } ], "source": [ "# Split naive_fast into initialization and query, so we can time them separately\n", "import numpy as np\n", "import netCDF4\n", "\n", "class Naive_fast(object):\n", " def __init__(self, ncfile, latvarname, lonvarname):\n", " self.ncfile = ncfile\n", " self.latvar = self.ncfile.variables[latvarname]\n", " self.lonvar = self.ncfile.variables[lonvarname] \n", " # Read latitude and longitude from file into numpy arrays\n", " self.latvals = self.latvar[:]\n", " self.lonvals = self.lonvar[:]\n", " self.shape = self.latvals.shape\n", " \n", "\n", " def query(self,lat0,lon0):\n", " dist_sq = (self.latvals-lat0)**2 + (self.lonvals-lon0)**2\n", " minindex_flattened = dist_sq.argmin() # 1D index\n", " iy_min, ix_min = np.unravel_index(minindex_flattened, self.shape) # 2D indexes\n", " return iy_min,ix_min\n", "\n", "ncfile = netCDF4.Dataset(filename, 'r')\n", "ns = Naive_fast(ncfile,'Latitude','Longitude')\n", "iy,ix = ns.query(50.0, -140.0)\n", "print('Closest lat lon:', ns.latvar[iy,ix], ns.lonvar[iy,ix])\n", "ncfile.close()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Closest lat lon: 49.9867 -139.982\n" ] } ], "source": [ "# Split tunnel_fast into initialization and query, so we can time them separately\n", "import numpy as np\n", "import netCDF4\n", "from math import pi\n", "from numpy import cos, sin\n", "\n", "class Tunnel_fast(object):\n", " def __init__(self, ncfile, latvarname, lonvarname):\n", " self.ncfile = ncfile\n", " self.latvar = self.ncfile.variables[latvarname]\n", " self.lonvar = self.ncfile.variables[lonvarname] \n", " # Read latitude and longitude from file into numpy arrays\n", " rad_factor = pi/180.0 # for trignometry, need angles in radians\n", " self.latvals = self.latvar[:] * rad_factor\n", " self.lonvals = self.lonvar[:] * rad_factor\n", " self.shape = self.latvals.shape\n", " clat,clon,slon = cos(self.latvals),cos(self.lonvals),sin(self.lonvals)\n", " self.clat_clon = clat*clon\n", " self.clat_slon = clat*slon\n", " self.slat = sin(self.latvals)\n", " \n", " def query(self,lat0,lon0):\n", " # for trignometry, need angles in radians\n", " rad_factor = pi/180.0 \n", " lat0_rad = lat0 * rad_factor\n", " lon0_rad = lon0 * rad_factor\n", " delX = cos(lat0_rad)*cos(lon0_rad) - self.clat_clon\n", " delY = cos(lat0_rad)*sin(lon0_rad) - self.clat_slon\n", " delZ = sin(lat0_rad) - self.slat;\n", " dist_sq = delX**2 + delY**2 + delZ**2\n", " minindex_1d = dist_sq.argmin() # 1D index \n", " iy_min, ix_min = np.unravel_index(minindex_1d, self.shape) # 2D indexes\n", " return iy_min,ix_min\n", "\n", "ncfile = netCDF4.Dataset(filename, 'r')\n", "ns = Tunnel_fast(ncfile,'Latitude','Longitude')\n", "iy,ix = ns.query(50.0, -140.0)\n", "print('Closest lat lon:', ns.latvar[iy,ix], ns.lonvar[iy,ix])\n", "ncfile.close()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Closest lat lon: 49.9867 -139.982\n" ] } ], "source": [ "# Split kdtree_fast into initialization and query, so we can time them separately\n", "import numpy as np\n", "import netCDF4\n", "from math import pi\n", "from numpy import cos, sin\n", "from scipy.spatial import cKDTree\n", "\n", "class Kdtree_fast(object):\n", " def __init__(self, ncfile, latvarname, lonvarname):\n", " self.ncfile = ncfile\n", " self.latvar = self.ncfile.variables[latvarname]\n", " self.lonvar = self.ncfile.variables[lonvarname] \n", " # Read latitude and longitude from file into numpy arrays\n", " rad_factor = pi/180.0 # for trignometry, need angles in radians\n", " self.latvals = self.latvar[:] * rad_factor\n", " self.lonvals = self.lonvar[:] * rad_factor\n", " self.shape = self.latvals.shape\n", " clat,clon = cos(self.latvals),cos(self.lonvals)\n", " slat,slon = sin(self.latvals),sin(self.lonvals)\n", " clat_clon = clat*clon\n", " clat_slon = clat*slon\n", " triples = list(zip(np.ravel(clat*clon), np.ravel(clat*slon), np.ravel(slat)))\n", " self.kdt = cKDTree(triples)\n", "\n", " def query(self,lat0,lon0):\n", " rad_factor = pi/180.0 \n", " lat0_rad = lat0 * rad_factor\n", " lon0_rad = lon0 * rad_factor\n", " clat0,clon0 = cos(lat0_rad),cos(lon0_rad)\n", " slat0,slon0 = sin(lat0_rad),sin(lon0_rad)\n", " dist_sq_min, minindex_1d = self.kdt.query([clat0*clon0,clat0*slon0,slat0])\n", " iy_min, ix_min = np.unravel_index(minindex_1d, self.shape)\n", " return iy_min,ix_min\n", "\n", "ncfile = netCDF4.Dataset(filename, 'r')\n", "ns = Kdtree_fast(ncfile,'Latitude','Longitude')\n", "iy,ix = ns.query(50.0, -140.0)\n", "print('Closest lat lon:', ns.latvar[iy,ix], ns.lonvar[iy,ix])\n", "ncfile.close()" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Setup times for the four algorithms" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "ncfile = netCDF4.Dataset(filename, 'r')" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "100 loops, best of 3: 2.05 ms per loop\n" ] } ], "source": [ "%%timeit\n", "ns = Naive_slow(ncfile,'Latitude','Longitude')" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "100 loops, best of 3: 2.09 ms per loop\n" ] } ], "source": [ "%%timeit\n", "ns = Naive_fast(ncfile,'Latitude','Longitude')" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10 loops, best of 3: 27.7 ms per loop\n" ] } ], "source": [ "%%timeit\n", "ns = Tunnel_fast(ncfile,'Latitude','Longitude')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1 loop, best of 3: 1.14 s per loop\n" ] } ], "source": [ "%%timeit\n", "ns = Kdtree_fast(ncfile,'Latitude','Longitude')" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Query times for the four algorithms" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "%%timeit ns = Naive_slow(ncfile,'Latitude','Longitude')\n", "iy,ix = ns.query(50.0, -140.0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "%%timeit ns = Naive_fast(ncfile,'Latitude','Longitude')\n", "iy,ix = ns.query(50.0, -140.0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "%%timeit ns = Tunnel_fast(ncfile,'Latitude','Longitude')\n", "iy,ix = ns.query(50.0, -140.0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "%%timeit ns = Kdtree_fast(ncfile,'Latitude','Longitude')\n", "iy,ix = ns.query(50.0, -140.0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true, "deletable": true, "editable": true }, "outputs": [], "source": [ "ncfile.close()" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "In the next cell, we copy the results of the %%timeit runs into Python variables. _(Is there a way to capture %%timeit output, so we don't have to do this manually?)_" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "ns0,nf0,tf0,kd0 = 3.76, 3.8, 27.4, 2520 # setup times in msec\n", "nsq,nfq,tfq,kdq = 7790, 2.46, 5.14, .0738 # query times in msec" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "### Summary of timings\n", "\n", "The naive_slow method is always slower than all other methods. The naive_fast method would only be worth considering if non-flatness of the Earth is irrelevant, for example in a relatively small region not close to the poles and not crossing a longitude discontinuity.\n", "\n", "Total time for running initialization followed by N queries is:\n", "\n", " - naive_slow: $ns0 + nsq * N$\n", " - naive_fast: $nf0 + nfq * N$\n", " - tunnel_fast: $nt0 + ntq * N$\n", " - kdtree_fast: $kd0 + kdq * N$" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "deletable": true, "editable": true }, "outputs": [], "source": [ "N = 10000\n", "print(N, \"queries using naive_slow:\", round((ns0 + nsq*N)/1000,1), \"seconds\")\n", "print(N, \"queries using naive_fast:\", round((nf0 + nfq*N)/1000,1), \"seconds\")\n", "print(N, \"queries using tunnel_fast:\", round((tf0 + tfq*N)/1000,1), \"seconds\")\n", "print(N, \"queries using kdtree_fast:\", round((kd0 + kdq*N)/1000,1), \"seconds\")\n", "print('')\n", "print(\"kd_tree_fast outperforms naive_fast above:\", int((kd0-nf0)/(nfq-kdq)), \"queries\")\n", "print(\"kd_tree_fast outperforms tunnel_fast above:\", int((kd0-tf0)/(tfq-kdq)), \"queries\")" ] }, { "cell_type": "markdown", "metadata": { "deletable": true, "editable": true }, "source": [ "The advantage of using KD-trees is much greater for more search set points, as KD-tree query complexity is O(log(N)), but the other algorithms are O(N), the same as the difference between using binary search versus linear search." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.3" } }, "nbformat": 4, "nbformat_minor": 0 }