{ "cells": [ { "cell_type": "markdown", "source": [ "\n", "\n", "\n", "| Author | Date | Vacancy | Reading time |\n", "| ----------- | ----------- | ---- | ---- |\n", "| Michael Feil | July 14th 2021 | Federated Learning in Healthcare, TU Munich | 20min |\n", "\n", "\n", "## Table of Contents:\n", "* [**Part 0**: Introduction](#ref0)\n", "* [**Part 1**: Create non-iid client distributions](#ref1)\n", "* [**Part 2**. Apply gradient descent for one client](#ref2)\n", "* [2.1: Right on track: gradient descent ftw!](#ref2.1)\n", "* [2.2: Taking the detour: an intro to gradient descent in python](#ref2.2)\n", "* [2.3: Gradient descent for both clients and the global distribution](#ref2.3)\n", "* [**Part 3**: FedAvg with non-iid client data](#ref3)\n", "* [3.1: Taking the detour: An programming intro to FedAvg](#ref3.1)\n", "* [3.2: Back on track: Federated Averaging](#ref3.2)\n", "* [3.3: How to mess with Federated Averaging](#ref3.3)\n", "* [**Part 4**: Our salvation: Posterior Averaging?](#ref4)\n", "* [4.1: FedPA: Beyond the simplified implementation ](#ref4.1)\n", "* [**Part 5**: Summary](#ref5)\n", "* [**Part 6**: References](#ref6)\n", "\n", "Hint: For interactivity, better open this blog post on Binder.org. [Simply press on this logo: ![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/michaelfeil/fl-in-healtcare-viz/HEAD?filepath=colab-viz-fedavg-vs-gd.ipynb)" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "# **Part 0**: Introduction \r\n", "Think about the following scenario: \r\n", "\r\n", "You have personal images on your smartphone. So do your friends. (of course they do, we live in 2021!) As a Machine Learning engineer your goal is to create a machine learning model, which does a good job at the classification of your images from your phone and automatically clusters them in various folders, e.g. \"cats\" and \"beergarden munich\". You also plan to help some of your friends sorting their images. \r\n", "\r\n", "In the past (or better say before 2017), your proposal would have been, that everyone of your friends would have uploaded their images to your cloud-based server. Then all of images would get aggregated, and you would train and test a ML model. There was just one drawback: None of our friends wanted to share their sensitive image data! - But now you have the answer: **Federated Learning (FL)**!\r\n", "\r\n", "The introduction of **Federated Averaging (FedAvg)** by McMahan et al.[[2](#z2)] in 2017 enables us now, to train a model without uploading any images to any kind of central server. Isn't that great!?! Let's look at the images below.\r\n", "\r\n", "\r\n", "\r\n", "
\r\n", " \r\n", " missing image\r\n", "
Figure1 © Image adapted from Peter Kairouz [3]
\r\n", "
\r\n", "\r\n", "\r\n", "As in the [Figure1](#Figure1) above, in the case of **cross-device** FL with FedAvg, our cloud-based central server will send an initial state of the model parameters to thousands of mobile devices at once. The clients then update their version of the model parameters using **stochastic gradient descent (SGD)**. By doing more SGD steps, we train more each **communication round**, and therefore improve the **communication efficiency**. The updated model parameters get communicated back to the central server, where the client parameters get aggregated using **weighted averaging**. The process then starts again into the next communication round with the slightly improved parameters of our model. \r\n" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "\n", "\n", "### Executive Summary\n", "In the following sections, we will discuss a suitable heterogenous data distribution for the clients [Part 1](#ref1), code and visualize gradient descent and derive the global posterior in [Part 2](#ref2). In [Part 3](#ref3), we will find some unexpected hiccups using FedAvg[[2](#z2)] and solve them by introducing a FedPA[[1](#z1)] in [Part 4](#ref4). " ], "metadata": {} }, { "cell_type": "code", "execution_count": 1, "source": [ "# import some utilities\r\n", "from computation_fedavg_gd import grad_descent, partial_derivative\r\n", "from computation_fedavg_gd import MultivariateGaussian\r\n", "from computation_fedavg_gd import PDFsManipulate\r\n", "\r\n", "# import the usual suspects (excluding pandas as pd)\r\n", "import matplotlib\r\n", "import matplotlib.pyplot as plt\r\n", "import numpy as np" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "# **Part 1**: Create non-iid client distributions \n", "\n", "First step: For our scenario from the introduction, we model two phones (clients) by selecting two suitable client distributions. The distributions symbolize our target function on each client, which we try to optimize. \n", "For this notebook, we plot with the following assumptions: \n", "- the model has just 2 tunable weights $\\theta$\n", "- there are 2 clients.\n", "- The $\\theta$ over each target function $J$ has a (Multivariate) Gaussian distribution ($\\mu$,$\\Sigma$).\n", "- since the data distribution of our clients differs, we assume a non independent and identically distributed $\\theta$ (**non-iid**) (i.e. we pick different parameters $\\mu$, $\\Sigma$ for each client).\n", "\n", "\n", "\n", "Further reading: The Multivariate Gaussian is defined in ```computation_fedavg_gd.py``` according to this formula:\n", "\n", "$$\n", "p(\\mathbf{x} \\mid \\mu, \\Sigma)=\\frac{1}{\\sqrt{(2 \\pi)^{d}|\\Sigma|}} \\exp \\left(-\\frac{1}{2}(\\mathbf{x}-\\mu)^{T} \\Sigma^{-1}(\\mathbf{x}-\\mu)\\right)\n", "$$" ], "metadata": {} }, { "cell_type": "code", "execution_count": 2, "source": [ "# If you are running this notebook blog interactivly on binder.org\r\n", "# after reading through everything, you can change the parameters here.\r\n", "mu1, sigma1 = np.array([0, 1]), np.array([[1.0, -0.9], [-0.5, 1.5]])\r\n", "mu2, sigma2 = np.array([0, -1.2]), np.array([[1.5, 0.6], [0.9, 1.0]])\r\n", "\r\n", "multivariate1 = MultivariateGaussian(\r\n", " mu=mu1, sigma=sigma1\r\n", ")\r\n", "multivariate2 = MultivariateGaussian(\r\n", " mu=mu2, sigma=sigma2\r\n", ")" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "# calculate at our target function over a sampling grid\r\n", "from plot_fedavg_gd import SAMPLE_GRID_VECTOR\r\n", "# sample over a grid from -3,5 to 3,5 in 2D\r\n", "J1 = multivariate1.evaluate(theta_vec=SAMPLE_GRID_VECTOR)\r\n", "J2 = multivariate2.evaluate(theta_vec=SAMPLE_GRID_VECTOR)\r\n", "print(f\"evaluated gaussian over a 2D sampling grid {SAMPLE_GRID_VECTOR.shape} for a third target vector {J1.shape}\")" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "evaluated gaussian over a 2D sampling grid (71, 71, 2) for a third target vector (71, 71)\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Lets create some fancy 3D-plots for our two client distributions. Note how the two distributions have different optima in our non-iid setting." ], "metadata": {} }, { "cell_type": "code", "execution_count": 19, "source": [ "%matplotlib notebook \r\n", "# lets visualize our distributions\r\n", "\r\n", "def client_function_draw(eval_sample, cmap = matplotlib.cm.Greens):\r\n", " # adjust some 3D settings\r\n", " fig = plt.figure()\r\n", " ax = fig.gca(projection=\"3d\")\r\n", " ax.view_init(elev=20.0, azim=-125)\r\n", " ax.dist = 7.5\r\n", " ax.set_zticks(np.linspace(0, 0.3, 7))\r\n", " # labels\r\n", " ax.set_xlabel(r\"${\\theta}_0$\")\r\n", " ax.set_ylabel(r\"${\\theta}_1$\")\r\n", " ax.set_zlabel(r\"J(${\\theta}_0$, ${\\theta}_1$)\")\r\n", "\r\n", " # plot a 3D curve\r\n", " ax.plot_wireframe(\r\n", " SAMPLE_GRID_VECTOR[:,:,0],\r\n", " SAMPLE_GRID_VECTOR[:,:,1],\r\n", " eval_sample,\r\n", " rstride=3,\r\n", " cstride=3,\r\n", " linewidth=0.5,\r\n", " antialiased=True,\r\n", " cmap=matplotlib.cm.Greens,\r\n", " zorder=1,\r\n", " )\r\n", " # make a 2D contour plot in cmap color on the bottom\r\n", " ax.contourf(\r\n", " SAMPLE_GRID_VECTOR[:,:,0],\r\n", " SAMPLE_GRID_VECTOR[:,:,1],\r\n", " eval_sample,\r\n", " zdir=\"z\",\r\n", " offset=-np.max(eval_sample)/20,\r\n", " cmap=cmap,\r\n", " zorder=0,\r\n", " )\r\n", " return ax\r\n", "\r\n", "print(\"**client 1: **\")\r\n", "fig = client_function_draw(J1)\r\n", "print(\"**client 2: **\")\r\n", "fig = client_function_draw(J2, cmap=\"Oranges\")" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "**client 1: **\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "application/javascript": "/* Put everything inside the global mpl namespace */\nwindow.mpl = {};\n\n\nmpl.get_websocket_type = function() {\n if (typeof(WebSocket) !== 'undefined') {\n return WebSocket;\n } else if (typeof(MozWebSocket) !== 'undefined') {\n return MozWebSocket;\n } else {\n alert('Your browser does not have WebSocket support. ' +\n 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n 'Firefox 4 and 5 are also supported but you ' +\n 'have to enable WebSockets in about:config.');\n };\n}\n\nmpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n this.id = figure_id;\n\n this.ws = websocket;\n\n this.supports_binary = (this.ws.binaryType != undefined);\n\n if (!this.supports_binary) {\n var warnings = document.getElementById(\"mpl-warnings\");\n if (warnings) {\n warnings.style.display = 'block';\n warnings.textContent = (\n \"This browser does not support binary websocket messages. \" +\n \"Performance may be slow.\");\n }\n }\n\n this.imageObj = new Image();\n\n this.context = undefined;\n this.message = undefined;\n this.canvas = undefined;\n this.rubberband_canvas = undefined;\n this.rubberband_context = undefined;\n this.format_dropdown = undefined;\n\n this.image_mode = 'full';\n\n this.root = $('
');\n this._root_extra_style(this.root)\n this.root.attr('style', 'display: inline-block');\n\n $(parent_element).append(this.root);\n\n this._init_header(this);\n this._init_canvas(this);\n this._init_toolbar(this);\n\n var fig = this;\n\n this.waiting = false;\n\n this.ws.onopen = function () {\n fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n fig.send_message(\"send_image_mode\", {});\n if (mpl.ratio != 1) {\n fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n }\n fig.send_message(\"refresh\", {});\n }\n\n this.imageObj.onload = function() {\n if (fig.image_mode == 'full') {\n // Full images could contain transparency (where diff images\n // almost always do), so we need to clear the canvas so that\n // there is no ghosting.\n fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n }\n fig.context.drawImage(fig.imageObj, 0, 0);\n };\n\n this.imageObj.onunload = function() {\n fig.ws.close();\n }\n\n this.ws.onmessage = this._make_on_message_function(this);\n\n this.ondownload = ondownload;\n}\n\nmpl.figure.prototype._init_header = function() {\n var titlebar = $(\n '
');\n var titletext = $(\n '
');\n titlebar.append(titletext)\n this.root.append(titlebar);\n this.header = titletext[0];\n}\n\n\n\nmpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n\n}\n\n\nmpl.figure.prototype._root_extra_style = function(canvas_div) {\n\n}\n\nmpl.figure.prototype._init_canvas = function() {\n var fig = this;\n\n var canvas_div = $('
');\n\n canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n\n function canvas_keyboard_event(event) {\n return fig.key_event(event, event['data']);\n }\n\n canvas_div.keydown('key_press', canvas_keyboard_event);\n canvas_div.keyup('key_release', canvas_keyboard_event);\n this.canvas_div = canvas_div\n this._canvas_extra_style(canvas_div)\n this.root.append(canvas_div);\n\n var canvas = $('');\n canvas.addClass('mpl-canvas');\n canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n\n this.canvas = canvas[0];\n this.context = canvas[0].getContext(\"2d\");\n\n var backingStore = this.context.backingStorePixelRatio ||\n\tthis.context.webkitBackingStorePixelRatio ||\n\tthis.context.mozBackingStorePixelRatio ||\n\tthis.context.msBackingStorePixelRatio ||\n\tthis.context.oBackingStorePixelRatio ||\n\tthis.context.backingStorePixelRatio || 1;\n\n mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n\n var rubberband = $('');\n rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n\n var pass_mouse_events = true;\n\n canvas_div.resizable({\n start: function(event, ui) {\n pass_mouse_events = false;\n },\n resize: function(event, ui) {\n fig.request_resize(ui.size.width, ui.size.height);\n },\n stop: function(event, ui) {\n pass_mouse_events = true;\n fig.request_resize(ui.size.width, ui.size.height);\n },\n });\n\n function mouse_event_fn(event) {\n if (pass_mouse_events)\n return fig.mouse_event(event, event['data']);\n }\n\n rubberband.mousedown('button_press', mouse_event_fn);\n rubberband.mouseup('button_release', mouse_event_fn);\n // Throttle sequential mouse events to 1 every 20ms.\n rubberband.mousemove('motion_notify', mouse_event_fn);\n\n rubberband.mouseenter('figure_enter', mouse_event_fn);\n rubberband.mouseleave('figure_leave', mouse_event_fn);\n\n canvas_div.on(\"wheel\", function (event) {\n event = event.originalEvent;\n event['data'] = 'scroll'\n if (event.deltaY < 0) {\n event.step = 1;\n } else {\n event.step = -1;\n }\n mouse_event_fn(event);\n });\n\n canvas_div.append(canvas);\n canvas_div.append(rubberband);\n\n this.rubberband = rubberband;\n this.rubberband_canvas = rubberband[0];\n this.rubberband_context = rubberband[0].getContext(\"2d\");\n this.rubberband_context.strokeStyle = \"#000000\";\n\n this._resize_canvas = function(width, height) {\n // Keep the size of the canvas, canvas container, and rubber band\n // canvas in synch.\n canvas_div.css('width', width)\n canvas_div.css('height', height)\n\n canvas.attr('width', width * mpl.ratio);\n canvas.attr('height', height * mpl.ratio);\n canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n\n rubberband.attr('width', width);\n rubberband.attr('height', height);\n }\n\n // Set the figure to an initial 600x600px, this will subsequently be updated\n // upon first draw.\n this._resize_canvas(600, 600);\n\n // Disable right mouse context menu.\n $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n return false;\n });\n\n function set_focus () {\n canvas.focus();\n canvas_div.focus();\n }\n\n window.setTimeout(set_focus, 100);\n}\n\nmpl.figure.prototype._init_toolbar = function() {\n var fig = this;\n\n var nav_element = $('
');\n nav_element.attr('style', 'width: 100%');\n this.root.append(nav_element);\n\n // Define a callback function for later on.\n function toolbar_event(event) {\n return fig.toolbar_button_onclick(event['data']);\n }\n function toolbar_mouse_event(event) {\n return fig.toolbar_button_onmouseover(event['data']);\n }\n\n for(var toolbar_ind in mpl.toolbar_items) {\n var name = mpl.toolbar_items[toolbar_ind][0];\n var tooltip = mpl.toolbar_items[toolbar_ind][1];\n var image = mpl.toolbar_items[toolbar_ind][2];\n var method_name = mpl.toolbar_items[toolbar_ind][3];\n\n if (!name) {\n // put a spacer in here.\n continue;\n }\n var button = $('');\n button.click(method_name, toolbar_event);\n button.mouseover(tooltip, toolbar_mouse_event);\n nav_element.append(button);\n }\n\n // Add the status bar.\n var status_bar = $('');\n nav_element.append(status_bar);\n this.message = status_bar[0];\n\n // Add the close button to the window.\n var buttongrp = $('
');\n var button = $('');\n button.click(function (evt) { fig.handle_close(fig, {}); } );\n button.mouseover('Stop Interaction', toolbar_mouse_event);\n buttongrp.append(button);\n var titlebar = this.root.find($('.ui-dialog-titlebar'));\n titlebar.prepend(buttongrp);\n}\n\nmpl.figure.prototype._root_extra_style = function(el){\n var fig = this\n el.on(\"remove\", function(){\n\tfig.close_ws(fig, {});\n });\n}\n\nmpl.figure.prototype._canvas_extra_style = function(el){\n // this is important to make the div 'focusable\n el.attr('tabindex', 0)\n // reach out to IPython and tell the keyboard manager to turn it's self\n // off when our div gets focus\n\n // location in version 3\n if (IPython.notebook.keyboard_manager) {\n IPython.notebook.keyboard_manager.register_events(el);\n }\n else {\n // location in version 2\n IPython.keyboard_manager.register_events(el);\n }\n\n}\n\nmpl.figure.prototype._key_event_extra = function(event, name) {\n var manager = IPython.notebook.keyboard_manager;\n if (!manager)\n manager = IPython.keyboard_manager;\n\n // Check for shift+enter\n if (event.shiftKey && event.which == 13) {\n this.canvas_div.blur();\n // select the cell after this one\n var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n IPython.notebook.select(index + 1);\n }\n}\n\nmpl.figure.prototype.handle_save = function(fig, msg) {\n fig.ondownload(fig, null);\n}\n\n\nmpl.find_output_cell = function(html_output) {\n // Return the cell and output element which can be found *uniquely* in the notebook.\n // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n // IPython event is triggered only after the cells have been serialised, which for\n // our purposes (turning an active figure into a static one), is too late.\n var cells = IPython.notebook.get_cells();\n var ncells = cells.length;\n for (var i=0; i= 3 moved mimebundle to data attribute of output\n data = data.data;\n }\n if (data['text/html'] == html_output) {\n return [cell, data, j];\n }\n }\n }\n }\n}\n\n// Register the function which deals with the matplotlib target/channel.\n// The kernel may be null if the page has been refreshed.\nif (IPython.notebook.kernel != null) {\n IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n}\n" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": {} } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "# **Part 2**. Apply gradient descent for one client \n", "Second step: Awesome, we have the parameter spaces for our two image recognition models. Let's do some training, every client simply by himself, with its own local data. If you don't know how gradient descent is implemented here, a [quick detour to 2.2.](#ref2.2) might help! :)\n", "\n", "## 2.1 Right on track: gradient descent ftw! \n", "Let's see how this would look like, with Gradient Descent applied. " ], "metadata": {} }, { "cell_type": "code", "execution_count": 5, "source": [ "# lets start with some random values for theta. come back later, and tune them, if you like.\n", "THETA = np.array([-2.0, 0.1])" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 20, "source": [ "%matplotlib notebook \n", "\n", "# apply gradient descent\n", "history_client1 = grad_descent(\n", " function_descent=multivariate1.evaluate,\n", " theta=THETA,\n", " eps=1e-07, # stop condition\n", " nb_max_iter=1000, # max iterations,\n", " verbose=1,\n", ")\n", "ax = client_function_draw(J1)\n", "\n", "# plot the gradient descent\n", "ax.scatter(\n", " history_client1[:, 0][::10],\n", " history_client1[:, 1][::10],\n", " history_client1[:, 2][::10],\n", " label=\"GD descent\",\n", " c=\"red\",\n", " lw=5,\n", " zorder=100,\n", ")\n", "fig = ax.legend()" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "iter: 0, theta: [-1.9, gradient: 3.313123563351253e-05\n", "iter: 50, theta: [-1.6, gradient: 0.00018967341849882979\n", "iter: 100, theta: [-0.9, gradient: 0.0025225717318137636\n", "iter: 150, theta: [-0.2, gradient: 0.00034660933191360543\n", "iter: 200, theta: [-0.0, gradient: 1.3755523300240657e-05\n", "iter: 250, theta: [-0.0, gradient: 3.125739587883647e-06\n", "iter: 300, theta: [-0.0, gradient: 8.610496716188187e-07\n", "iter: 350, theta: [-0.0, gradient: 2.3839203919240326e-07\n", "stop condition reached after 384 iterations\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "application/javascript": "/* Put everything inside the global mpl namespace */\nwindow.mpl = {};\n\n\nmpl.get_websocket_type = function() {\n if (typeof(WebSocket) !== 'undefined') {\n return WebSocket;\n } else if (typeof(MozWebSocket) !== 'undefined') {\n return MozWebSocket;\n } else {\n alert('Your browser does not have WebSocket support. ' +\n 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n 'Firefox 4 and 5 are also supported but you ' +\n 'have to enable WebSockets in about:config.');\n };\n}\n\nmpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n this.id = figure_id;\n\n this.ws = websocket;\n\n this.supports_binary = (this.ws.binaryType != undefined);\n\n if (!this.supports_binary) {\n var warnings = document.getElementById(\"mpl-warnings\");\n if (warnings) {\n warnings.style.display = 'block';\n warnings.textContent = (\n \"This browser does not support binary websocket messages. \" +\n \"Performance may be slow.\");\n }\n }\n\n this.imageObj = new Image();\n\n this.context = undefined;\n this.message = undefined;\n this.canvas = undefined;\n this.rubberband_canvas = undefined;\n this.rubberband_context = undefined;\n this.format_dropdown = undefined;\n\n this.image_mode = 'full';\n\n this.root = $('
');\n this._root_extra_style(this.root)\n this.root.attr('style', 'display: inline-block');\n\n $(parent_element).append(this.root);\n\n this._init_header(this);\n this._init_canvas(this);\n this._init_toolbar(this);\n\n var fig = this;\n\n this.waiting = false;\n\n this.ws.onopen = function () {\n fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n fig.send_message(\"send_image_mode\", {});\n if (mpl.ratio != 1) {\n fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n }\n fig.send_message(\"refresh\", {});\n }\n\n this.imageObj.onload = function() {\n if (fig.image_mode == 'full') {\n // Full images could contain transparency (where diff images\n // almost always do), so we need to clear the canvas so that\n // there is no ghosting.\n fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n }\n fig.context.drawImage(fig.imageObj, 0, 0);\n };\n\n this.imageObj.onunload = function() {\n fig.ws.close();\n }\n\n this.ws.onmessage = this._make_on_message_function(this);\n\n this.ondownload = ondownload;\n}\n\nmpl.figure.prototype._init_header = function() {\n var titlebar = $(\n '
');\n var titletext = $(\n '
');\n titlebar.append(titletext)\n this.root.append(titlebar);\n this.header = titletext[0];\n}\n\n\n\nmpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n\n}\n\n\nmpl.figure.prototype._root_extra_style = function(canvas_div) {\n\n}\n\nmpl.figure.prototype._init_canvas = function() {\n var fig = this;\n\n var canvas_div = $('
');\n\n canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n\n function canvas_keyboard_event(event) {\n return fig.key_event(event, event['data']);\n }\n\n canvas_div.keydown('key_press', canvas_keyboard_event);\n canvas_div.keyup('key_release', canvas_keyboard_event);\n this.canvas_div = canvas_div\n this._canvas_extra_style(canvas_div)\n this.root.append(canvas_div);\n\n var canvas = $('');\n canvas.addClass('mpl-canvas');\n canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n\n this.canvas = canvas[0];\n this.context = canvas[0].getContext(\"2d\");\n\n var backingStore = this.context.backingStorePixelRatio ||\n\tthis.context.webkitBackingStorePixelRatio ||\n\tthis.context.mozBackingStorePixelRatio ||\n\tthis.context.msBackingStorePixelRatio ||\n\tthis.context.oBackingStorePixelRatio ||\n\tthis.context.backingStorePixelRatio || 1;\n\n mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n\n var rubberband = $('');\n rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n\n var pass_mouse_events = true;\n\n canvas_div.resizable({\n start: function(event, ui) {\n pass_mouse_events = false;\n },\n resize: function(event, ui) {\n fig.request_resize(ui.size.width, ui.size.height);\n },\n stop: function(event, ui) {\n pass_mouse_events = true;\n fig.request_resize(ui.size.width, ui.size.height);\n },\n });\n\n function mouse_event_fn(event) {\n if (pass_mouse_events)\n return fig.mouse_event(event, event['data']);\n }\n\n rubberband.mousedown('button_press', mouse_event_fn);\n rubberband.mouseup('button_release', mouse_event_fn);\n // Throttle sequential mouse events to 1 every 20ms.\n rubberband.mousemove('motion_notify', mouse_event_fn);\n\n rubberband.mouseenter('figure_enter', mouse_event_fn);\n rubberband.mouseleave('figure_leave', mouse_event_fn);\n\n canvas_div.on(\"wheel\", function (event) {\n event = event.originalEvent;\n event['data'] = 'scroll'\n if (event.deltaY < 0) {\n event.step = 1;\n } else {\n event.step = -1;\n }\n mouse_event_fn(event);\n });\n\n canvas_div.append(canvas);\n canvas_div.append(rubberband);\n\n this.rubberband = rubberband;\n this.rubberband_canvas = rubberband[0];\n this.rubberband_context = rubberband[0].getContext(\"2d\");\n this.rubberband_context.strokeStyle = \"#000000\";\n\n this._resize_canvas = function(width, height) {\n // Keep the size of the canvas, canvas container, and rubber band\n // canvas in synch.\n canvas_div.css('width', width)\n canvas_div.css('height', height)\n\n canvas.attr('width', width * mpl.ratio);\n canvas.attr('height', height * mpl.ratio);\n canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n\n rubberband.attr('width', width);\n rubberband.attr('height', height);\n }\n\n // Set the figure to an initial 600x600px, this will subsequently be updated\n // upon first draw.\n this._resize_canvas(600, 600);\n\n // Disable right mouse context menu.\n $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n return false;\n });\n\n function set_focus () {\n canvas.focus();\n canvas_div.focus();\n }\n\n window.setTimeout(set_focus, 100);\n}\n\nmpl.figure.prototype._init_toolbar = function() {\n var fig = this;\n\n var nav_element = $('
');\n nav_element.attr('style', 'width: 100%');\n this.root.append(nav_element);\n\n // Define a callback function for later on.\n function toolbar_event(event) {\n return fig.toolbar_button_onclick(event['data']);\n }\n function toolbar_mouse_event(event) {\n return fig.toolbar_button_onmouseover(event['data']);\n }\n\n for(var toolbar_ind in mpl.toolbar_items) {\n var name = mpl.toolbar_items[toolbar_ind][0];\n var tooltip = mpl.toolbar_items[toolbar_ind][1];\n var image = mpl.toolbar_items[toolbar_ind][2];\n var method_name = mpl.toolbar_items[toolbar_ind][3];\n\n if (!name) {\n // put a spacer in here.\n continue;\n }\n var button = $('');\n button.click(method_name, toolbar_event);\n button.mouseover(tooltip, toolbar_mouse_event);\n nav_element.append(button);\n }\n\n // Add the status bar.\n var status_bar = $('');\n nav_element.append(status_bar);\n this.message = status_bar[0];\n\n // Add the close button to the window.\n var buttongrp = $('
');\n var button = $('');\n button.click(function (evt) { fig.handle_close(fig, {}); } );\n button.mouseover('Stop Interaction', toolbar_mouse_event);\n buttongrp.append(button);\n var titlebar = this.root.find($('.ui-dialog-titlebar'));\n titlebar.prepend(buttongrp);\n}\n\nmpl.figure.prototype._root_extra_style = function(el){\n var fig = this\n el.on(\"remove\", function(){\n\tfig.close_ws(fig, {});\n });\n}\n\nmpl.figure.prototype._canvas_extra_style = function(el){\n // this is important to make the div 'focusable\n el.attr('tabindex', 0)\n // reach out to IPython and tell the keyboard manager to turn it's self\n // off when our div gets focus\n\n // location in version 3\n if (IPython.notebook.keyboard_manager) {\n IPython.notebook.keyboard_manager.register_events(el);\n }\n else {\n // location in version 2\n IPython.keyboard_manager.register_events(el);\n }\n\n}\n\nmpl.figure.prototype._key_event_extra = function(event, name) {\n var manager = IPython.notebook.keyboard_manager;\n if (!manager)\n manager = IPython.keyboard_manager;\n\n // Check for shift+enter\n if (event.shiftKey && event.which == 13) {\n this.canvas_div.blur();\n // select the cell after this one\n var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n IPython.notebook.select(index + 1);\n }\n}\n\nmpl.figure.prototype.handle_save = function(fig, msg) {\n fig.ondownload(fig, null);\n}\n\n\nmpl.find_output_cell = function(html_output) {\n // Return the cell and output element which can be found *uniquely* in the notebook.\n // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n // IPython event is triggered only after the cells have been serialised, which for\n // our purposes (turning an active figure into a static one), is too late.\n var cells = IPython.notebook.get_cells();\n var ncells = cells.length;\n for (var i=0; i= 3 moved mimebundle to data attribute of output\n data = data.data;\n }\n if (data['text/html'] == html_output) {\n return [cell, data, j];\n }\n }\n }\n }\n}\n\n// Register the function which deals with the matplotlib target/channel.\n// The kernel may be null if the page has been refreshed.\nif (IPython.notebook.kernel != null) {\n IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n}\n" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "**client 2**\n", "create 3D plot using SGD on [('GD client 2', 'orange')] over distribution [('client 2 target', 'Oranges')]\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "application/javascript": "/* Put everything inside the global mpl namespace */\nwindow.mpl = {};\n\n\nmpl.get_websocket_type = function() {\n if (typeof(WebSocket) !== 'undefined') {\n return WebSocket;\n } else if (typeof(MozWebSocket) !== 'undefined') {\n return MozWebSocket;\n } else {\n alert('Your browser does not have WebSocket support. ' +\n 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n 'Firefox 4 and 5 are also supported but you ' +\n 'have to enable WebSockets in about:config.');\n };\n}\n\nmpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n this.id = figure_id;\n\n this.ws = websocket;\n\n this.supports_binary = (this.ws.binaryType != undefined);\n\n if (!this.supports_binary) {\n var warnings = document.getElementById(\"mpl-warnings\");\n if (warnings) {\n warnings.style.display = 'block';\n warnings.textContent = (\n \"This browser does not support binary websocket messages. \" +\n \"Performance may be slow.\");\n }\n }\n\n this.imageObj = new Image();\n\n this.context = undefined;\n this.message = undefined;\n this.canvas = undefined;\n this.rubberband_canvas = undefined;\n this.rubberband_context = undefined;\n this.format_dropdown = undefined;\n\n this.image_mode = 'full';\n\n this.root = $('
');\n this._root_extra_style(this.root)\n this.root.attr('style', 'display: inline-block');\n\n $(parent_element).append(this.root);\n\n this._init_header(this);\n this._init_canvas(this);\n this._init_toolbar(this);\n\n var fig = this;\n\n this.waiting = false;\n\n this.ws.onopen = function () {\n fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n fig.send_message(\"send_image_mode\", {});\n if (mpl.ratio != 1) {\n fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n }\n fig.send_message(\"refresh\", {});\n }\n\n this.imageObj.onload = function() {\n if (fig.image_mode == 'full') {\n // Full images could contain transparency (where diff images\n // almost always do), so we need to clear the canvas so that\n // there is no ghosting.\n fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n }\n fig.context.drawImage(fig.imageObj, 0, 0);\n };\n\n this.imageObj.onunload = function() {\n fig.ws.close();\n }\n\n this.ws.onmessage = this._make_on_message_function(this);\n\n this.ondownload = ondownload;\n}\n\nmpl.figure.prototype._init_header = function() {\n var titlebar = $(\n '
');\n var titletext = $(\n '
');\n titlebar.append(titletext)\n this.root.append(titlebar);\n this.header = titletext[0];\n}\n\n\n\nmpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n\n}\n\n\nmpl.figure.prototype._root_extra_style = function(canvas_div) {\n\n}\n\nmpl.figure.prototype._init_canvas = function() {\n var fig = this;\n\n var canvas_div = $('
');\n\n canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n\n function canvas_keyboard_event(event) {\n return fig.key_event(event, event['data']);\n }\n\n canvas_div.keydown('key_press', canvas_keyboard_event);\n canvas_div.keyup('key_release', canvas_keyboard_event);\n this.canvas_div = canvas_div\n this._canvas_extra_style(canvas_div)\n this.root.append(canvas_div);\n\n var canvas = $('');\n canvas.addClass('mpl-canvas');\n canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n\n this.canvas = canvas[0];\n this.context = canvas[0].getContext(\"2d\");\n\n var backingStore = this.context.backingStorePixelRatio ||\n\tthis.context.webkitBackingStorePixelRatio ||\n\tthis.context.mozBackingStorePixelRatio ||\n\tthis.context.msBackingStorePixelRatio ||\n\tthis.context.oBackingStorePixelRatio ||\n\tthis.context.backingStorePixelRatio || 1;\n\n mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n\n var rubberband = $('');\n rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n\n var pass_mouse_events = true;\n\n canvas_div.resizable({\n start: function(event, ui) {\n pass_mouse_events = false;\n },\n resize: function(event, ui) {\n fig.request_resize(ui.size.width, ui.size.height);\n },\n stop: function(event, ui) {\n pass_mouse_events = true;\n fig.request_resize(ui.size.width, ui.size.height);\n },\n });\n\n function mouse_event_fn(event) {\n if (pass_mouse_events)\n return fig.mouse_event(event, event['data']);\n }\n\n rubberband.mousedown('button_press', mouse_event_fn);\n rubberband.mouseup('button_release', mouse_event_fn);\n // Throttle sequential mouse events to 1 every 20ms.\n rubberband.mousemove('motion_notify', mouse_event_fn);\n\n rubberband.mouseenter('figure_enter', mouse_event_fn);\n rubberband.mouseleave('figure_leave', mouse_event_fn);\n\n canvas_div.on(\"wheel\", function (event) {\n event = event.originalEvent;\n event['data'] = 'scroll'\n if (event.deltaY < 0) {\n event.step = 1;\n } else {\n event.step = -1;\n }\n mouse_event_fn(event);\n });\n\n canvas_div.append(canvas);\n canvas_div.append(rubberband);\n\n this.rubberband = rubberband;\n this.rubberband_canvas = rubberband[0];\n this.rubberband_context = rubberband[0].getContext(\"2d\");\n this.rubberband_context.strokeStyle = \"#000000\";\n\n this._resize_canvas = function(width, height) {\n // Keep the size of the canvas, canvas container, and rubber band\n // canvas in synch.\n canvas_div.css('width', width)\n canvas_div.css('height', height)\n\n canvas.attr('width', width * mpl.ratio);\n canvas.attr('height', height * mpl.ratio);\n canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n\n rubberband.attr('width', width);\n rubberband.attr('height', height);\n }\n\n // Set the figure to an initial 600x600px, this will subsequently be updated\n // upon first draw.\n this._resize_canvas(600, 600);\n\n // Disable right mouse context menu.\n $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n return false;\n });\n\n function set_focus () {\n canvas.focus();\n canvas_div.focus();\n }\n\n window.setTimeout(set_focus, 100);\n}\n\nmpl.figure.prototype._init_toolbar = function() {\n var fig = this;\n\n var nav_element = $('
');\n nav_element.attr('style', 'width: 100%');\n this.root.append(nav_element);\n\n // Define a callback function for later on.\n function toolbar_event(event) {\n return fig.toolbar_button_onclick(event['data']);\n }\n function toolbar_mouse_event(event) {\n return fig.toolbar_button_onmouseover(event['data']);\n }\n\n for(var toolbar_ind in mpl.toolbar_items) {\n var name = mpl.toolbar_items[toolbar_ind][0];\n var tooltip = mpl.toolbar_items[toolbar_ind][1];\n var image = mpl.toolbar_items[toolbar_ind][2];\n var method_name = mpl.toolbar_items[toolbar_ind][3];\n\n if (!name) {\n // put a spacer in here.\n continue;\n }\n var button = $('');\n button.click(method_name, toolbar_event);\n button.mouseover(tooltip, toolbar_mouse_event);\n nav_element.append(button);\n }\n\n // Add the status bar.\n var status_bar = $('');\n nav_element.append(status_bar);\n this.message = status_bar[0];\n\n // Add the close button to the window.\n var buttongrp = $('
');\n var button = $('');\n button.click(function (evt) { fig.handle_close(fig, {}); } );\n button.mouseover('Stop Interaction', toolbar_mouse_event);\n buttongrp.append(button);\n var titlebar = this.root.find($('.ui-dialog-titlebar'));\n titlebar.prepend(buttongrp);\n}\n\nmpl.figure.prototype._root_extra_style = function(el){\n var fig = this\n el.on(\"remove\", function(){\n\tfig.close_ws(fig, {});\n });\n}\n\nmpl.figure.prototype._canvas_extra_style = function(el){\n // this is important to make the div 'focusable\n el.attr('tabindex', 0)\n // reach out to IPython and tell the keyboard manager to turn it's self\n // off when our div gets focus\n\n // location in version 3\n if (IPython.notebook.keyboard_manager) {\n IPython.notebook.keyboard_manager.register_events(el);\n }\n else {\n // location in version 2\n IPython.keyboard_manager.register_events(el);\n }\n\n}\n\nmpl.figure.prototype._key_event_extra = function(event, name) {\n var manager = IPython.notebook.keyboard_manager;\n if (!manager)\n manager = IPython.keyboard_manager;\n\n // Check for shift+enter\n if (event.shiftKey && event.which == 13) {\n this.canvas_div.blur();\n // select the cell after this one\n var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n IPython.notebook.select(index + 1);\n }\n}\n\nmpl.figure.prototype.handle_save = function(fig, msg) {\n fig.ondownload(fig, null);\n}\n\n\nmpl.find_output_cell = function(html_output) {\n // Return the cell and output element which can be found *uniquely* in the notebook.\n // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n // IPython event is triggered only after the cells have been serialised, which for\n // our purposes (turning an active figure into a static one), is too late.\n var cells = IPython.notebook.get_cells();\n var ncells = cells.length;\n for (var i=0; i= 3 moved mimebundle to data attribute of output\n data = data.data;\n }\n if (data['text/html'] == html_output) {\n return [cell, data, j];\n }\n }\n }\n }\n}\n\n// Register the function which deals with the matplotlib target/channel.\n// The kernel may be null if the page has been refreshed.\nif (IPython.notebook.kernel != null) {\n IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n}\n" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": {} } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "By having the $\\Sigma_{global}$ and $\\mu_{global}$, we can reconstruct the global posterior, finding the same argmax as with the numerical calculations.\n", "\n", "Once we have the posterior, we can run even GD on it, or directly go to the argmax at $\\mu$ of the f_posterior_exact:" ], "metadata": {} }, { "cell_type": "code", "execution_count": 13, "source": [ "#lets do, as if we could aggregate our global function\n", "fig = mpl_multivariate_3d_gd(\n", " name_labels=[\"GD on global posterior\", \"GD client 2\", \"GD client 1\"],\n", " colors=[\"purple\", \"orange\", \"green\"],\n", " functions=[f_posterior_exact.evaluate, multivariate1.evaluate, multivariate2.evaluate],\n", " target_function=[ J1, J2, J_exact],\n", " cmap_target=[ \"Greens\", \"Oranges\", \"Purples\"],\n", " label_target=[ \"client 1 target\", \"client 2 target\", \"global posterior\"],\n", " theta=THETA,\n", " )" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "create 3D plot using SGD on [('GD on global posterior', 'purple'), ('GD client 2', 'orange'), ('GD client 1', 'green')] over distribution [('client 1 target', 'Greens'), ('client 2 target', 'Oranges'), ('global posterior', 'Purples')]\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "application/javascript": "/* Put everything inside the global mpl namespace */\nwindow.mpl = {};\n\n\nmpl.get_websocket_type = function() {\n if (typeof(WebSocket) !== 'undefined') {\n return WebSocket;\n } else if (typeof(MozWebSocket) !== 'undefined') {\n return MozWebSocket;\n } else {\n alert('Your browser does not have WebSocket support. ' +\n 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n 'Firefox 4 and 5 are also supported but you ' +\n 'have to enable WebSockets in about:config.');\n };\n}\n\nmpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n this.id = figure_id;\n\n this.ws = websocket;\n\n this.supports_binary = (this.ws.binaryType != undefined);\n\n if (!this.supports_binary) {\n var warnings = document.getElementById(\"mpl-warnings\");\n if (warnings) {\n warnings.style.display = 'block';\n warnings.textContent = (\n \"This browser does not support binary websocket messages. \" +\n \"Performance may be slow.\");\n }\n }\n\n this.imageObj = new Image();\n\n this.context = undefined;\n this.message = undefined;\n this.canvas = undefined;\n this.rubberband_canvas = undefined;\n this.rubberband_context = undefined;\n this.format_dropdown = undefined;\n\n this.image_mode = 'full';\n\n this.root = $('
');\n this._root_extra_style(this.root)\n this.root.attr('style', 'display: inline-block');\n\n $(parent_element).append(this.root);\n\n this._init_header(this);\n this._init_canvas(this);\n this._init_toolbar(this);\n\n var fig = this;\n\n this.waiting = false;\n\n this.ws.onopen = function () {\n fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n fig.send_message(\"send_image_mode\", {});\n if (mpl.ratio != 1) {\n fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n }\n fig.send_message(\"refresh\", {});\n }\n\n this.imageObj.onload = function() {\n if (fig.image_mode == 'full') {\n // Full images could contain transparency (where diff images\n // almost always do), so we need to clear the canvas so that\n // there is no ghosting.\n fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n }\n fig.context.drawImage(fig.imageObj, 0, 0);\n };\n\n this.imageObj.onunload = function() {\n fig.ws.close();\n }\n\n this.ws.onmessage = this._make_on_message_function(this);\n\n this.ondownload = ondownload;\n}\n\nmpl.figure.prototype._init_header = function() {\n var titlebar = $(\n '
');\n var titletext = $(\n '
');\n titlebar.append(titletext)\n this.root.append(titlebar);\n this.header = titletext[0];\n}\n\n\n\nmpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n\n}\n\n\nmpl.figure.prototype._root_extra_style = function(canvas_div) {\n\n}\n\nmpl.figure.prototype._init_canvas = function() {\n var fig = this;\n\n var canvas_div = $('
');\n\n canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n\n function canvas_keyboard_event(event) {\n return fig.key_event(event, event['data']);\n }\n\n canvas_div.keydown('key_press', canvas_keyboard_event);\n canvas_div.keyup('key_release', canvas_keyboard_event);\n this.canvas_div = canvas_div\n this._canvas_extra_style(canvas_div)\n this.root.append(canvas_div);\n\n var canvas = $('');\n canvas.addClass('mpl-canvas');\n canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n\n this.canvas = canvas[0];\n this.context = canvas[0].getContext(\"2d\");\n\n var backingStore = this.context.backingStorePixelRatio ||\n\tthis.context.webkitBackingStorePixelRatio ||\n\tthis.context.mozBackingStorePixelRatio ||\n\tthis.context.msBackingStorePixelRatio ||\n\tthis.context.oBackingStorePixelRatio ||\n\tthis.context.backingStorePixelRatio || 1;\n\n mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n\n var rubberband = $('');\n rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n\n var pass_mouse_events = true;\n\n canvas_div.resizable({\n start: function(event, ui) {\n pass_mouse_events = false;\n },\n resize: function(event, ui) {\n fig.request_resize(ui.size.width, ui.size.height);\n },\n stop: function(event, ui) {\n pass_mouse_events = true;\n fig.request_resize(ui.size.width, ui.size.height);\n },\n });\n\n function mouse_event_fn(event) {\n if (pass_mouse_events)\n return fig.mouse_event(event, event['data']);\n }\n\n rubberband.mousedown('button_press', mouse_event_fn);\n rubberband.mouseup('button_release', mouse_event_fn);\n // Throttle sequential mouse events to 1 every 20ms.\n rubberband.mousemove('motion_notify', mouse_event_fn);\n\n rubberband.mouseenter('figure_enter', mouse_event_fn);\n rubberband.mouseleave('figure_leave', mouse_event_fn);\n\n canvas_div.on(\"wheel\", function (event) {\n event = event.originalEvent;\n event['data'] = 'scroll'\n if (event.deltaY < 0) {\n event.step = 1;\n } else {\n event.step = -1;\n }\n mouse_event_fn(event);\n });\n\n canvas_div.append(canvas);\n canvas_div.append(rubberband);\n\n this.rubberband = rubberband;\n this.rubberband_canvas = rubberband[0];\n this.rubberband_context = rubberband[0].getContext(\"2d\");\n this.rubberband_context.strokeStyle = \"#000000\";\n\n this._resize_canvas = function(width, height) {\n // Keep the size of the canvas, canvas container, and rubber band\n // canvas in synch.\n canvas_div.css('width', width)\n canvas_div.css('height', height)\n\n canvas.attr('width', width * mpl.ratio);\n canvas.attr('height', height * mpl.ratio);\n canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n\n rubberband.attr('width', width);\n rubberband.attr('height', height);\n }\n\n // Set the figure to an initial 600x600px, this will subsequently be updated\n // upon first draw.\n this._resize_canvas(600, 600);\n\n // Disable right mouse context menu.\n $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n return false;\n });\n\n function set_focus () {\n canvas.focus();\n canvas_div.focus();\n }\n\n window.setTimeout(set_focus, 100);\n}\n\nmpl.figure.prototype._init_toolbar = function() {\n var fig = this;\n\n var nav_element = $('
');\n nav_element.attr('style', 'width: 100%');\n this.root.append(nav_element);\n\n // Define a callback function for later on.\n function toolbar_event(event) {\n return fig.toolbar_button_onclick(event['data']);\n }\n function toolbar_mouse_event(event) {\n return fig.toolbar_button_onmouseover(event['data']);\n }\n\n for(var toolbar_ind in mpl.toolbar_items) {\n var name = mpl.toolbar_items[toolbar_ind][0];\n var tooltip = mpl.toolbar_items[toolbar_ind][1];\n var image = mpl.toolbar_items[toolbar_ind][2];\n var method_name = mpl.toolbar_items[toolbar_ind][3];\n\n if (!name) {\n // put a spacer in here.\n continue;\n }\n var button = $('');\n button.click(method_name, toolbar_event);\n button.mouseover(tooltip, toolbar_mouse_event);\n nav_element.append(button);\n }\n\n // Add the status bar.\n var status_bar = $('');\n nav_element.append(status_bar);\n this.message = status_bar[0];\n\n // Add the close button to the window.\n var buttongrp = $('
');\n var button = $('');\n button.click(function (evt) { fig.handle_close(fig, {}); } );\n button.mouseover('Stop Interaction', toolbar_mouse_event);\n buttongrp.append(button);\n var titlebar = this.root.find($('.ui-dialog-titlebar'));\n titlebar.prepend(buttongrp);\n}\n\nmpl.figure.prototype._root_extra_style = function(el){\n var fig = this\n el.on(\"remove\", function(){\n\tfig.close_ws(fig, {});\n });\n}\n\nmpl.figure.prototype._canvas_extra_style = function(el){\n // this is important to make the div 'focusable\n el.attr('tabindex', 0)\n // reach out to IPython and tell the keyboard manager to turn it's self\n // off when our div gets focus\n\n // location in version 3\n if (IPython.notebook.keyboard_manager) {\n IPython.notebook.keyboard_manager.register_events(el);\n }\n else {\n // location in version 2\n IPython.keyboard_manager.register_events(el);\n }\n\n}\n\nmpl.figure.prototype._key_event_extra = function(event, name) {\n var manager = IPython.notebook.keyboard_manager;\n if (!manager)\n manager = IPython.keyboard_manager;\n\n // Check for shift+enter\n if (event.shiftKey && event.which == 13) {\n this.canvas_div.blur();\n // select the cell after this one\n var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n IPython.notebook.select(index + 1);\n }\n}\n\nmpl.figure.prototype.handle_save = function(fig, msg) {\n fig.ondownload(fig, null);\n}\n\n\nmpl.find_output_cell = function(html_output) {\n // Return the cell and output element which can be found *uniquely* in the notebook.\n // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n // IPython event is triggered only after the cells have been serialised, which for\n // our purposes (turning an active figure into a static one), is too late.\n var cells = IPython.notebook.get_cells();\n var ncells = cells.length;\n for (var i=0; i= 3 moved mimebundle to data attribute of output\n data = data.data;\n }\n if (data['text/html'] == html_output) {\n return [cell, data, j];\n }\n }\n }\n }\n}\n\n// Register the function which deals with the matplotlib target/channel.\n// The kernel may be null if the page has been refreshed.\nif (IPython.notebook.kernel != null) {\n IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n}\n" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "ϴ_final: [ 0.64033788 -0.31200206] \n", "ϴ_optimum: [ 0.67346939 -0.2433281 ]\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "__Congratiulations! You made it. FedAvg reaches the global optimum.__\n", "Through iteratively averaging our client updates, we trained one global model, which is now able to classify the images on client1 and client2 data optimally. Despite it is not the posterior modes (maximum) of any of the local distributions, we still found the best solution according to our a global cost function. \n", "__But.. is that really always true?__ Does FedAvg always find the optimum of the global cost function, when we have non-iid client distributions?\n", "\n", "## 3.3 How to mess with Federated Averaging \n", "So far, FedAvg has been pretty successful. But remember how our client distributions have an **heterogenous distribution** (non-iid)? \n", "\n", "Below, we will see, that FedAvg is not guaranteed to converge to the global optimum. Despite of the fact that our problem is very basic and a very minimal implementation of GD finds the optimum, FedAvg fails here. This happens, as we try to improve our **communication efficiency** by increasing our step size in the training:\n", "\n", "```python\n", "communication_rounds = 20\n", "gd_steps_local = 120\n", "```\n", "\n", "In the code below, observe how we fail to find the optimal solution using FedAvg. " ], "metadata": {} }, { "cell_type": "code", "execution_count": 16, "source": [ "server_history , _ = fedavg(\n", " function_clients=[multivariate1.evaluate, multivariate2.evaluate], # clients\n", " function_eval=f_posterior_exact.evaluate, \n", " communication_rounds = 20,\n", " gd_steps_local = 120,\n", " theta=THETA,\n", ")\n", "\n", "ax = client_function_draw(J_exact, cmap=\"Purples\") # plot global function\n", "\n", "ax.scatter(\n", " server_history[:, 0],\n", " server_history[:, 1],\n", " server_history[:, 2],\n", " label=\"FedAvg\",\n", " c=\"red\",\n", " lw=5,\n", " zorder=100,\n", ")\n", "fig = ax.legend()\n", "print(\"\\u03F4_final: \",server_history[-1,:-1], \"\\n\\u03F4_optimum:\", mu_global_posterior)" ], "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "application/javascript": "/* Put everything inside the global mpl namespace */\nwindow.mpl = {};\n\n\nmpl.get_websocket_type = function() {\n if (typeof(WebSocket) !== 'undefined') {\n return WebSocket;\n } else if (typeof(MozWebSocket) !== 'undefined') {\n return MozWebSocket;\n } else {\n alert('Your browser does not have WebSocket support. ' +\n 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n 'Firefox 4 and 5 are also supported but you ' +\n 'have to enable WebSockets in about:config.');\n };\n}\n\nmpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n this.id = figure_id;\n\n this.ws = websocket;\n\n this.supports_binary = (this.ws.binaryType != undefined);\n\n if (!this.supports_binary) {\n var warnings = document.getElementById(\"mpl-warnings\");\n if (warnings) {\n warnings.style.display = 'block';\n warnings.textContent = (\n \"This browser does not support binary websocket messages. \" +\n \"Performance may be slow.\");\n }\n }\n\n this.imageObj = new Image();\n\n this.context = undefined;\n this.message = undefined;\n this.canvas = undefined;\n this.rubberband_canvas = undefined;\n this.rubberband_context = undefined;\n this.format_dropdown = undefined;\n\n this.image_mode = 'full';\n\n this.root = $('
');\n this._root_extra_style(this.root)\n this.root.attr('style', 'display: inline-block');\n\n $(parent_element).append(this.root);\n\n this._init_header(this);\n this._init_canvas(this);\n this._init_toolbar(this);\n\n var fig = this;\n\n this.waiting = false;\n\n this.ws.onopen = function () {\n fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n fig.send_message(\"send_image_mode\", {});\n if (mpl.ratio != 1) {\n fig.send_message(\"set_dpi_ratio\", {'dpi_ratio': mpl.ratio});\n }\n fig.send_message(\"refresh\", {});\n }\n\n this.imageObj.onload = function() {\n if (fig.image_mode == 'full') {\n // Full images could contain transparency (where diff images\n // almost always do), so we need to clear the canvas so that\n // there is no ghosting.\n fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n }\n fig.context.drawImage(fig.imageObj, 0, 0);\n };\n\n this.imageObj.onunload = function() {\n fig.ws.close();\n }\n\n this.ws.onmessage = this._make_on_message_function(this);\n\n this.ondownload = ondownload;\n}\n\nmpl.figure.prototype._init_header = function() {\n var titlebar = $(\n '
');\n var titletext = $(\n '
');\n titlebar.append(titletext)\n this.root.append(titlebar);\n this.header = titletext[0];\n}\n\n\n\nmpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n\n}\n\n\nmpl.figure.prototype._root_extra_style = function(canvas_div) {\n\n}\n\nmpl.figure.prototype._init_canvas = function() {\n var fig = this;\n\n var canvas_div = $('
');\n\n canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n\n function canvas_keyboard_event(event) {\n return fig.key_event(event, event['data']);\n }\n\n canvas_div.keydown('key_press', canvas_keyboard_event);\n canvas_div.keyup('key_release', canvas_keyboard_event);\n this.canvas_div = canvas_div\n this._canvas_extra_style(canvas_div)\n this.root.append(canvas_div);\n\n var canvas = $('');\n canvas.addClass('mpl-canvas');\n canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n\n this.canvas = canvas[0];\n this.context = canvas[0].getContext(\"2d\");\n\n var backingStore = this.context.backingStorePixelRatio ||\n\tthis.context.webkitBackingStorePixelRatio ||\n\tthis.context.mozBackingStorePixelRatio ||\n\tthis.context.msBackingStorePixelRatio ||\n\tthis.context.oBackingStorePixelRatio ||\n\tthis.context.backingStorePixelRatio || 1;\n\n mpl.ratio = (window.devicePixelRatio || 1) / backingStore;\n\n var rubberband = $('');\n rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n\n var pass_mouse_events = true;\n\n canvas_div.resizable({\n start: function(event, ui) {\n pass_mouse_events = false;\n },\n resize: function(event, ui) {\n fig.request_resize(ui.size.width, ui.size.height);\n },\n stop: function(event, ui) {\n pass_mouse_events = true;\n fig.request_resize(ui.size.width, ui.size.height);\n },\n });\n\n function mouse_event_fn(event) {\n if (pass_mouse_events)\n return fig.mouse_event(event, event['data']);\n }\n\n rubberband.mousedown('button_press', mouse_event_fn);\n rubberband.mouseup('button_release', mouse_event_fn);\n // Throttle sequential mouse events to 1 every 20ms.\n rubberband.mousemove('motion_notify', mouse_event_fn);\n\n rubberband.mouseenter('figure_enter', mouse_event_fn);\n rubberband.mouseleave('figure_leave', mouse_event_fn);\n\n canvas_div.on(\"wheel\", function (event) {\n event = event.originalEvent;\n event['data'] = 'scroll'\n if (event.deltaY < 0) {\n event.step = 1;\n } else {\n event.step = -1;\n }\n mouse_event_fn(event);\n });\n\n canvas_div.append(canvas);\n canvas_div.append(rubberband);\n\n this.rubberband = rubberband;\n this.rubberband_canvas = rubberband[0];\n this.rubberband_context = rubberband[0].getContext(\"2d\");\n this.rubberband_context.strokeStyle = \"#000000\";\n\n this._resize_canvas = function(width, height) {\n // Keep the size of the canvas, canvas container, and rubber band\n // canvas in synch.\n canvas_div.css('width', width)\n canvas_div.css('height', height)\n\n canvas.attr('width', width * mpl.ratio);\n canvas.attr('height', height * mpl.ratio);\n canvas.attr('style', 'width: ' + width + 'px; height: ' + height + 'px;');\n\n rubberband.attr('width', width);\n rubberband.attr('height', height);\n }\n\n // Set the figure to an initial 600x600px, this will subsequently be updated\n // upon first draw.\n this._resize_canvas(600, 600);\n\n // Disable right mouse context menu.\n $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n return false;\n });\n\n function set_focus () {\n canvas.focus();\n canvas_div.focus();\n }\n\n window.setTimeout(set_focus, 100);\n}\n\nmpl.figure.prototype._init_toolbar = function() {\n var fig = this;\n\n var nav_element = $('
');\n nav_element.attr('style', 'width: 100%');\n this.root.append(nav_element);\n\n // Define a callback function for later on.\n function toolbar_event(event) {\n return fig.toolbar_button_onclick(event['data']);\n }\n function toolbar_mouse_event(event) {\n return fig.toolbar_button_onmouseover(event['data']);\n }\n\n for(var toolbar_ind in mpl.toolbar_items) {\n var name = mpl.toolbar_items[toolbar_ind][0];\n var tooltip = mpl.toolbar_items[toolbar_ind][1];\n var image = mpl.toolbar_items[toolbar_ind][2];\n var method_name = mpl.toolbar_items[toolbar_ind][3];\n\n if (!name) {\n // put a spacer in here.\n continue;\n }\n var button = $('');\n button.click(method_name, toolbar_event);\n button.mouseover(tooltip, toolbar_mouse_event);\n nav_element.append(button);\n }\n\n // Add the status bar.\n var status_bar = $('');\n nav_element.append(status_bar);\n this.message = status_bar[0];\n\n // Add the close button to the window.\n var buttongrp = $('
');\n var button = $('');\n button.click(function (evt) { fig.handle_close(fig, {}); } );\n button.mouseover('Stop Interaction', toolbar_mouse_event);\n buttongrp.append(button);\n var titlebar = this.root.find($('.ui-dialog-titlebar'));\n titlebar.prepend(buttongrp);\n}\n\nmpl.figure.prototype._root_extra_style = function(el){\n var fig = this\n el.on(\"remove\", function(){\n\tfig.close_ws(fig, {});\n });\n}\n\nmpl.figure.prototype._canvas_extra_style = function(el){\n // this is important to make the div 'focusable\n el.attr('tabindex', 0)\n // reach out to IPython and tell the keyboard manager to turn it's self\n // off when our div gets focus\n\n // location in version 3\n if (IPython.notebook.keyboard_manager) {\n IPython.notebook.keyboard_manager.register_events(el);\n }\n else {\n // location in version 2\n IPython.keyboard_manager.register_events(el);\n }\n\n}\n\nmpl.figure.prototype._key_event_extra = function(event, name) {\n var manager = IPython.notebook.keyboard_manager;\n if (!manager)\n manager = IPython.keyboard_manager;\n\n // Check for shift+enter\n if (event.shiftKey && event.which == 13) {\n this.canvas_div.blur();\n // select the cell after this one\n var index = IPython.notebook.find_cell_index(this.cell_info[0]);\n IPython.notebook.select(index + 1);\n }\n}\n\nmpl.figure.prototype.handle_save = function(fig, msg) {\n fig.ondownload(fig, null);\n}\n\n\nmpl.find_output_cell = function(html_output) {\n // Return the cell and output element which can be found *uniquely* in the notebook.\n // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n // IPython event is triggered only after the cells have been serialised, which for\n // our purposes (turning an active figure into a static one), is too late.\n var cells = IPython.notebook.get_cells();\n var ncells = cells.length;\n for (var i=0; i= 3 moved mimebundle to data attribute of output\n data = data.data;\n }\n if (data['text/html'] == html_output) {\n return [cell, data, j];\n }\n }\n }\n }\n}\n\n// Register the function which deals with the matplotlib target/channel.\n// The kernel may be null if the page has been refreshed.\nif (IPython.notebook.kernel != null) {\n IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n}\n" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "" ] }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "ϴ_final: [ 0.65818276 -0.20683045] \n", "ϴ_optimum: [ 0.67346939 -0.2433281 ]\n" ] } ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Finally, we made it! Our FedPA \"Light\" converged to the optimum!\r\n", "\r\n", "## 4.1 FedPA: Beyond the simplified implementation \r\n", "\r\n", "In this chapter, we already used Posterior Averaging to find our optimal $\\theta$. As you can see in [Figure2 below](#Figure2), instead of using GD on the clients, we estimated our covariance and our $\\mu$ of the clients posterior distribution, and then send back the computed $\\Delta$ .\r\n", "\r\n", "\r\n", "\r\n", "\r\n", "\r\n", "
\r\n", " \r\n", " missing gif fedavg-and-fedpa-illustration\r\n", "
Figure2 © GIF adapted from Maruan Al-Shedivat [4]
\r\n", "
\r\n", "\r\n", "Reading the blog about FedPA from Maruan Al-Shedivat [1] was very valuable for me. I can warmly recommend to read and work through both, the [blog](https://blog.ml.cmu.edu/2021/02/19/an-inferential-perspective-on-federated-learning/) and the [code](https://github.com/alshedivat/fedpa). This also motivated me in the decision to create this Notebook Blog, instead of 'just' a blog similar the authors' original." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "# **Part 5**: Summary \n", "\n", "In this blog, you have learned to:\n", "- create two non-iid distributions for your initial problem with the images. \n", "- You then derived the global posterior and used gradient descent to find the optimal paramters. \n", "- Using Federated Learning, you applied FedAvg to this data setting and understood why this only found a suboptimal $\\theta$. \n", "- Ultimatley, using some gaussian estimation of the client posterior from the clients, you applied Posterior Averaging. \n", "\n", "Thank you! If you have any kind of Feedback, [feel free to contact me!](mailto:feedback-blog-posterior-averaging@michaelfeil.eu).\n" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "# **Part 6**: References \n", "\n", "[1]: [Federated Learning via Posterior Averaging: A New Perspective and Practical Algorithms](https://arxiv.org/abs/2010.05273), Al-Shedivat et al., Paper, 2021 \n", "\n", "[2]: [Communication efficient learning of deep networks from decentralized data](https://arxiv.org/abs/1602.05629), McMahan et al., Paper, 2017\n", "\n", "[3]: [OM PriCon2020: OpenMined: 'Advances and Open Problems in Federated Learning' - Peter Kairouz](https://www.youtube.com/watch?v=YnzoxibzkdI), Peter Kairouz, Presentation, 2020\n", "\n", "[4]: [An inferential perspective on federated learning](https://blog.ml.cmu.edu/2021/02/19/an-inferential-perspective-on-federated-learning/), Al-Shedivat, Blog, 2021 \n", "\n", "[5]: [The Matrix Cookbook](http://www2.compute.dtu.dk/pubdb/pubs/3274-full.html), K. B. Petersen and M. S. Pedersen, 2012 \n", " \n", "[6]: [Federated Learning via Posterior Averaging](https://github.com/alshedivat/fedpa), Al-Shedivat et al., Code, 2021 \n" ], "metadata": {} } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 2 }