{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tree crown detection using DeepForest\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Context\n", "### Purpose\n", "Detect tree crown using a state-of-art Deep Learning model for object detection.\n", "\n", "### Modelling approach\n", "A prebuilt Deep Learning model, named *DeepForest* {cite:p}`Weinstein2020_MEE`, is used to predict individual tree crowns from an airborne RGB image. *DeepForest* was trained on data from the National Ecological Observatory Network (NEON). _DeepForest_ was implemented in Python 3.7 using initally Tensorflow v1.14 but later moved to Pytorch. Further details can be found in the [package documentation](https://deepforest.readthedocs.io/en/latest/).\n", "\n", "### Highlights\n", "* Fetch a NEON sample image from a Zenodo repository.\n", "* Retrieve and plot the reference annotations (bounding boxes) for the target image.\n", "* Load and use a pretrained *DeepForest* model to generate full-image or tile-wise prediction.\n", "* Indicate the pros and cons of full-image and tile-wise prediction.\n", "\n", ":::{note}\n", "The author acknowledges [DeepForest](https://deepforest.readthedocs.io/en/latest/) contributors. Some code snippets were extracted from DeepForest [GitHub public repository](https://github.com/weecology/DeepForest).\n", ":::" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "import glob\n", "import os\n", "import urllib\n", "import numpy as np\n", "\n", "import intake\n", "import matplotlib.pyplot as plt\n", "import xmltodict\n", "import cv2\n", "\n", "import torch\n", "\n", "from shapely.geometry import box\n", "import pandas as pd\n", "from geopandas import GeoDataFrame\n", "import xarray as xr\n", "import panel as pn\n", "import holoviews as hv\n", "import hvplot.pandas\n", "import hvplot.xarray\n", "from skimage.exposure import equalize_hist\n", "\n", "import pooch\n", "\n", "import warnings\n", "warnings.filterwarnings(action='ignore')\n", "\n", "hv.extension('bokeh', width=100)\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set project structure" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "notebook_folder = './notebook'\n", "if not os.path.exists(notebook_folder):\n", " os.makedirs(notebook_folder)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fetch a RGB image from Zenodo\n", "\n", "Fetch a sample image from a publicly accessible location." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pooch.retrieve(\n", " url=\"doi:10.5281/zenodo.3459803/2018_MLBS_3_541000_4140000_image_crop.tif\",\n", " known_hash=\"md5:01a7cf23b368ff9e006fda8fe9ca4c8c\",\n", " path=notebook_folder,\n", " fname=\"2018_MLBS_3_541000_4140000_image_crop.tif\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# set catalogue location\n", "catalog_file = os.path.join(notebook_folder, 'catalog.yaml')\n", "\n", "with open(catalog_file, 'w') as f:\n", " f.write('''\n", "sources:\n", " NEONTREE_rgb:\n", " driver: xarray_image\n", " description: 'NeonTreeEvaluation RGB images (collection)'\n", " args:\n", " urlpath: \"{{ CATALOG_DIR }}/2018_MLBS_3_541000_4140000_image_crop.tif\"\n", " ''')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load an intake catalog for the downloaded data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cat_tc = intake.open_catalog(catalog_file)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load sample image\n", "\n", "Here we use `intake` to load the image through `dask`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tc_rgb = cat_tc[\"NEONTREE_rgb\"].to_dask()\n", "tc_rgb" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load and prepare labels" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# functions to load xml and extract bounding boxes\n", "\n", "# function to create ordered dictionary of .xml annotation files\n", "def loadxml(imagename):\n", " imagename = imagename.replace('.tif','')\n", " fullurl = \"https://raw.githubusercontent.com/weecology/NeonTreeEvaluation/master/annotations/\" + imagename + \".xml\"\n", " file = urllib.request.urlopen(fullurl)\n", " data = file.read()\n", " file.close()\n", " data = xmltodict.parse(data)\n", " return data\n", "\n", "# function to extract bounding boxes\n", "def extractbb(i):\n", " bb = [f['bndbox'] for f in allxml[i]['annotation']['object']]\n", " return bb" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "filenames = glob.glob(os.path.join(notebook_folder, '*.tif'))\n", "filesn = [os.path.basename(i) for i in filenames]\n", "\n", "allxml = [loadxml(i) for i in filesn]\n", "bball = [extractbb(i) for i in range(0,len(allxml))]\n", "print(len(bball))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualise image and labels" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# function to plot images\n", "def cv2_imshow(a, **kwargs):\n", " a = a.clip(0, 255).astype('uint8')\n", " # cv2 stores colors as BGR; convert to RGB\n", " if a.ndim == 3:\n", " if a.shape[2] == 4:\n", " a = cv2.cvtColor(a, cv2.COLOR_BGRA2RGBA)\n", " else:\n", " a = cv2.cvtColor(a, cv2.COLOR_BGR2RGB)\n", "\n", " return plt.imshow(a, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image = tc_rgb\n", "\n", "# plot predicted bbox\n", "image2 = image.values.copy()\n", "target_bbox = bball[0]\n", "print(type(target_bbox))\n", "print(target_bbox[0:2])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "for row in target_bbox:\n", " cv2.rectangle(image2, (int(row[\"xmin\"]), int(row[\"ymin\"])), (int(row[\"xmax\"]), int(row[\"ymax\"])), (0,255,255), thickness=2, lineType=cv2.LINE_AA)\n", "\n", "plot_reference = plt.figure(figsize=(15,15))\n", "cv2_imshow(np.flip(image2,2))\n", "plt.title('Reference labels',fontsize='xx-large')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load *DeepForest* pretrained model\n", "\n", "Now we're going to load and use a pretrained model from the `deepforest` package." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from deepforest import main\n", "\n", "# load deep forest model\n", "model = main.deepforest()\n", "model.use_release()\n", "model.current_device = torch.device(\"cpu\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pred_boxes = model.predict_image(image=image.values)\n", "print(pred_boxes.head(5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "image3 = image.values.copy() \n", "\n", "for index, row in pred_boxes.iterrows():\n", " cv2.rectangle(image3, (int(row[\"xmin\"]), int(row[\"ymin\"])), (int(row[\"xmax\"]), int(row[\"ymax\"])), (0,255,255), thickness=2, lineType=cv2.LINE_AA)\n", "\n", "plot_fullimage = plt.figure(figsize=(15,15))\n", "cv2_imshow(np.flip(image3,2))\n", "plt.title('Full-image predictions',fontsize='xx-large')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Comparison full image prediction and reference labels\n", "\n", "Let's compare the labels and predictions over the tested image." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "plot_referandfullimage = plt.figure(figsize=(15,15))\n", "ax1 = plt.subplot(1, 2, 1), cv2_imshow(np.flip(image2,2))\n", "ax1[0].set_title('Reference labels',fontsize='xx-large')\n", "ax2 = plt.subplot(1, 2, 2), cv2_imshow(np.flip(image3,2))\n", "ax2[0].set_title('Full-image predictions', fontsize='xx-large')\n", "plt.show() # To show figure" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Interpretation:**\n", "\n", "* It seems the pretrained model doesn't perform well with the tested image.\n", "* The low performance might be explained due to the pretrained model used 10 cm resolution images.\n", "\n", "## Tile-based prediction\n", "\n", "To optimise the predictions, the DeepForest can be run [tile-wise](https://deepforest.readthedocs.io/en/latest/better.html).\n", "\n", "The following cells show how to define the optimal window i.e. tile size." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from deepforest import preprocess\n", "\n", "#Create windows of 400px\n", "windows = preprocess.compute_windows(image.values, patch_size=400,patch_overlap=0)\n", "print(f'We have {len(windows)} windows in the image')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "#Loop through a few sample windows, crop and predict\n", "plot_tilewindows, axes, = plt.subplots(nrows=2,ncols=2, figsize=(15,15))\n", "axes = axes.flatten()\n", "for index2 in range(4):\n", " crop = image.values[windows[index2].indices()]\n", " #predict in bgr channel order, color predictions in red.\n", " boxes = model.predict_image(image=np.flip(crop[...,::-1],2), return_plot = True)\n", "\n", " #but plot in rgb channel order\n", " axes[index2].imshow(boxes[...,::-1])\n", " axes[index2].set_title(f'Prediction in Window {index2 + 1} out of {len(windows)}', fontsize='xx-large')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once a suitable tile size is defined, we can run in a batch using the `predict_tile` function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "tile = model.predict_tile(image=image.values,return_plot=False,patch_overlap=0,iou_threshold=0.05,patch_size=400)\n", "\n", "# plot predicted bbox\n", "image_tile = image.values.copy()\n", "\n", "for index, row in tile.iterrows():\n", " cv2.rectangle(image_tile, (int(row[\"xmin\"]), int(row[\"ymin\"])), (int(row[\"xmax\"]), int(row[\"ymax\"])), (0, 255, 255), thickness=2, lineType=cv2.LINE_AA)\n", "\n", "plot_tilewise = plt.figure(figsize=(15,15))\n", "ax1 = plt.subplot(1, 2, 1), cv2_imshow(np.flip(image2,2))\n", "ax1[0].set_title('Reference labels',fontsize='xx-large')\n", "ax2 = plt.subplot(1, 2, 2), cv2_imshow(np.flip(image_tile,2))\n", "ax2[0].set_title('Tile-wise predictions', fontsize='xx-large')\n", "plt.show() # To show figure" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Interpretation**\n", "\n", "* The tile-based prediction provides more reasonable results than predicting over the whole image.\n", "* While the prediction looks closer to the reference labels, there seem to be some tiles edges artefacts. This will require further investigation i.e. inspecting the `deepforest` tile-wise prediction function to understand how predictions from different tiles are combined after the model has made them." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Interactive plots\n", "\n", "The plot below summarises above static plots by interactively comparing bounding boxes and scores of full-image and tile-wise predictions. To zoom-in the reference NEON RGB image with its original resolution change `rasterize=True` to `rasterize=False`." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "## function to convert bbox in dictionary to geopandas\n", "def bbox_to_geopandas(bbox_df):\n", " geometry = [box(x1, y1, x2, y2) for x1,y1,x2,y2 in zip(bbox_df.xmin, bbox_df.ymin, bbox_df.xmax, bbox_df.ymax)]\n", " poly_geo = GeoDataFrame(bbox_df, geometry=geometry)\n", " return poly_geo\n", "\n", "## prepare reference and prediction bbox\n", "### convert data types for reference bbox dictionary\n", "reference = pd.DataFrame.from_dict(target_bbox)\n", "reference[['xmin', 'ymin', 'xmax', 'ymax']] = reference[['xmin', 'ymin', 'xmax', 'ymax']].astype(int)\n", "\n", "poly_reference = bbox_to_geopandas(reference)\n", "poly_prediction_image = bbox_to_geopandas(pred_boxes)\n", "poly_prediction_tile = bbox_to_geopandas(tile)\n", "\n", "## settings for hvplot objects\n", "settings_vector = dict(fill_color=None, width=400, height=400, clim=(0,1), fontsize={'title': '110%'})\n", "settings_image = dict(x='x', y='y', data_aspect=1, xaxis=False, yaxis=None)\n", "\n", "## create hvplot objects\n", "plot_RGB = tc_rgb.hvplot.rgb(**settings_image, bands='channel', hover=False, rasterize=True)\n", "plot_vector_reference = poly_reference.hvplot(hover_cols=False, legend=False).opts(title='Reference labels', alpha=1, **settings_vector)\n", "plot_vector_image = poly_prediction_image.hvplot(hover_cols=['score'], legend=False).opts(title='Full-image predictions', alpha=0.5, **settings_vector)\n", "plot_vector_tile = poly_prediction_tile.hvplot(hover_cols=['score'], legend=False).opts(title='Tile-wise predictions', alpha=0.5, **settings_vector)\n", "\n", "plot_comparison = pn.Row(pn.Column(plot_RGB * plot_vector_reference, \n", " plot_RGB * plot_vector_image),\n", " pn.Column(pn.Spacer(width=400, height=400),\n", " plot_RGB * plot_vector_tile), scroll=True)\n", "\n", "plot_comparison.embed()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "This notebook has demonstrated the use of:\n", "\n", "* `pooch` and `intake` package to fetch data from a Zenodo repository containing training data files of the [NeonTreeEvaluation Benchmark](https://zenodo.org/record/3459803#.YhI54xPP30o).\n", "* `deepforest` package to easily load and run a pretrained model for tree crown classification from very-high resolution RGB imagery.\n", "* The `tile-wise` option in `deepforest` considerably improve tree crown predictions. However, the user should define an optimal tile size.\n", "* `cv2` to generate static plots comparing reference against bounding boxes and scores of two prediction strategies, full-image and tile-wise predictions.\n", "* `hvplot` and `panel` to interactively compare both prediction strategies against reference labels." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Citing this Notebook" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Please see [CITATION.cff](https://github.com/eds-book/15d986da-2d7c-44fb-af71-700494485def/blob/main/CITATION.cff) for the full citation information. The citation file can be exported to APA or BibTex formats (learn more [here](https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-citation-files))." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Additional information\n", "\n", "**Review**: This notebook has been reviewed by one or more members of the Environmental Data Science book community. The open review is available [here](https://github.com/alan-turing-institute/environmental-ds-book/pull/5).\n", "\n", "**License**: The code in this notebook is licensed under the MIT License. The Environmental Data Science book is licensed under the Creative Commons by Attribution 4.0 license. See further details [here](https://github.com/alan-turing-institute/environmental-ds-book/blob/main/LICENSE).\n", "\n", "**Contact**: If you have any suggestion or report an issue with this notebook, feel free to [create an issue](https://github.com/alan-turing-institute/environmental-ds-book/issues/new/choose) or send a direct message to [environmental.ds.book@gmail.com](mailto:environmental.ds.book@gmail.com)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "from datetime import date\n", "\n", "print('Notebook repository version: v2025.6.0')\n", "print(f'Last tested: {date.today()}')" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "end-to-end_treecrown.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 4 }