{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Land Cover Classification\n", "\n", "In this tutorial, we'll learn how to apply a land cover classification model to imagery hosted in the [Planetary Computer Data Catalog](https://planetarycomputer.microsoft.com/catalog). In the process, we'll see how to:\n", "\n", "1. Create a Dask Cluster where each worker has a GPU\n", "2. Use the Planetary Computer's metadata query API to search for data matching certain conditions\n", "3. Stitch many files into a mosaic of images\n", "4. Apply a PyTorch model to chunks of an xarray DataArray in parallel.\n", "\n", "If you're running this on the [Planetary Computer Hub](http://planetarycomputer.microsoft.com/compute), make sure to choose the **GPU - PyTorch** profile when presented with the form to choose your environment.\n", "\n", "### Land cover background\n", "\n", "We'll work with [NAIP](https://planetarycomputer.microsoft.com/dataset/naip) data, a collection of high-resolution aerial imagery covering the continental US. We'll apply a PyTorch model trained for land cover classification to the data. The model takes in an image and classifies each pixel into a category (e.g. \"water\", \"tree canopy\", \"road\", etc.). We're using a neural network trained by data scientists from Microsoft's [AI for Good](https://www.microsoft.com/en-us/ai/ai-for-good) program. We'll use the model to analyze how land cover changed over a portion of Maryland from 2013 to 2017.\n", "\n", "### Scaling our computation\n", "\n", "This is a somewhat large computation, and we'll handle the scale in two ways:\n", "\n", "1. We'll use a cloud-native workflow, reading data directly from Blob Storage into memory on VMs running in Azure, skipping a slow local download step.\n", "2. Processing images in parallel, using multiple threads on a to load and preprocess the data, before moving it to the GPU for prediction.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-09-20 16:11:38,787 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "/proxy/8787/status\n" ] } ], "source": [ "from dask.distributed import Client\n", "from dask_cuda import LocalCUDACluster\n", "\n", "cluster = LocalCUDACluster(threads_per_worker=4)\n", "client = Client(cluster)\n", "print(f\"/proxy/{client.scheduler_info()['services']['dashboard']}/status\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Make sure to open the Dask Dashboard, either by clicking the *Dashboard* link or by using the dask-labextension to lay out your workspace (See [Scale with Dask](https://planetarycomputer.microsoft.com/docs/quickstarts/scale-with-dask/#Open-the-dashboard).).\n", "\n", "Next, we'll load the model. It's available in a public Azure Blob Storage container. We'll download the model locally and construct the [Unet](https://smp.readthedocs.io/en/latest/models.html#unet) using `segmentation_models_pytorch`. Deep learning models can be somewhat large and difficult to seralize, so we'll make sure to [load it directly on the worker](https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask) using `Client.submit`. This returns a `Future` pointing to the model, which we'll use later on." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import azure.storage.blob\n", "from pathlib import Path\n", "import segmentation_models_pytorch\n", "import torch\n", "import warnings\n", "\n", "# ignore SyntaxWarning in pretrainedmodels\n", "warnings.filterwarnings(\"ignore\", category=SyntaxWarning)\n", "\n", "\n", "def load_model():\n", " p = Path(\"unet_both_lc.pt\")\n", " if not p.exists():\n", " blob_client = azure.storage.blob.BlobClient(\n", " account_url=\"https://naipeuwest.blob.core.windows.net/\",\n", " container_name=\"naip-models\",\n", " blob_name=\"unet_both_lc.pt\",\n", " )\n", "\n", " with p.open(\"wb\") as f:\n", " f.write(blob_client.download_blob().readall())\n", "\n", " model = segmentation_models_pytorch.Unet(\n", " encoder_name=\"resnet18\",\n", " encoder_depth=3,\n", " encoder_weights=None,\n", " decoder_channels=(128, 64, 64),\n", " in_channels=4,\n", " classes=13,\n", " )\n", " model.load_state_dict(torch.load(\"unet_both_lc.pt\", map_location=\"cuda:0\"))\n", "\n", " device = torch.device(\"cuda\")\n", " model = model.to(device)\n", " return model" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "remote_model = client.submit(load_model)\n", "print(remote_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data discovery\n", "\n", "Suppose we've been tasked with analyzing how land use changed from 2013 to 2017 for a region of Maryland. The full NAIP dataset consists of millions of images. How do we find the few hundred files that we care about?\n", "\n", "With the Planetary Computer's **metadata query API**, that's straightforward. First, we'll define our area of interest as a bounding box." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "bbox = [-77.9754638671875, 38.58037909468592, -76.37969970703125, 39.812755695478124]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we'll use `pystac_client` to query the Planetary Computer's STAC endpoint. We'll filter the results by space (to return only images touching our area of interest) and time (to return a set of images from 2013, and a second set for 2017)." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import pystac_client\n", "\n", "api = pystac_client.Client.open(\"https://planetarycomputer.microsoft.com/api/stac/v1/\")\n", "search_2013 = api.search(\n", " bbox=bbox,\n", " datetime=\"2012-12-31T00:00:00Z/2014-01-01T00:00:00Z\",\n", " collections=[\"naip\"],\n", ")\n", "\n", "search_2017 = api.search(\n", " bbox=bbox,\n", " datetime=\"2016-12-31T00:00:00Z/2018-01-01T00:00:00Z\",\n", " collections=[\"naip\"],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each item in those results is a single stac `Item`, which includes URLs to cloud-optimized GeoTIFF files stored in Azure Blob Storage.\n", "\n", "### Aligning images\n", "\n", "We have URLs to many files in Blob Storage. We want to treat all those as one big, logical dataset, so we'll use some open-source libraries to stitch them all together.\n", "\n", "[stac-vrt](https://stac-vrt.readthedocs.io/) will take a collection of STAC items and efficiently build a [GDAL VRT](https://gdal.org/drivers/raster/vrt.html)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2013: 440 items\n", "2017: 443 items\n" ] } ], "source": [ "import stac_vrt\n", "\n", "data_2013 = search_2013.get_all_items_as_dict()[\"features\"]\n", "data_2017 = search_2017.get_all_items_as_dict()[\"features\"]\n", "\n", "print(\"2013:\", len(data_2013), \"items\")\n", "print(\"2017:\", len(data_2017), \"items\")\n", "\n", "naip_2013 = stac_vrt.build_vrt(\n", " data_2013, block_width=512, block_height=512, data_type=\"Byte\"\n", ")\n", "mosaic_2017 = stac_vrt.build_vrt(\n", " data_2017, block_width=512, block_height=512, data_type=\"Byte\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once we have a pair of VRTs (one per year), we use [rasterio.warp](https://rasterio.readthedocs.io/en/latest/api/rasterio.warp.html) to make sure they're aligned." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import rasterio\n", "\n", "a = rasterio.open(naip_2013)\n", "naip_2017 = rasterio.vrt.WarpedVRT(\n", " rasterio.open(mosaic_2017), transform=a.transform, height=a.height, width=a.width\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[xarray](https://xarray.pydata.org/en/stable/) provides a convenient data structure for working with large, n-dimensional, labeled datasets like this. [rioxarray](https://corteva.github.io/rioxarray/stable/) is an engine for reading datasets like this into xarray." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (time: 2, band: 4, y: 149498, x: 145987)>\n",
       "dask.array<concatenate, shape=(2, 4, 149498, 145987), dtype=uint8, chunksize=(1, 4, 8192, 8192), chunktype=numpy.ndarray>\n",
       "Coordinates:\n",
       "  * band         (band) int64 1 2 3 4\n",
       "  * x            (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 3.88e+05 3.88e+05\n",
       "  * y            (y) float64 4.418e+06 4.418e+06 ... 4.269e+06 4.269e+06\n",
       "    spatial_ref  int64 0\n",
       "  * time         (time) int64 2013 2017\n",
       "Attributes:\n",
       "    scale_factor:  1.0\n",
       "    add_offset:    0.0
" ], "text/plain": [ "\n", "dask.array\n", "Coordinates:\n", " * band (band) int64 1 2 3 4\n", " * x (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 3.88e+05 3.88e+05\n", " * y (y) float64 4.418e+06 4.418e+06 ... 4.269e+06 4.269e+06\n", " spatial_ref int64 0\n", " * time (time) int64 2013 2017\n", "Attributes:\n", " scale_factor: 1.0\n", " add_offset: 0.0" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "import pandas as pd\n", "import xarray as xr\n", "import rioxarray\n", "\n", "ds1 = rioxarray.open_rasterio(naip_2013, chunks=(4, 8192, 8192), lock=False)\n", "ds2 = rioxarray.open_rasterio(naip_2017, chunks=(4, 8192, 8192), lock=False)\n", "\n", "ds = xr.concat([ds1, ds2], dim=pd.Index([2013, 2017], name=\"time\"))\n", "ds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pre-processing for the neural network\n", "\n", "Now we have a big dataset, that's been pixel-aligned on a grid for the two time periods.\n", "The model requires a bit of pre-processing upfront. We'll define a couple variables with the per-band mean and standard deviation for each year." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "bands = xr.DataArray(\n", " [1, 2, 3, 4], name=\"band\", dims=[\"band\"], coords={\"band\": [1, 2, 3, 4]}\n", ")\n", "NAIP_2013_MEANS = xr.DataArray(\n", " np.array([117.00, 130.75, 122.50, 159.30], dtype=\"float32\"),\n", " name=\"mean\",\n", " coords=[bands],\n", ")\n", "NAIP_2013_STDS = xr.DataArray(\n", " np.array([38.16, 36.68, 24.30, 66.22], dtype=\"float32\"),\n", " name=\"mean\",\n", " coords=[bands],\n", ")\n", "NAIP_2017_MEANS = xr.DataArray(\n", " np.array([72.84, 86.83, 76.78, 130.82], dtype=\"float32\"),\n", " name=\"std\",\n", " coords=[bands],\n", ")\n", "NAIP_2017_STDS = xr.DataArray(\n", " np.array([41.78, 34.66, 28.76, 58.95], dtype=\"float32\"),\n", " name=\"mean\",\n", " coords=[bands],\n", ")\n", "\n", "mean = xr.concat([NAIP_2013_MEANS, NAIP_2017_MEANS], dim=\"time\")\n", "std = xr.concat([NAIP_2013_STDS, NAIP_2017_STDS], dim=\"time\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With those constants defined, we can normalize the data by subtracting the mean and dividing by the standard deviation.\n", "We'll also fix an issue the model had with partial chunks by dropping some pixels from the bottom-right corner." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (time: 2, band: 4, y: 149472, x: 145984)>\n",
       "dask.array<getitem, shape=(2, 4, 149472, 145984), dtype=float32, chunksize=(1, 4, 8192, 8192), chunktype=numpy.ndarray>\n",
       "Coordinates:\n",
       "  * band         (band) int64 1 2 3 4\n",
       "  * x            (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 3.88e+05 3.88e+05\n",
       "  * y            (y) float64 4.418e+06 4.418e+06 ... 4.269e+06 4.269e+06\n",
       "    spatial_ref  int64 0\n",
       "  * time         (time) int64 2013 2017
" ], "text/plain": [ "\n", "dask.array\n", "Coordinates:\n", " * band (band) int64 1 2 3 4\n", " * x (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 3.88e+05 3.88e+05\n", " * y (y) float64 4.418e+06 4.418e+06 ... 4.269e+06 4.269e+06\n", " spatial_ref int64 0\n", " * time (time) int64 2013 2017" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Normalize by per-year mean, std\n", "normalized = (ds - mean) / std\n", "\n", "# The Unet model doesn't like partial chunks, so we chop off the\n", "# last 1-31 pixels.\n", "slices = {}\n", "for coord in [\"y\", \"x\"]:\n", " remainder = len(ds.coords[coord]) % 32\n", " slice_ = slice(-remainder) if remainder else slice(None)\n", " slices[coord] = slice_\n", "\n", "normalized = normalized.isel(**slices)\n", "normalized" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Predicting land cover for each pixel\n", "\n", "At this point, we're ready to make predictions.\n", "\n", "We'll apply the model to the entire dataset, taking care to not over-saturate the GPUs. The GPUs will work on relatively small \"chips\" which fit comfortably in memory. The prediction, which comes from `model(data)`, will happen on the GPU so that it's nice and fast.\n", "\n", "Stepping up a level, we have Dask chunks. This is just a regular NumPy array. We'll break each chunk into a bunch of chips (using `dask.array.core.slices_from_chunks`) and get a prediction for each chip." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import dask.array\n", "\n", "\n", "def predict_chip(data: torch.Tensor, model) -> torch.Tensor:\n", " # Input is GPU, output is GPU.\n", " with torch.no_grad():\n", " result = model(data).argmax(dim=1).to(torch.uint8)\n", " return result.to(\"cpu\")\n", "\n", "\n", "def copy_and_predict_chunked(tile, model, token=None):\n", " has_time = tile.ndim == 4\n", " if has_time:\n", " assert tile.shape[0] == 1\n", " tile = tile[0]\n", "\n", " slices = dask.array.core.slices_from_chunks(dask.array.empty(tile.shape).chunks)\n", " out = np.empty(shape=tile.shape[1:], dtype=\"uint8\")\n", " device = torch.device(\"cuda\")\n", "\n", " for slice_ in slices:\n", " gpu_chip = torch.as_tensor(tile[slice_][np.newaxis, ...]).to(device)\n", " out[slice_[1:]] = predict_chip(gpu_chip, model).cpu().numpy()[0]\n", " if has_time:\n", " out = out[np.newaxis, ...]\n", " return out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Stepping up yet another level, we'll apply the predictions to the entire xarray DataArray. We'll use `DataArray.map_blocks` to do the prediction in parallel." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'predict-2109933971430175b909fc7cadef8ccb' (time: 2,\n",
       "                                                              y: 149472,\n",
       "                                                              x: 145984)>\n",
       "dask.array<predict, shape=(2, 149472, 145984), dtype=uint8, chunksize=(1, 8192, 8192), chunktype=numpy.ndarray>\n",
       "Coordinates:\n",
       "  * x            (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 3.88e+05 3.88e+05\n",
       "  * y            (y) float64 4.418e+06 4.418e+06 ... 4.269e+06 4.269e+06\n",
       "    spatial_ref  int64 0\n",
       "  * time         (time) int64 2013 2017
" ], "text/plain": [ "\n", "dask.array\n", "Coordinates:\n", " * x (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 3.88e+05 3.88e+05\n", " * y (y) float64 4.418e+06 4.418e+06 ... 4.269e+06 4.269e+06\n", " spatial_ref int64 0\n", " * time (time) int64 2013 2017" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "meta = np.array([[]], dtype=\"uint8\")[:0]\n", "\n", "predictions_array = normalized.data.map_blocks(\n", " copy_and_predict_chunked,\n", " meta=meta,\n", " drop_axis=1,\n", " model=remote_model,\n", " name=\"predict\",\n", ")\n", "\n", "predictions = xr.DataArray(\n", " predictions_array,\n", " coords=normalized.drop_vars(\"band\").coords,\n", " dims=(\"time\", \"y\", \"x\"),\n", ")\n", "predictions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So there's three levels:\n", "\n", "1. xarray DataArray, backed by a Dask Array\n", "2. NumPy arrays, which are subsets of the Dask Array\n", "3. Chips, which are subsets of the NumPy arrays\n", "\n", "We can kick off a computation by calling `predictions.persist()`. This should cause some activity on your Dask Dashboard." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'predict-2109933971430175b909fc7cadef8ccb' (time: 2, y: 200,\n",
       "                                                              x: 200)>\n",
       "array([[[1, 1, 1, ..., 1, 1, 1],\n",
       "        [1, 1, 1, ..., 1, 1, 1],\n",
       "        [1, 1, 1, ..., 1, 1, 1],\n",
       "        ...,\n",
       "        [1, 1, 1, ..., 1, 1, 1],\n",
       "        [1, 1, 1, ..., 1, 1, 1],\n",
       "        [1, 1, 1, ..., 1, 1, 1]],\n",
       "\n",
       "       [[3, 3, 3, ..., 3, 3, 3],\n",
       "        [3, 3, 3, ..., 3, 3, 3],\n",
       "        [3, 3, 3, ..., 3, 3, 3],\n",
       "        ...,\n",
       "        [1, 1, 1, ..., 1, 1, 1],\n",
       "        [1, 1, 1, ..., 1, 1, 1],\n",
       "        [1, 1, 1, ..., 1, 1, 1]]], dtype=uint8)\n",
       "Coordinates:\n",
       "  * x            (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 2.422e+05 2.422e+05\n",
       "  * y            (y) float64 4.418e+06 4.418e+06 ... 4.418e+06 4.418e+06\n",
       "    spatial_ref  int64 0\n",
       "  * time         (time) int64 2013 2017
" ], "text/plain": [ "\n", "array([[[1, 1, 1, ..., 1, 1, 1],\n", " [1, 1, 1, ..., 1, 1, 1],\n", " [1, 1, 1, ..., 1, 1, 1],\n", " ...,\n", " [1, 1, 1, ..., 1, 1, 1],\n", " [1, 1, 1, ..., 1, 1, 1],\n", " [1, 1, 1, ..., 1, 1, 1]],\n", "\n", " [[3, 3, 3, ..., 3, 3, 3],\n", " [3, 3, 3, ..., 3, 3, 3],\n", " [3, 3, 3, ..., 3, 3, 3],\n", " ...,\n", " [1, 1, 1, ..., 1, 1, 1],\n", " [1, 1, 1, ..., 1, 1, 1],\n", " [1, 1, 1, ..., 1, 1, 1]]], dtype=uint8)\n", "Coordinates:\n", " * x (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 2.422e+05 2.422e+05\n", " * y (y) float64 4.418e+06 4.418e+06 ... 4.418e+06 4.418e+06\n", " spatial_ref int64 0\n", " * time (time) int64 2013 2017" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions[:, :200, :200].compute()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each element of `predictions` is an integer encoding the class the PyTorch model things the pixel belongs to (tree canopy, building, water, etc.).\n", "\n", "Finally, we can compute the result we're interested in: Which pixels (spots on the earth) changed land cover over the four years, at least according to our model." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'predict-2109933971430175b909fc7cadef8ccb' (y: 149472,\n",
       "                                                              x: 145984)>\n",
       "dask.array<ne, shape=(149472, 145984), dtype=bool, chunksize=(8192, 8192), chunktype=numpy.ndarray>\n",
       "Coordinates:\n",
       "  * x            (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 3.88e+05 3.88e+05\n",
       "  * y            (y) float64 4.418e+06 4.418e+06 ... 4.269e+06 4.269e+06\n",
       "    spatial_ref  int64 0
" ], "text/plain": [ "\n", "dask.array\n", "Coordinates:\n", " * x (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 3.88e+05 3.88e+05\n", " * y (y) float64 4.418e+06 4.418e+06 ... 4.269e+06 4.269e+06\n", " spatial_ref int64 0" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "change = predictions.sel(time=2013) != predictions.sel(time=2017)\n", "change" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's a boolean array where `True` means \"this location changed\". We'll mask out the `predictions` array with `change`. The value `other=0` means \"no change\", so `changed_predictions` has just the predictions (the integer codes) where there was a change." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'predict-2109933971430175b909fc7cadef8ccb' (time: 2,\n",
       "                                                              y: 149472,\n",
       "                                                              x: 145984)>\n",
       "dask.array<where, shape=(2, 149472, 145984), dtype=uint8, chunksize=(1, 8192, 8192), chunktype=numpy.ndarray>\n",
       "Coordinates:\n",
       "  * x            (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 3.88e+05 3.88e+05\n",
       "  * y            (y) float64 4.418e+06 4.418e+06 ... 4.269e+06 4.269e+06\n",
       "    spatial_ref  int64 0\n",
       "  * time         (time) int64 2013 2017
" ], "text/plain": [ "\n", "dask.array\n", "Coordinates:\n", " * x (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 3.88e+05 3.88e+05\n", " * y (y) float64 4.418e+06 4.418e+06 ... 4.269e+06 4.269e+06\n", " spatial_ref int64 0\n", " * time (time) int64 2013 2017" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "changed_predictions = predictions.where(change, other=0)\n", "changed_predictions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Again, we can kick off some computation." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray 'predict-2109933971430175b909fc7cadef8ccb' (time: 2, y: 200,\n",
       "                                                              x: 200)>\n",
       "array([[[1, 1, 1, ..., 1, 1, 1],\n",
       "        [1, 1, 1, ..., 1, 1, 1],\n",
       "        [1, 1, 1, ..., 1, 1, 1],\n",
       "        ...,\n",
       "        [0, 0, 0, ..., 0, 0, 0],\n",
       "        [0, 0, 0, ..., 0, 0, 0],\n",
       "        [0, 0, 0, ..., 0, 0, 0]],\n",
       "\n",
       "       [[3, 3, 3, ..., 3, 3, 3],\n",
       "        [3, 3, 3, ..., 3, 3, 3],\n",
       "        [3, 3, 3, ..., 3, 3, 3],\n",
       "        ...,\n",
       "        [0, 0, 0, ..., 0, 0, 0],\n",
       "        [0, 0, 0, ..., 0, 0, 0],\n",
       "        [0, 0, 0, ..., 0, 0, 0]]], dtype=uint8)\n",
       "Coordinates:\n",
       "  * x            (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 2.422e+05 2.422e+05\n",
       "  * y            (y) float64 4.418e+06 4.418e+06 ... 4.418e+06 4.418e+06\n",
       "    spatial_ref  int64 0\n",
       "  * time         (time) int64 2013 2017
" ], "text/plain": [ "\n", "array([[[1, 1, 1, ..., 1, 1, 1],\n", " [1, 1, 1, ..., 1, 1, 1],\n", " [1, 1, 1, ..., 1, 1, 1],\n", " ...,\n", " [0, 0, 0, ..., 0, 0, 0],\n", " [0, 0, 0, ..., 0, 0, 0],\n", " [0, 0, 0, ..., 0, 0, 0]],\n", "\n", " [[3, 3, 3, ..., 3, 3, 3],\n", " [3, 3, 3, ..., 3, 3, 3],\n", " [3, 3, 3, ..., 3, 3, 3],\n", " ...,\n", " [0, 0, 0, ..., 0, 0, 0],\n", " [0, 0, 0, ..., 0, 0, 0],\n", " [0, 0, 0, ..., 0, 0, 0]]], dtype=uint8)\n", "Coordinates:\n", " * x (x) float64 2.42e+05 2.42e+05 2.42e+05 ... 2.422e+05 2.422e+05\n", " * y (y) float64 4.418e+06 4.418e+06 ... 4.418e+06 4.418e+06\n", " spatial_ref int64 0\n", " * time (time) int64 2013 2017" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "changed_predictions[:, :200, :200].compute()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's do some visual spot checking of our model. This does require processing the full-resolution images, so we need to limit things to something that fits in memory now." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "middle = ds.shape[2] // 2, ds.shape[3] // 2\n", "slice_y = slice(middle[0], middle[0] + 5_000)\n", "slice_x = slice(middle[1], middle[1] + 5_000)\n", "\n", "parts = [x.isel(y=slice_y, x=slice_x) for x in [ds, predictions, changed_predictions]]" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "ds_local, predictions_local, changed_predictions_local = dask.compute(*parts)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.colors\n", "from bokeh.models.tools import BoxZoomTool\n", "import panel\n", "import hvplot.xarray # noqa\n", "\n", "\n", "cmap = matplotlib.colors.ListedColormap(\n", " np.array(\n", " [\n", " (0, 0, 0),\n", " (0, 197, 255),\n", " (0, 168, 132),\n", " (38, 115, 0),\n", " (76, 230, 0),\n", " (163, 255, 115),\n", " (255, 170, 0),\n", " (255, 0, 0),\n", " (156, 156, 156),\n", " (0, 0, 0),\n", " (115, 115, 0),\n", " (230, 230, 0),\n", " (255, 255, 115),\n", " (197, 0, 255),\n", " ]\n", " )\n", " / 255\n", ")\n", "\n", "\n", "def logo(plot, element):\n", " plot.state.toolbar.logo = None\n", "\n", "\n", "zoom = BoxZoomTool(match_aspect=True)\n", "style_kwargs = dict(\n", " width=450,\n", " height=400,\n", " xaxis=False,\n", " yaxis=False,\n", ")\n", "kwargs = dict(\n", " x=\"x\",\n", " y=\"y\",\n", " cmap=cmap,\n", " rasterize=True,\n", " aggregator=\"mode\",\n", " colorbar=False,\n", " tools=[\"pan\", zoom, \"wheel_zoom\", \"reset\"],\n", " clim=(0, 12),\n", ")\n", "\n", "image_2013_plot = (\n", " ds_local.sel(time=2013)\n", " .hvplot.rgb(\n", " bands=\"band\",\n", " x=\"x\",\n", " y=\"y\",\n", " rasterize=True,\n", " title=\"NAIP 2013\",\n", " hover=False,\n", " **style_kwargs,\n", " )\n", " .opts(default_tools=[], hooks=[logo])\n", ")\n", "classification_2013_plot = (\n", " changed_predictions_local.sel(time=2013)\n", " .hvplot.image(title=\"Classification 2013\", **kwargs, **style_kwargs)\n", " .opts(default_tools=[])\n", ")\n", "\n", "image_2017_plot = (\n", " ds_local.sel(time=2017)\n", " .hvplot.rgb(\n", " bands=\"band\",\n", " x=\"x\",\n", " y=\"y\",\n", " rasterize=True,\n", " title=\"NAIP 2017\",\n", " hover=False,\n", " **style_kwargs,\n", " )\n", " .opts(default_tools=[], hooks=[logo])\n", ")\n", "classification_2017_plot = (\n", " changed_predictions_local.sel(time=2013)\n", " .hvplot.image(title=\"Classification 2017\", **kwargs, **style_kwargs)\n", " .opts(default_tools=[])\n", ")\n", "\n", "panel.GridBox(\n", " image_2013_plot,\n", " classification_2013_plot,\n", " image_2017_plot,\n", " classification_2017_plot,\n", " ncols=2,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That visualization uses [Panel](https://panel.holoviz.org/), a Python dashboarding library. In an interactive Jupyter Notebook you can pan and zoom around the large dataset.\n", "![](images/change.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Scale further\n", "\n", "This example created a local Dask \"cluster\" on this single node. You can scale your computation out to a true GPU cluster with Dask Gateway by setting the `gpu=True` option when creating a cluster.\n", "\n", "```python\n", "import dask_gateway\n", "\n", "N_WORKERS = 2\n", "g = dask_gateway.Gateway()\n", "options = g.cluster_options()\n", "options[\"gpu\"] = True\n", "options[\"worker_memory\"] = 25\n", "options[\"worker_cores\"] = 3\n", "options[\"environment\"] = {\n", " \"DASK_DISTRIBUTED__WORKERS__RESOURCES__GPU\": \"1\",\n", "}\n", "\n", "cluster = g.new_cluster(options)\n", "client = cluster.get_client()\n", "cluster.scale(N_WORKERS)\n", "```" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 4 }