{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Robust Vector Pooling\n",
"\n",
"We can extend robust pooling to work over vector arguments. Denote by $\\phi$ the penalty function. Then robust vector averaging with input $\\{x_i \\in \\mathbb{R}^m \\mid i = 1, \\ldots, n\\}$ and output $y \\in \\mathbb{R}^m$ finds the solution to the optimization problem\n",
"\n",
"$$\n",
"y \\in \\text{arg min}_u \\sum_{i=1}^{n} \\phi(\\|u - x_i\\|; \\alpha)\n",
"$$\n",
"\n",
"where $\\alpha$ is a parameter of the penalty function. We implement this operation using a `ddn.basic` node for demonstration and can be used in your own code by importing `RobustVectorAverage` from `ddn.basic.robust_nodes`. There is also an efficient version called `RobustVectorPool2d` in the `ddn.pytorch.robust_vec_pool` module for use in Pytorch applications. That version follows the mathematical derivation shown at the end of this notebook."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2021-05-08T02:04:32.849585Z",
"start_time": "2021-05-08T02:04:32.003927Z"
}
},
"outputs": [],
"source": [
"%matplotlib notebook\n",
"\n",
"import sys\n",
"sys.path.append(\"../\")\n",
"\n",
"import scipy.optimize as opt\n",
"from ddn.basic.node import *\n",
"\n",
"class RobustVectorAverage(NonUniqueDeclarativeNode):\n",
" \"\"\"\n",
" Solves for the multi-dimensional robust average,\n",
" minimize f(x, y) = \\sum_{i=1}^{n} phi(\\|y - x_i\\|; alpha)\n",
" where phi(z; alpha) is one of the following robust penalties,\n",
" 'quadratic': 1/2 z^2\n",
" 'pseudo-huber': alpha^2 (\\sqrt(1 + (z/alpha)^2 - 1)\n",
" 'huber': 1/2 z^2 for |z| <= alpha and alpha |z| - 1/2 alpha^2 otherwise\n",
" 'welsch': 1 - exp(-z^2 / 2 alpha^2)\n",
" 'trunc-quad': 1/2 z^2 for |z| <= alpha and 1/2 alpha^2 otherwise\n",
"\n",
" The input is assumed to be flattened from an (n \\times m)-matrix to an nm-vector.\n",
" \"\"\"\n",
"\n",
" restarts = 10 # number of random restarts when solving non-convex penalties\n",
"\n",
" def __init__(self, n, m, penalty='huber', alpha=1.0):\n",
" assert (alpha > 0.0)\n",
" self.alpha = alpha\n",
" self.alpha_sq = alpha ** 2\n",
" self.penalty = penalty.lower()\n",
" if (self.penalty == 'quadratic'):\n",
" self.phi = lambda z: 0.5 * np.power(z, 2.0)\n",
" elif (self.penalty == 'pseudo-huber'):\n",
" self.phi = lambda z: self.alpha_sq * (np.sqrt(1.0 + np.power(z, 2.0) / self.alpha_sq) - 1.0)\n",
" elif (self.penalty == 'huber'):\n",
" self.phi = lambda z: np.where(np.abs(z) <= alpha, 0.5 * np.power(z, 2.0), alpha * np.abs(z) - 0.5 * self.alpha_sq)\n",
" elif (self.penalty == 'welsch'):\n",
" self.phi = lambda z: 1.0 - np.exp(-0.5 * np.power(z, 2.0) / self.alpha_sq)\n",
" elif (self.penalty == 'trunc-quad'):\n",
" self.phi = lambda z: np.minimum(0.5 * np.power(z, 2.0), 0.5 * self.alpha_sq)\n",
" else:\n",
" assert False, \"unrecognized penalty function {}\".format(penalty)\n",
"\n",
" super().__init__(n*m, m) # make sure node is properly constructed\n",
" self.eps = 1.0e-4 # relax tolerance on optimality test\n",
"\n",
" def objective(self, x, y):\n",
" assert (len(x) == self.dim_x) and (len(y) == self.dim_y)\n",
" # the inclusion of 1.0e-9 prevents division by zero during automatic differentiation when a y lands exactly on a data point xi\n",
" return np.sum([self.phi(np.sqrt(np.dot(y - xi, y - xi) + 1.0e-9)) for xi in x.reshape((int(self.dim_x / self.dim_y), self.dim_y))])\n",
"\n",
" def solve(self, x):\n",
" assert(len(x) == self.dim_x)\n",
"\n",
" J = lambda y : self.objective(x, y)\n",
" dJ = lambda y : self.fY(x, y)\n",
"\n",
" y_star = np.mean(x.reshape((int(self.dim_x / self.dim_y), self.dim_y)), 0)\n",
" if (self.penalty != 'quadratic'):\n",
" result = opt.minimize(J, y_star, args=(), method='L-BFGS-B', jac=dJ, options={'maxiter': 100, 'disp': False})\n",
" if not result.success: print(result.message)\n",
" y_star, J_star = result.x, result.fun\n",
"\n",
" # run with different intial guesses for non-convex penalties\n",
" if (self.penalty == 'welsch') or (self.penalty == 'trunc-quad'):\n",
" guesses = np.random.permutation(x.reshape((int(self.dim_x / self.dim_y), self.dim_y)))\n",
" if self.dim_x > self.restarts: guesses = guesses[0:self.restarts, :]\n",
" for y_init in guesses:\n",
" result = opt.minimize(J, y_init, args=(), method='L-BFGS-B', jac=dJ, options={'maxiter': 100, 'disp': False})\n",
" if not result.success: print(result.message)\n",
" if (result.fun < J_star):\n",
" y_star, J_star = result.x, result.fun\n",
"\n",
" return y_star, None\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example\n",
"\n",
"To show an example of robust vector averaging we demonstrate updating a set of 2D points with outliers such that the (robust) average of the points moves to (0, 0). That is, we aim to solve the problem\n",
"\n",
"$$\n",
"\\begin{array}{ll}\n",
" \\text{minimize (over $x_i$)} & \\|y\\|_2^2 \\\\\n",
" \\text{subject to} & y = \\text{arg min}_u \\sum_{i=1}^{n} \\phi(\\|u - x_i\\|; \\alpha)\n",
"\\end{array}\n",
"$$\n",
"\n",
"We use three different penalty functions: `quadratic`, `pseudo-huber` and `welsch`. The animation below shows the initial position of the points and robust averages. In the three other panels are shown updates of the points at each iteration as we move the average to (0, 0). Notice that the average computed with the `quadratic` penalty function, which is not robust, moves all points including the outliers. The other two penalty functions only move inlier points; the outliers are largely unaffected. This could be useful, for example, in a neural network where we only want network parameters to be influenced by inliers during back-propagation."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2021-05-08T02:04:33.039046Z",
"start_time": "2021-05-08T02:04:32.851583Z"
}
},
"outputs": [],
"source": [
"%%capture\n",
"\n",
"from ddn.basic.sample_nodes import SquaredErrorNode\n",
"from ddn.basic.composition import ComposedNode\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.animation as animation\n",
"import matplotlib.patches as patches\n",
"from IPython.display import HTML\n",
"\n",
"# set for only two random restarts\n",
"RobustVectorAverage.restarts = 2\n",
"\n",
"# setup data with outliers\n",
"np.random.seed(0)\n",
"x = 1.5 * (np.random.rand(10, 2) - 0.5)\n",
"x[-1, 0] += 6.0; x[-1, 1] -= 2.0\n",
"x[0, 0] += 3.0; x[0, 1] += 2.0\n",
"\n",
"data = [('quadratic', 'b', x.copy(), []), ('pseudo-huber', 'r', x.copy(), []), ('welsch', 'g', x.copy(), [])]\n",
"\n",
"t = np.linspace(0.0, 2.0 * np.pi)\n",
"\n",
"def plot(ax, x_init, x, data, colour=None, history=None):\n",
" y = {}\n",
" for (name, c, _, _) in data:\n",
" node = RobustVectorAverage(x.shape[0], x.shape[1], name)\n",
" y[name], _ = node.solve(x.flatten())\n",
" ax.plot(y[name][0], y[name][1], 'D', color=c, markersize=10)\n",
"\n",
" # draw original points\n",
" ax.plot(x_init[:, 0], x_init[:, 1], 'o', markeredgecolor='k', markerfacecolor='w', markeredgewidth=1.0)\n",
" if colour is not None:\n",
" ax.plot(x[:, 0], x[:, 1], 'o', markeredgecolor='k', markerfacecolor=colour, markeredgewidth=1.0)\n",
"\n",
" # draw circle\n",
" for (name, c, _, _) in data:\n",
" ax.plot(np.cos(t) + y[name][0], np.sin(t) + y[name][1], '--', color=c, linewidth=1)\n",
"\n",
" # draw learning curve\n",
" if history is not None:\n",
" ax.add_patch(patches.Rectangle((3.5, 0.5), 3.0, 2.0, fc='w', ec='k'))\n",
" if len(history) > 0:\n",
" h = 2.0 * np.array(history) / np.max(history) + 0.5\n",
" ax.plot(np.linspace(3.5, 6.5, len(history)), h, color=colour)\n",
" pass\n",
"\n",
" ax.set_xlim(-2.0, 7.0); ax.set_ylim(-3.0, 3.0)\n",
"\n",
"\n",
"def init():\n",
" plot(ax[0], x, x, data)\n",
" ax[0].legend([name for (name, c, _, h) in data] + ['original points'])\n",
" return ax\n",
"\n",
"\n",
"def animate(fnum, x_init, data):\n",
"\n",
" for i, (name, c, x, h) in enumerate(data):\n",
" ax[i+1].clear()\n",
" plot(ax[i+1], x_init, x, data, c, h)\n",
"\n",
" # gradient descent update\n",
" node = ComposedNode(RobustVectorAverage(x.shape[0], x.shape[1], name), SquaredErrorNode(x.shape[1]))\n",
" h.append(node.solve(x.flatten())[0])\n",
" dJ = node.gradient(x.flatten())\n",
" x -= 0.5 * dJ.reshape(x.shape)\n",
"\n",
" return ax[2:]\n",
"\n",
"\n",
"# create animation\n",
"fig = plt.figure()\n",
"ax = [plt.subplot(2, 2, i+1) for i in range(4)]\n",
"plt.tight_layout()\n",
"ani = animation.FuncAnimation(fig, animate, init_func=init, fargs=(x, data), interval=100, frames=50, repeat=False)\n",
"#plt.close(fig)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2021-05-08T02:05:04.764800Z",
"start_time": "2021-05-08T02:04:33.043938Z"
}
},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# display using video or javascript\n",
"HTML(ani.to_html5_video())\n",
"#HTML(ani.to_jshtml())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Mathematics\n",
"\n",
"Consider the objective function for the robust vector pooling optimization problem,\n",
"\n",
"$$\n",
"\\begin{align*}\n",
"f(\\{x_1, \\ldots, x_n\\}, u) &= \\sum_{i=1}^{n} \\phi(\\|u - x_i\\|_2; \\alpha) \\\\\n",
"&= \\sum_{i=1}^{n} \\phi(z_i; \\alpha)\n",
"\\end{align*}\n",
"$$\n",
"\n",
"where we have written $z_i = \\|u - x_i\\|_2$.\n",
"\n",
"The gradient of the minimizer $y$ with respect to each of the $x_j$ is given by\n",
"\n",
"$$\n",
"\\text{D}_{X_j} y = -H^{-1} B\n",
"$$\n",
"\n",
"where $H = \\text{D}^2_{YY} f$ and $B = \\text{D}^2_{X_jY} f$ (see Proposition 4.4 of the DDN [paper](https://arxiv.org/abs/1909.04866)). Since $f$ decomposes as a sum of penalty functions $\\phi$, it\n",
"suffices to consider $\\text{D}^2_{YY} \\phi$ and $\\text{D}^2_{XY} \\phi$.\n",
"\n",
"$$\n",
"\\begin{align*}\n",
"\\text{D}_{Y} \\phi(z_i; \\alpha) &= \\phi'(z_i; \\alpha) \\text{D}_{Y} z_i \\\\\n",
"&= \\phi'(z_i; \\alpha) \\frac{(y - x_i)^T}{z_i}\n",
"\\end{align*}\n",
"$$\n",
"\n",
"where $\\phi'$ is the first derivative of $\\phi$. So\n",
"\n",
"$$\n",
"\\begin{align*}\n",
"\\text{D}^2_{YY} \\phi(z_i; \\alpha)\n",
"&= \\frac{\\phi'(z_i; \\alpha)}{z_i} I_{m \\times m} + \\left(\\phi''(z_i; \\alpha) - \n",
" \\frac{\\phi'(z_i; \\alpha)}{z_i}\\right) \\frac{(y - x_i)(y - x_i)^T}{z_i^2}\n",
"\\\\\n",
"&= \\kappa_1(z_i) I_{m \\times m} + \\kappa_2(z_i) (y - x_i)(y - x_i)^T\n",
"\\end{align*}\n",
"$$\n",
"\n",
"where $\\kappa_1$ and $\\kappa_2$ are quantities that depend on the penalty function and $z_i$.\n",
"\n",
"By symmetry $\\text{D}^2_{X_jY} \\phi(z_j; \\alpha) = - \\text{D}^2_{YY} \\phi(z_j; \\alpha)$.\n",
"\n",
"We therefore have\n",
"\n",
"$$\n",
"\\begin{align*}\n",
"\\text{D}_{X_j} y\n",
"&= \\left( \\sum_{i=1}^{n} \\kappa_1(z_i) I_{m \\times m} + \\kappa_2(z_i) (y - x_i)(y - x_i)^T \\right)^{-1}\n",
"\\Bigg( \\kappa_1(z_j) I_{m \\times m} + \\kappa_2(z_j) (y - x_j)(y - x_j)^T\\Bigg)\n",
"\\\\\n",
"&= H^{-1} \\Bigg( \\kappa_1(z_j) I_{m \\times m} + \\kappa_2(z_j) (y - x_j)(y - x_j)^T\\Bigg)\n",
"\\end{align*}\n",
"$$\n",
"\n",
"## Implementation\n",
"\n",
"Let $v^T = \\text{D} J(y)$ be the derivative of the loss function with respect to the output, i.e., the backward going gradient. Our goal is to compute $\\text{D} J(x_i)$ for $i = 1, \\ldots, n$. We have,\n",
"\n",
"$$\n",
"\\text{D} J(x_i) = v^T H^{-1} \\Bigg( \\kappa_1(z_i) I_{m \\times m} + \\kappa_2(z_i) (y - x_i)(y - x_i)^T\\Bigg)\n",
"$$\n",
"\n",
"Let $w^T = v^T H^{-1}$ obtained by solving $v = H^{-1} w$ using Cholesky factorization and back substitution. Note that this can be computed independent of $i$. We then have\n",
"\n",
"$$\n",
"\\begin{align*}\n",
"\\text{D} J(x_i)\n",
"&= w^T \\Bigg( \\kappa_1(z_i) I_{m \\times m} + \\kappa_2(z_i) (y - x_i)(y - x_i)^T \\Bigg)\n",
"\\\\\n",
"&= \\kappa_1(z_i) w^T + \\kappa_2(z_i) w^T (y - x_i)(y - x_i)^T\n",
"\\end{align*}\n",
"$$\n",
"\n",
"Performing the inner product $w^T (y - x_i)$ before the outer product $(y - x_i)(y - x_i)^T$ results in significant memory and computational savings."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}