{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Custom plotting library fully in Python!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from math import pi\n", "\n", "import numpy as np\n", "\n", "import branca\n", "\n", "from ipywidgets import VBox, IntSlider\n", "\n", "from ipycanvas import Canvas, MultiCanvas, hold_canvas" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Plot(MultiCanvas):\n", " def __init__(self, x, y, color=None, scheme=branca.colormap.linear.RdBu_11):\n", " super(Plot, self).__init__(3, width=800, height=600, sync_image_data=True)\n", "\n", " self.color = color\n", " self.scheme = scheme\n", "\n", " self.background_color = \"#f7f7f7\"\n", "\n", " self.init_plot(x, y)\n", "\n", " def init_plot(self, x, y, color=None, scheme=None):\n", " self.x = x\n", " self.y = y\n", " self.color = color if color is not None else self.color\n", " self.scheme = scheme if scheme is not None else self.scheme\n", "\n", " padding = 0.1\n", " padding_x = padding * self.width\n", " padding_y = padding * self.height\n", "\n", " # TODO Fix drawarea max: It should be (canvas.size - padding)\n", " self.drawarea = (\n", " drawarea_min_x,\n", " drawarea_min_y,\n", " drawarea_max_x,\n", " drawarea_max_y,\n", " ) = (\n", " padding_x,\n", " padding_y,\n", " self.width - 2 * padding_x,\n", " self.height - 2 * padding_y,\n", " )\n", "\n", " min_x, min_y, max_x, max_y = np.min(x), np.min(y), np.max(x), np.max(y)\n", "\n", " dx = max_x - min_x\n", " dy = max_y - min_y\n", "\n", " # Turns a data coordinate into pixel coordinate\n", " self.scale_x = lambda x: drawarea_max_x * (x - min_x) / dx + drawarea_min_x\n", " self.scale_y = (\n", " lambda y: drawarea_max_y * (1 - (y - min_y) / dy) + drawarea_min_y\n", " )\n", "\n", " # Turns a pixel coordinate into data coordinate\n", " self.unscale_x = lambda sx: (sx - drawarea_min_x) * dx / drawarea_max_x + min_x\n", " self.unscale_y = (\n", " lambda sy: (1 - ((sy - drawarea_min_y) / drawarea_max_y)) * dy + min_y\n", " )\n", "\n", " self.colormap = None\n", " if self.color is not None:\n", " self.colormap = self.scheme.scale(np.min(self.color), np.max(self.color))\n", "\n", " def draw_background(self):\n", " drawarea_min_x, drawarea_min_y, drawarea_max_x, drawarea_max_y = self.drawarea\n", "\n", " background = self[0]\n", "\n", " # Draw background\n", " background.fill_style = self.background_color\n", " background.global_alpha = 0.3\n", " background.fill_rect(\n", " drawarea_min_x, drawarea_min_y, drawarea_max_x, drawarea_max_y\n", " )\n", " background.global_alpha = 1\n", "\n", " # Draw grid and ticks\n", " n_lines = 10\n", " background.fill_style = \"black\"\n", " background.stroke_style = \"#8c8c8c\"\n", " background.line_width = 1\n", "\n", " for i in range(n_lines):\n", " j = i / (n_lines - 1)\n", " line_x = drawarea_max_x * j + drawarea_min_x\n", " line_y = drawarea_max_y * j + drawarea_min_y\n", "\n", " # Line on the y axis\n", " background.stroke_line(\n", " line_x, drawarea_min_y, line_x, drawarea_max_y + drawarea_min_y\n", " )\n", "\n", " # Line on the x axis\n", " background.stroke_line(\n", " drawarea_min_x, line_y, drawarea_max_x + drawarea_min_x, line_y\n", " )\n", "\n", " # Draw y tick\n", " background.text_align = \"right\"\n", " background.text_baseline = \"middle\"\n", " background.fill_text(\n", " \"{0:.2e}\".format(self.unscale_y(line_y)), drawarea_min_x * 0.95, line_y\n", " )\n", "\n", " # Draw x tick\n", " background.text_align = \"center\"\n", " background.text_baseline = \"top\"\n", " background.fill_text(\n", " \"{0:.2e}\".format(self.unscale_x(line_x)),\n", " line_x,\n", " drawarea_max_y + drawarea_min_y + drawarea_min_y * 0.05,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ScatterPlot(Plot):\n", " def __init__(\n", " self,\n", " x,\n", " y,\n", " size,\n", " color,\n", " scheme=branca.colormap.linear.RdBu_11,\n", " stroke_color=\"black\",\n", " ):\n", " super(ScatterPlot, self).__init__(x, y, color, scheme)\n", "\n", " self.dragging = False\n", " self.sizes = size\n", " self.stroke_color = stroke_color\n", "\n", " self.n_marks = min(x.shape[0], y.shape[0], size.shape[0], color.shape[0])\n", "\n", " # Index of the dragged point\n", " self.i_mark = -1\n", "\n", " self[2].on_mouse_down(self.mouse_down_handler)\n", " self[2].on_mouse_move(self.mouse_move_handler)\n", " self[2].on_mouse_up(self.mouse_up_handler)\n", "\n", " self.draw()\n", "\n", " def draw(self):\n", " with hold_canvas():\n", " self.clear()\n", " plot_layer = self[1]\n", "\n", " plot_layer.save()\n", "\n", " self.draw_background()\n", "\n", " # Draw scatter\n", " plot_layer.stroke_style = self.stroke_color\n", "\n", " for idx in range(self.n_marks):\n", " plot_layer.fill_style = self.colormap(self.color[idx])\n", "\n", " mark_x = self.scale_x(self.x[idx])\n", " mark_y = self.scale_y(self.y[idx])\n", " mark_size = self.sizes[idx]\n", "\n", " plot_layer.fill_circle(mark_x, mark_y, mark_size)\n", " plot_layer.stroke_circle(mark_x, mark_y, mark_size)\n", "\n", " plot_layer.restore()\n", "\n", " def mouse_down_handler(self, pixel_x, pixel_y):\n", " plot_layer = self[1]\n", "\n", " for idx in range(self.n_marks):\n", " mark_x = self.x[idx]\n", " mark_y = self.y[idx]\n", " mark_size = self.sizes[idx]\n", "\n", " if (\n", " pixel_x > self.scale_x(mark_x) - mark_size\n", " and pixel_x < self.scale_x(mark_x) + mark_size\n", " and pixel_y > self.scale_y(mark_y) - mark_size\n", " and pixel_y < self.scale_y(mark_y) + mark_size\n", " ):\n", " self.i_mark = idx\n", " self.dragging = True\n", "\n", " with hold_canvas():\n", " plot_layer.fill_style = self.background_color\n", " plot_layer.stroke_style = self.colormap(self.color[self.i_mark])\n", "\n", " plot_layer.fill_circle(\n", " self.scale_x(mark_x), self.scale_y(mark_y), mark_size\n", " )\n", " plot_layer.stroke_circle(\n", " self.scale_x(mark_x), self.scale_y(mark_y), mark_size\n", " )\n", " break\n", "\n", " def mouse_move_handler(self, pixel_x, pixel_y):\n", " if self.dragging and self.i_mark != -1:\n", " interaction_layer = self[2]\n", "\n", " unscaled_x = self.unscale_x(pixel_x)\n", " unscaled_y = self.unscale_y(pixel_y)\n", "\n", " with hold_canvas():\n", " interaction_layer.clear()\n", " interaction_layer.fill_style = self.colormap(self.color[self.i_mark])\n", " interaction_layer.stroke_style = self.stroke_color\n", "\n", " self.x[self.i_mark] = unscaled_x\n", " self.y[self.i_mark] = unscaled_y\n", "\n", " interaction_layer.fill_circle(pixel_x, pixel_y, self.sizes[self.i_mark])\n", " interaction_layer.stroke_circle(\n", " pixel_x, pixel_y, self.sizes[self.i_mark]\n", " )\n", "\n", " def mouse_up_handler(self, pixel_x, pixel_y):\n", " self.dragging = False\n", "\n", " self.draw()\n", "\n", " interaction_layer = self[2]\n", " interaction_layer.clear()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LinePlot(Plot):\n", " def __init__(self, x, y, line_color=\"#749cb8\", line_width=2):\n", " super(LinePlot, self).__init__(x, y)\n", "\n", " self.line_color = line_color\n", " self.line_width = line_width\n", "\n", " self.draw()\n", "\n", " def update(self, x, y, line_color=None, line_width=None):\n", " self.init_plot(x, y)\n", "\n", " self.line_color = line_color if line_color is not None else self.line_color\n", " self.line_width = line_width if line_width is not None else self.line_width\n", "\n", " self.draw()\n", "\n", " def draw(self):\n", " with hold_canvas():\n", " self.clear()\n", " plot_layer = self[1]\n", " plot_layer.save()\n", "\n", " self.draw_background()\n", "\n", " # Draw lines\n", " n_points = min(self.x.shape[0], self.y.shape[0])\n", "\n", " plot_layer.stroke_style = self.line_color\n", " plot_layer.line_width = self.line_width\n", " plot_layer.line_join = \"bevel\"\n", " plot_layer.line_cap = \"round\"\n", "\n", " plot_layer.stroke_lines(\n", " np.stack((self.scale_x(self.x), self.scale_y(self.y)), axis=1)\n", " )\n", "\n", " plot_layer.restore()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class HeatmapPlot(Plot):\n", " def __init__(self, x, y, color, scheme=branca.colormap.linear.RdBu_11):\n", " super(HeatmapPlot, self).__init__(x, y, color, scheme)\n", "\n", " self.draw()\n", "\n", " def draw(self):\n", " outof_x_bound = lambda idx: True if idx >= x.shape[0] or idx < 0 else False\n", " outof_y_bound = lambda idx: True if idx >= y.shape[0] or idx < 0 else False\n", "\n", " with hold_canvas():\n", " self.clear()\n", " plot_layer = self[1]\n", " plot_layer.save()\n", "\n", " self.draw_background()\n", "\n", " # Draw heatmap\n", " n_marks = min(self.x.shape[0], self.y.shape[0])\n", "\n", " for x_idx in range(1, self.color.shape[0] - 1):\n", " for y_idx in range(1, self.color.shape[1] - 1):\n", " plot_layer.fill_style = self.colormap(self.color[x_idx][y_idx])\n", "\n", " rect_center = (\n", " self.scale_x(self.x[x_idx]),\n", " self.scale_y(self.y[y_idx]),\n", " )\n", " neighbours_x = (\n", " self.scale_x(self.x[x_idx - 1]),\n", " self.scale_x(self.x[x_idx + 1]),\n", " )\n", " neighbours_y = (\n", " self.scale_y(self.y[y_idx - 1]),\n", " self.scale_y(self.y[y_idx + 1]),\n", " )\n", "\n", " rect_top_left_corner = (\n", " (neighbours_x[0] + rect_center[0]) / 2,\n", " (neighbours_y[0] + rect_center[1]) / 2,\n", " )\n", " rect_low_right_corner = (\n", " (neighbours_x[1] + rect_center[0]) / 2,\n", " (neighbours_y[1] + rect_center[1]) / 2,\n", " )\n", "\n", " width = rect_low_right_corner[0] - rect_top_left_corner[0] + 0.5\n", " height = rect_low_right_corner[1] - rect_top_left_corner[1] - 0.5\n", "\n", " plot_layer.fill_rect(\n", " rect_top_left_corner[0], rect_top_left_corner[1], width, height\n", " )\n", "\n", " plot_layer.restore()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Scatter plot" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_points = 1_000" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Scatter marks are draggable! Move the mouse while clicking on them..." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = np.random.rand(n_points)\n", "y = np.random.rand(n_points)\n", "sizes = np.random.randint(2, 8, n_points)\n", "colors = np.random.rand(n_points) * 10 - 2\n", "\n", "plot = ScatterPlot(\n", " x, y, sizes, colors, branca.colormap.linear.viridis, stroke_color=\"white\"\n", ")\n", "plot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### You can retrieve the entire ``Canvas`` or a subpart of it using the ``get_image_data`` method" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "arr = plot.get_image_data(200, 300, 50, 100)\n", "arr.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot[1].stroke_style = \"red\"\n", "plot[1].line_width = 2\n", "plot[1].stroke_rect(200, 300, 50, 100)\n", "\n", "c = Canvas(width=50, height=100)\n", "c.put_image_data(arr, 0, 0)\n", "c" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Or you can save it to a file using ``to_file``" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot.to_file(\"my_scatter.png\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from ipywidgets import Image\n", "\n", "Image.from_file(\"my_scatter.png\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Line plot" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = np.linspace(0, 20, 500)\n", "y = np.sin(x)\n", "\n", "LinePlot(x, y, line_width=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "slider = IntSlider(description=\"Pow:\", min=1, max=10, step=1)\n", "\n", "x = np.linspace(-20, 20, 500)\n", "y = np.power(x, slider.value)\n", "\n", "power_plot = LinePlot(x, y, line_color=\"#32a852\", line_width=3)\n", "\n", "\n", "def on_slider_change(change):\n", " y = np.power(x, slider.value)\n", "\n", " power_plot.update(x, y)\n", "\n", "\n", "slider.observe(on_slider_change, \"value\")\n", "\n", "VBox((power_plot, slider))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n = 2_000\n", "x = np.linspace(0, 100, n)\n", "y = np.cumsum(np.random.randn(n))\n", "\n", "LinePlot(x, y, line_width=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Heatmap" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = np.linspace(-5, 5, 100)\n", "y = np.linspace(-5, 5, 100)\n", "x_grid, y_grid = np.meshgrid(x, y)\n", "color = np.sin(x_grid + y_grid**2) + np.cos(x_grid**2 + y_grid**2)\n", "\n", "HeatmapPlot(x, y, color, scheme=branca.colormap.linear.RdYlBu_05)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.0" } }, "nbformat": 4, "nbformat_minor": 4 }