{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "LL3kRdXs5zzD" }, "source": [ "##### Copyright 2018 Google LLC.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");\n", "you may not use this file except in compliance with the License.\n", "You may obtain a copy of the License at\n", "\n", "https://www.apache.org/licenses/LICENSE-2.0\n", "\n", "Unless required by applicable law or agreed to in writing, software\n", "distributed under the License is distributed on an \"AS IS\" BASIS,\n", "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "See the License for the specific language governing permissions and\n", "limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "sAgUgR5Mzzz2" }, "source": [ "# XLA in Python\n", "\n", " \n", "\n", "_Anselm Levskaya_ \n", "\n", "_The Python XLA client was designed by Roy Frostig._\n", "\n", "_JAX was written by Dougal Maclaurin, Peter Hawkins, Matthew Johnson, Roy Frostig, Alex Wiltschko, Chris Leary._ \n", "\n", "XLA is the compiler that JAX uses, and the compiler that TF uses for TPUs and will soon use for all devices, so it's worth some study. However, it's not exactly easy to play with XLA computations directly using the raw C++ interface. JAX exposes the underlying XLA computation builder API through a python wrapper, and makes interacting with the XLA compute model accessible for messing around and prototyping.\n", "\n", "XLA computations are built as computation graphs in HLO IR, which is then lowered to LLO that is device specific (CPU, GPU, TPU, etc.). \n", "\n", "As end users we interact with the computational primitives offered to us by the HLO spec." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "EZK5RseuvZkr" }, "source": [ "## References \n", "\n", "__xla__ the doc that defines what's in HLO - but note that the doc is incomplete and omits some ops.\n", "\n", "https://www.tensorflow.org/xla/operation_semantics\n", "\n", "more details on ops in the source code.\n", "\n", "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/client/xla_builder.h\n", "\n", "__python xla client__ this is the XLA python client for JAX, and what we're using here.\n", "\n", "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client.py\n", "\n", "https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/xla_client_test.py\n", "\n", "__jax__ you can see how jax interacts with the XLA compute layer for execution and JITing in these files.\n", "\n", "https://github.com/google/jax/blob/master/jax/lax.py\n", "\n", "https://github.com/google/jax/blob/master/jax/lib/xla_bridge.py\n", "\n", "https://github.com/google/jax/blob/master/jax/interpreters/xla.py" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "3XR2NGmrzBGe" }, "source": [ "## Colab Setup and Imports" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "HMRkxnna8NCN" }, "source": [ "First install jax and jaxlib to get its xla client:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "JWCCBdpL8T5t" }, "outputs": [], "source": [ "!pip install --upgrade -q https://storage.googleapis.com/jax-wheels/cuda$(echo $CUDA_VERSION | sed -e 's/\\.//' -e 's/\\..*//')/jaxlib-0.1.13-cp36-none-linux_x86_64.whl\n", "!pip install --upgrade -q jax" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": {}, "colab_type": "code", "id": "Ogo2SBd3u18P" }, "outputs": [], "source": [ "# We import as onp to emphasize that we're using vanilla numpy, not jax numpy.\n", "import numpy as onp\n", "\n", "# We only need to import JAX's xla_client, not all of JAX.\n", "from jaxlib import xla_client\n", "\n", "# Plotting\n", "import matplotlib as mpl\n", "from matplotlib import pyplot as plt\n", "from matplotlib import gridspec\n", "from matplotlib import rcParams\n", "rcParams['image.interpolation'] = 'nearest'\n", "rcParams['image.cmap'] = 'viridis'\n", "rcParams['axes.grid'] = False" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "0cf7swaobc5l" }, "source": [ "## Convenience Functions" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "5I50k0rhbg6W" }, "outputs": [], "source": [ "# Here we borrow convenience functions from JAX to convert numpy shape/dtypes\n", "# to XLA appropriate shape/dtypes\n", "def canonicalize_dtype(dtype):\n", " \"\"\"We restrict ourselves to 32bit types for this demo.\"\"\"\n", " _dtype_to_32bit_dtype = {\n", " str(onp.dtype('int64')): onp.dtype('int32'),\n", " str(onp.dtype('uint64')): onp.dtype('uint32'),\n", " str(onp.dtype('float64')): onp.dtype('float32'),\n", " str(onp.dtype('complex128')): onp.dtype('complex64'),\n", " }\n", " dtype = onp.dtype(dtype)\n", " return str(_dtype_to_32bit_dtype.get(str(dtype), dtype))\n", "\n", "def shape_of(value):\n", " \"\"\"Given a Python or XLA value, return its canonicalized XLA Shape.\"\"\"\n", " if hasattr(value, 'shape') and hasattr(value, 'dtype'):\n", " return xla_client.Shape.array_shape(canonicalize_dtype(value.dtype), \n", " value.shape)\n", " elif onp.isscalar(value):\n", " return shape_of(onp.asarray(value))\n", " elif isinstance(value, (tuple, list)):\n", " return xla_client.Shape.tuple_shape(tuple(shape_of(elt) for elt in value))\n", " else:\n", " raise TypeError('Unexpected type: {}'.format(type(value)))\n", "\n", "def to_xla_type(dtype):\n", " \"Convert to integert xla type, for use with ConvertElementType, etc.\"\n", " if isinstance(dtype, str):\n", " return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[dtype]\n", " elif isinstance(dtype, type):\n", " return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[onp.dtype(dtype).name]\n", " elif isinstance(dtype, onp.dtype):\n", " return xla_client.DTYPE_TO_XLA_ELEMENT_TYPE[dtype.name]\n", " else:\n", " raise TypeError('Unexpected type: {}'.format(type(dtype)))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "odmjXyhMuNJ5" }, "source": [ "## Simple Computations" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "height": 33 }, "colab_type": "code", "executionInfo": { "elapsed": 364, "status": "ok", "timestamp": 1549929562036, "user": { "displayName": "Anselm Levskaya", "photoUrl": "https://lh3.googleusercontent.com/-rqjtQZ4KjgQ/AAAAAAAAAAI/AAAAAAAAAAw/5BQt0zmTW5o/s64/photo.jpg", "userId": "09409386770882740563" }, "user_tz": 480 }, "id": "UYUtxVzMYIiv", "outputId": "bd8aa18e-26d9-4df4-ebc3-20026119de17" }, "outputs": [ { "data": { "text/plain": [ "array(0.14112, dtype=float32)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# make a computation builder\n", "c = xla_client.ComputationBuilder(\"simple_scalar\")\n", "\n", "# define a parameter shape and parameter\n", "param_shape = xla_client.Shape.array_shape(onp.float32, ())\n", "x = c.ParameterWithShape(param_shape)\n", "\n", "# define computation graph\n", "y = c.Sin(x)\n", "\n", "# build computation graph\n", "# Keep in mind that incorrectly constructed graphs can cause \n", "# your notebook kernel to crash!\n", "computation = c.Build()\n", "\n", "# compile graph based on shape\n", "compiled_computation = computation.Compile([param_shape,])\n", "\n", "# define a host variable with above parameter shape\n", "host_input = onp.array(3.0, dtype=onp.float32)\n", "\n", "# place host variable on device and execute\n", "device_input = xla_client.LocalBuffer.from_pyval(host_input)\n", "device_out = compiled_computation.Execute([device_input ,])\n", "\n", "# retrive the result\n", "device_out.to_py()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "height": 33 }, "colab_type": "code", "executionInfo": { "elapsed": 350, "status": "ok", "timestamp": 1549929568548, "user": { "displayName": "Anselm Levskaya", "photoUrl": "https://lh3.googleusercontent.com/-rqjtQZ4KjgQ/AAAAAAAAAAI/AAAAAAAAAAw/5BQt0zmTW5o/s64/photo.jpg", "userId": "09409386770882740563" }, "user_tz": 480 }, "id": "rIA-IVMVvQs2", "outputId": "ce88ec6f-d2ea-4ec2-80b4-ddd1afd36957" }, "outputs": [ { "data": { "text/plain": [ "array([0.14112 , 0.7568025, 0.9589243], dtype=float32)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# same as above with vector type:\n", "\n", "c = xla_client.ComputationBuilder(\"simple_vector\")\n", "param_shape = xla_client.Shape.array_shape(onp.float32, (3,))\n", "x = c.ParameterWithShape(param_shape)\n", "\n", "# can also use this function to define a shape from an example:\n", "#x = c.ParameterFromNumpy(onp.array([0.0, 0.0, 0.0], dtype=onp.float32))\n", "\n", "# which is the same as using our convenience function above:\n", "#x = c.ParameterWithShape(shape_of(onp.array([0.0, 0.0, 0.0], \n", "# dtype=onp.float32)))\n", "\n", "# chain steps by reference:\n", "y = c.Sin(x)\n", "z = c.Abs(y)\n", "computation = c.Build()\n", "compiled_computation = computation.Compile([param_shape,])\n", "\n", "host_input = onp.array([3.0, 4.0, 5.0], dtype=onp.float32)\n", "\n", "device_input = xla_client.LocalBuffer.from_pyval(host_input)\n", "device_out = compiled_computation.Execute([device_input ,])\n", "\n", "# retrive the result\n", "device_out.to_py()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "F8kWlLaVuQ1b" }, "source": [ "## Simple While Loop" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "height": 33 }, "colab_type": "code", "executionInfo": { "elapsed": 358, "status": "ok", "timestamp": 1549929569852, "user": { "displayName": "Anselm Levskaya", "photoUrl": "https://lh3.googleusercontent.com/-rqjtQZ4KjgQ/AAAAAAAAAAI/AAAAAAAAAAw/5BQt0zmTW5o/s64/photo.jpg", "userId": "09409386770882740563" }, "user_tz": 480 }, "id": "MDQP1qW515Ao", "outputId": "4da894b5-2b0e-455e-a720-3bdadc57d164" }, "outputs": [ { "data": { "text/plain": [ "array(0, dtype=int32)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# trivial while loop, decrement until 0\n", "in_shape = shape_of(1)\n", "\n", "# body computation:\n", "bcb = xla_client.ComputationBuilder(\"bodycomp\")\n", "x = bcb.ParameterWithShape(in_shape)\n", "const = bcb.Constant(onp.int32(1))\n", "y = bcb.Sub(x, const)\n", "body_computation = bcb.Build()\n", "\n", "# test computation:\n", "tcb = xla_client.ComputationBuilder(\"testcomp\")\n", "x = tcb.ParameterWithShape(in_shape)\n", "const = tcb.Constant(onp.int32(0))\n", "y = tcb.Gt(x, const)\n", "test_computation = tcb.Build()\n", "\n", "# while computation:\n", "wcb = xla_client.ComputationBuilder(\"whilecomp\")\n", "x = wcb.ParameterWithShape(in_shape)\n", "wcb.While(test_computation, body_computation, x)\n", "while_computation = wcb.Build()\n", "\n", "# Now compile and execute:\n", "compiled_computation = while_computation.Compile([in_shape,])\n", "\n", "host_input = onp.array(5, dtype=onp.int32)\n", "\n", "device_input = xla_client.LocalBuffer.from_pyval(host_input)\n", "device_out = compiled_computation.Execute([device_input ,])\n", "\n", "# retrive the result\n", "device_out.to_py()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "7UOnXlY8slI6" }, "source": [ "## While loops w. tuples - Newton's Method for sqrt" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "height": 33 }, "colab_type": "code", "executionInfo": { "elapsed": 402, "status": "ok", "timestamp": 1549929572085, "user": { "displayName": "Anselm Levskaya", "photoUrl": "https://lh3.googleusercontent.com/-rqjtQZ4KjgQ/AAAAAAAAAAI/AAAAAAAAAAw/5BQt0zmTW5o/s64/photo.jpg", "userId": "09409386770882740563" }, "user_tz": 480 }, "id": "HEWz-vzd6QPR", "outputId": "6ef10855-232d-4701-a442-0e2667b2fd97" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "square root of 2.0 is 1.4142156839370728\n" ] } ], "source": [ "Xsqr = 2\n", "guess = 1.0\n", "converged_delta = 0.001\n", "maxit = 1000\n", "\n", "in_shape = shape_of((1.0, 1.0, 1))\n", "\n", "# body computation:\n", "# x_{i+1} = x_{i} - (x_i**2 - y) / (2 * x_i)\n", "bcb = xla_client.ComputationBuilder(\"bodycomp\")\n", "intuple = bcb.ParameterWithShape(in_shape)\n", "y = bcb.GetTupleElement(intuple, 0)\n", "x = bcb.GetTupleElement(intuple, 1)\n", "guard_cntr = bcb.GetTupleElement(intuple, 2)\n", "new_x = bcb.Sub(x, bcb.Div(bcb.Sub(bcb.Mul(x, x), y), bcb.Add(x, x)))\n", "result = bcb.Tuple(y, new_x, bcb.Sub(guard_cntr, bcb.Constant(onp.int32(1))))\n", "body_computation = bcb.Build()\n", "\n", "# test computation -- convergence and max iteration test\n", "tcb = xla_client.ComputationBuilder(\"testcomp\")\n", "intuple = tcb.ParameterWithShape(in_shape)\n", "y = tcb.GetTupleElement(intuple, 0)\n", "x = tcb.GetTupleElement(intuple, 1)\n", "guard_cntr = tcb.GetTupleElement(intuple, 2)\n", "criterion = tcb.Abs(tcb.Sub(tcb.Mul(x, x), y))\n", "# stop at convergence criteria or too many iterations\n", "test = tcb.And(tcb.Gt(criterion, tcb.Constant(onp.float32(converged_delta))), \n", " tcb.Gt(guard_cntr, tcb.Constant(onp.int32(0))))\n", "test_computation = tcb.Build()\n", "\n", "# while computation:\n", "wcb = xla_client.ComputationBuilder(\"whilecomp\")\n", "intuple = wcb.ParameterWithShape(in_shape)\n", "wcb.While(test_computation, body_computation, intuple)\n", "while_computation = wcb.Build()\n", "\n", "# Now compile and execute:\n", "compiled_computation = while_computation.Compile([in_shape,])\n", "\n", "y = onp.array(Xsqr, dtype=onp.float32)\n", "x = onp.array(guess, dtype=onp.float32)\n", "maxit = onp.array(maxit, dtype=onp.int32)\n", "\n", "device_input = xla_client.LocalBuffer.from_pyval((y, x, maxit))\n", "device_out = compiled_computation.Execute([device_input ,])\n", "\n", "host_out = device_out.to_py()\n", "print(\"square root of {y} is {x}\".format(y=y, x=host_out[1]))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "yETVIzTInFYr" }, "source": [ "## Calculate Symm Eigenvalues" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "AiyR1e2NubKa" }, "source": [ "Let's exploit the XLA QR implementation to solve some eigenvalues for symmetric matrices. \n", "\n", "This is the naive QR algorithm, without acceleration for closely-spaced eigenvalue convergence, nor any permutation to sort eigenvalues by magnitude." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "height": 455 }, "colab_type": "code", "executionInfo": { "elapsed": 1262, "status": "ok", "timestamp": 1549929575801, "user": { "displayName": "Anselm Levskaya", "photoUrl": "https://lh3.googleusercontent.com/-rqjtQZ4KjgQ/AAAAAAAAAAI/AAAAAAAAAAw/5BQt0zmTW5o/s64/photo.jpg", "userId": "09409386770882740563" }, "user_tz": 480 }, "id": "wjxDPbqCcuXT", "outputId": "9683e40b-3c5f-4f3e-c971-0613b182c68c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sorted eigenvalues\n", "[-1.190547 -0.91282177 -0.32339668 -0.14050038 -0.09441247 0.08265306\n", " 0.49015656 0.731502 1.0677357 5.3513203 ]\n", "sorted eigenvalues from numpy\n", "[-1.1905469 -0.9128221 -0.32339665 -0.14050038 -0.09441243 0.08265309\n", " 0.49015662 0.7315014 1.0677353 5.351319 ]\n", "sorted error\n", "[-1.1920929e-07 3.5762787e-07 -2.9802322e-08 0.0000000e+00\n", " -3.7252903e-08 -2.9802322e-08 -5.9604645e-08 5.9604645e-07\n", " 3.5762787e-07 1.4305115e-06]\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAEICAYAAAByNDmmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAACtdJREFUeJzt3V+snwV9x/H3h9MWaGWg6JbYNrZLjEvD/mDODMrmBbhEJ5ObZcMEM71plkxFY+JwN97uwhi9cC4N6A1ELioXxBBxiXpBliClJcO2akjtoIChcwFJIzut/e7inJmuW8956nkennO+vF9Jk55ffzx8Aufd53ee/s7TVBWSerpi7gGSpmPgUmMGLjVm4FJjBi41ZuBSYwYuNWbgr1NJTib5ZZJXkryU5F+T/G0SPyca8X/m69tfVNU1wNuAfwT+Hrh33kkak4GLqnq5qh4C/hr4myQ3zL1J4zBw/VpV/QA4Bfzp3Fs0DgPXxZ4H3jT3CI3DwHWxncB/zj1C4zBw/VqSP2Y58Efn3qJxGLhI8ltJbgMeAO6rqqfm3qRxxO8Hf31KchL4HeAccB44BtwH/HNV/WrGaRqRgUuN+RJdaszApcYMXGrMwKXGtkxx0De/aaH27N46+nF/cuL60Y8pbUavvvoSS2fPZK3nTRL4nt1b+cEju0c/7p/91UdHP6a0GT1+5J8GPc+X6FJjBi41ZuBSYwYuNWbgUmMGLjU2KPAk70/y4yRPJ7l76lGSxrFm4EkWgK8AHwD2AR9Osm/qYZLWb8gZ/F3A01V1oqqWWL4pwO3TzpI0hiGB7wSeveDjUyuP/S9J9ic5lOTQ6Z97vwBpIxjtIltVHaiqxapafMv1C2MdVtI6DAn8OeDCN5bvWnlM0gY3JPDHgbcn2ZtkG3AH8NC0sySNYc3vJquqc0k+DjwCLABfq6qjky+TtG6Dvl20qh4GHp54i6SR+U42qTEDlxozcKkxA5caM3CpsUluuviTE9dPcoPEKx59cvRjApz/kz+a5LjS3DyDS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNTXJX1alMdffTn//+1aMf8/qnfjn6MaXL5RlcaszApcYMXGrMwKXGDFxqzMClxgxcamzNwJPsTvK9JMeSHE1y12sxTNL6DXmjyzngM1V1OMk1wBNJ/qWqjk28TdI6rXkGr6oXqurwys9fAY4DO6ceJmn9Luutqkn2ADcCj/0/v7Yf2A9w5ZXXjjBN0noNvsiW5A3AN4FPVdUvLv71qjpQVYtVtbht644xN0r6DQ0KPMlWluO+v6oenHaSpLEMuYoe4F7geFV9cfpJksYy5Ax+M/AR4JYkT678+POJd0kawZoX2arqUSCvwRZJI/OdbFJjBi41ZuBSYwYuNbapbro4lSlukLj1306MfkyAs3/wu5McVz15BpcaM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGvOuqhOZ6u6nr+y5apLjXnPy1UmOq3l5BpcaM3CpMQOXGjNwqTEDlxozcKkxA5caGxx4koUkR5J8a8pBksZzOWfwu4DjUw2RNL5BgSfZBXwQuGfaOZLGNPQM/iXgs8D5Sz0hyf4kh5IcWjp7ZpRxktZnzcCT3Aa8WFVPrPa8qjpQVYtVtbht647RBkr6zQ05g98MfCjJSeAB4JYk9026StIo1gy8qj5XVbuqag9wB/Ddqrpz8mWS1s0/B5cau6zvB6+q7wPfn2SJpNF5BpcaM3CpMQOXGjNwqTEDlxrzrqqbzFR3P331t6+c5LhXvfhfkxxXw3gGlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5ca866qAqa7++mWl6a5C+y5666a5LjdeAaXGjNwqTEDlxozcKkxA5caM3CpMQOXGhsUeJLrkhxM8qMkx5O8e+phktZv6Btdvgx8u6r+Msk2YPuEmySNZM3Ak1wLvBf4KEBVLQFL086SNIYhL9H3AqeBryc5kuSeJDsuflKS/UkOJTm0dPbM6EMlXb4hgW8B3gl8tapuBM4Ad1/8pKo6UFWLVbW4bev/6V/SDIYEfgo4VVWPrXx8kOXgJW1wawZeVT8Dnk3yjpWHbgWOTbpK0iiGXkX/BHD/yhX0E8DHppskaSyDAq+qJ4HFibdIGpnvZJMaM3CpMQOXGjNwqTEDlxrzrqqa1FR3P71i6VejH/P8toXRjzk3z+BSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNeZNF7UpTXGDxCvOnR/9mADnt8x3HvUMLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjU2KPAkn05yNMkPk3wjyTR/o5ykUa0ZeJKdwCeBxaq6AVgA7ph6mKT1G/oSfQtwdZItwHbg+ekmSRrLmoFX1XPAF4BngBeAl6vqOxc/L8n+JIeSHFo6e2b8pZIu25CX6G8Ebgf2Am8FdiS58+LnVdWBqlqsqsVtW3eMv1TSZRvyEv19wE+r6nRVnQUeBN4z7SxJYxgS+DPATUm2JwlwK3B82lmSxjDka/DHgIPAYeCplX/mwMS7JI1g0PeDV9Xngc9PvEXSyHwnm9SYgUuNGbjUmIFLjRm41Jh3VZVWTHX305riuMmgp3kGlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caS1WNf9DkNPDvA576ZuA/Rh8wnc20dzNthc21dyNsfVtVvWWtJ00S+FBJDlXV4mwDLtNm2ruZtsLm2ruZtvoSXWrMwKXG5g78wMz//su1mfZupq2wufZumq2zfg0uaVpzn8ElTcjApcZmCzzJ+5P8OMnTSe6ea8dakuxO8r0kx5IcTXLX3JuGSLKQ5EiSb829ZTVJrktyMMmPkhxP8u65N60myadXPg9+mOQbSa6ae9NqZgk8yQLwFeADwD7gw0n2zbFlgHPAZ6pqH3AT8HcbeOuF7gKOzz1igC8D366q3wP+kA28OclO4JPAYlXdACwAd8y7anVzncHfBTxdVSeqagl4ALh9pi2rqqoXqurwys9fYfkTcOe8q1aXZBfwQeCeubesJsm1wHuBewGqaqmqXpp31Zq2AFcn2QJsB56fec+q5gp8J/DsBR+fYoNHA5BkD3Aj8Ni8S9b0JeCzwPm5h6xhL3Aa+PrKlxP3JNkx96hLqarngC8AzwAvAC9X1XfmXbU6L7INlOQNwDeBT1XVL+becylJbgNerKon5t4ywBbgncBXq+pG4Aywka/HvJHlV5p7gbcCO5LcOe+q1c0V+HPA7gs+3rXy2IaUZCvLcd9fVQ/OvWcNNwMfSnKS5S99bkly37yTLukUcKqq/ucV0UGWg9+o3gf8tKpOV9VZ4EHgPTNvWtVcgT8OvD3J3iTbWL5Q8dBMW1aVJCx/jXi8qr449561VNXnqmpXVe1h+b/rd6tqQ55lqupnwLNJ3rHy0K3AsRknreUZ4KYk21c+L25lA18UhOWXSK+5qjqX5OPAIyxfifxaVR2dY8sANwMfAZ5K8uTKY/9QVQ/PuKmTTwD3r/xGfwL42Mx7LqmqHktyEDjM8p+uHGGDv23Vt6pKjXmRTWrMwKXGDFxqzMClxgxcaszApcYMXGrsvwFxd1oMhwT6kgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "Niter = 200\n", "matrix_shape = (10, 10)\n", "in_shape = shape_of(\n", " (onp.zeros(matrix_shape, dtype=onp.float32), 1)\n", ")\n", "# NB: in_shape is the same as the manually constructed:\n", "# xla_client.Shape.tuple_shape(\n", "# (xla_client.Shape.array_shape(onp.float32, matrix_shape), \n", "# xla_client.Shape.array_shape(onp.int32, ()))\n", "# )\n", "\n", "# body computation -- QR loop: X_i = Q R , X_{i+1} = R Q\n", "bcb = xla_client.ComputationBuilder(\"bodycomp\")\n", "intuple = bcb.ParameterWithShape(in_shape)\n", "x = bcb.GetTupleElement(intuple, 0)\n", "cntr = bcb.GetTupleElement(intuple, 1)\n", "QR = bcb.QR(x)\n", "Q = bcb.GetTupleElement(QR, 0)\n", "R = bcb.GetTupleElement(QR, 1)\n", "RQ = bcb.Dot(R, Q)\n", "bcb.Tuple(RQ, bcb.Sub(cntr, bcb.Constant(onp.int32(1))))\n", "body_computation = bcb.Build()\n", "\n", "# test computation -- just a for loop condition\n", "tcb = xla_client.ComputationBuilder(\"testcomp\")\n", "intuple = tcb.ParameterWithShape(in_shape)\n", "cntr = tcb.GetTupleElement(intuple, 1)\n", "test = tcb.Gt(cntr, tcb.Constant(onp.int32(0)))\n", "test_computation = tcb.Build()\n", "\n", "# while computation:\n", "wcb = xla_client.ComputationBuilder(\"whilecomp\")\n", "intuple = wcb.ParameterWithShape(in_shape)\n", "wcb.While(test_computation, body_computation, intuple)\n", "while_computation = wcb.Build()\n", "\n", "# Now compile and execute:\n", "compiled_computation = while_computation.Compile([in_shape,])\n", "\n", "X = onp.random.random(matrix_shape).astype(onp.float32)\n", "X = (X + X.T) / 2.0\n", "it = onp.array(Niter, dtype=onp.int32)\n", "\n", "device_in = xla_client.LocalBuffer.from_pyval((X, it))\n", "device_out = compiled_computation.Execute([device_in,])\n", "\n", "host_out = device_out.to_py()\n", "eigh_vals = host_out[0].diagonal()\n", "\n", "plt.title('D')\n", "plt.imshow(host_out[0])\n", "print('sorted eigenvalues')\n", "print(onp.sort(eigh_vals))\n", "print('sorted eigenvalues from numpy')\n", "print(onp.sort(onp.linalg.eigh(X)[0]))\n", "print('sorted error') \n", "print(onp.sort(eigh_vals) - onp.sort(onp.linalg.eigh(X)[0]))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "FpggTihknAOw" }, "source": [ "## Calculate Full Symm Eigensystem" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Qos4ankYuj1T" }, "source": [ "We can also calculate the eigenbasis by accumulating the Qs." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "height": 1000 }, "colab_type": "code", "executionInfo": { "elapsed": 1569, "status": "ok", "timestamp": 1549929587147, "user": { "displayName": "Anselm Levskaya", "photoUrl": "https://lh3.googleusercontent.com/-rqjtQZ4KjgQ/AAAAAAAAAAI/AAAAAAAAAAw/5BQt0zmTW5o/s64/photo.jpg", "userId": "09409386770882740563" }, "user_tz": 480 }, "id": "Kp3A-aAiZk0g", "outputId": "ebdc1ecc-c9e1-4e95-b989-9645f8648ee0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sorted eigenvalues\n", "[-0.94551486 -0.63820213 -0.57944936 -0.28589356 -0.05510262 0.16862962\n", " 0.4192178 0.4671099 0.88734317 4.990509 ]\n", "sorted eigenvalues from numpy\n", "[-0.9455159 -0.63820285 -0.5794492 -0.28589386 -0.05510259 0.16862962\n", " 0.41921794 0.46710995 0.88734376 4.9905105 ]\n", "sorted error\n", "[ 1.0132790e-06 7.1525574e-07 -1.7881393e-07 2.9802322e-07\n", " -2.9802322e-08 0.0000000e+00 -1.4901161e-07 -5.9604645e-08\n", " -5.9604645e-07 -1.4305115e-06]\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAEICAYAAAByNDmmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAACtZJREFUeJzt3V+snwV9x/H3h3NKoMWhIHGxbWyXGBfGsmCOBiXzArzAyeRmyTCBbN6QJVPRmDjcjbe7MEYvjFsDegORi8oFMURcol6YZZVDIYG2apraQfkzuxmQVUhb+93FOTOsW8956nkennO+vF8JSc+vPx4+gfPm+f1+5/d7mqpCUk+XzD1A0nQMXGrMwKXGDFxqzMClxgxcaszApcYM/E0qyfEkryZ5JclLSf4lyd8k8XuiEf9jvrn9eVW9BXgX8A/A3wH3zTtJYzJwUVUvV9XDwF8Cf5Xkurk3aRwGrt+qqh8DJ4A/nXuLxmHgOt/zwFVzj9A4DFzn2wn8cu4RGoeB67eSvI+VwH809xaNw8BFkt9LcivwIHB/VT019yaNI34e/M0pyXHgHcBZ4BxwGLgf+Meq+s2M0zQiA5ca8yG61JiBS40ZuNSYgUuNLU5x0LdftVB7dm8b/bg/O3r16MeUtqLXTr/E6TOnst79Jgl8z+5t/PjR3aMf95bb7hz9mNJW9K9P/9Og+/kQXWrMwKXGDFxqzMClxgxcaszApcYGBZ7kliQ/TXI0yT1Tj5I0jnUDT7IAfA34CHAt8PEk1049TNLGDTmDvx84WlXHquo0KxcFuG3aWZLGMCTwncCzr/v6xOpt/0uSu5IsJ1k++Z9eL0DaDEZ7ka2q9lXVUlUtXXP1wliHlbQBQwJ/Dnj9G8t3rd4maZMbEvhjwLuT7E1yKXA78PC0sySNYd1Pk1XV2SSfBB4FFoBvVNWhyZdJ2rBBHxetqkeARybeImlkvpNNaszApcYMXGrMwKXGDFxqbJKLLv7s6NWTXCCxHpvmz8TL+/54kuNKc/MMLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41NslVVacy1dVPf/lHV4x+zKsO/dfox5QulmdwqTEDlxozcKkxA5caM3CpMQOXGjNwqbF1A0+yO8kPkhxOcijJ3W/EMEkbN+SNLmeBz1XVwSRvAR5P8s9VdXjibZI2aN0zeFW9UFUHV3/9CnAE2Dn1MEkbd1HPwZPsAa4HDvw/v3dXkuUky2fOnhpnnaQNGRx4kiuAbwOfqapfnf/7VbWvqpaqamnb4o4xN0r6HQ0KPMk2VuJ+oKoemnaSpLEMeRU9wH3Akar68vSTJI1lyBn8RuBO4KYkT67+9WcT75I0gnV/TFZVPwLyBmyRNDLfySY1ZuBSYwYuNWbgUmNb6qKLU5niAomXHHtu9GMCnPsD3yWs4TyDS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNeVXViUx19dNLjr84yXHP7fn9SY6reXkGlxozcKkxA5caM3CpMQOXGjNwqTEDlxobHHiShSRPJPnOlIMkjedizuB3A0emGiJpfIMCT7IL+Chw77RzJI1p6Bn8K8DngXMXukOSu5IsJ1k+c/bUKOMkbcy6gSe5FfhFVT2+1v2qal9VLVXV0rbFHaMNlPS7G3IGvxH4WJLjwIPATUnun3SVpFGsG3hVfaGqdlXVHuB24PtVdcfkyyRtmD8Hlxq7qM+DV9UPgR9OskTS6DyDS40ZuNSYgUuNGbjUmIFLjXlV1S1mqqufvvqOyyc57uX//uokx9UwnsGlxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApca8qqqA6a5++to1l01y3MtOvjbJcbvxDC41ZuBSYwYuNWbgUmMGLjVm4FJjBi41NijwJG9Nsj/JT5IcSfKBqYdJ2rihb3T5KvDdqvqLJJcC2yfcJGkk6wae5ErgQ8BfA1TVaeD0tLMkjWHIQ/S9wEngm0meSHJvkh3n3ynJXUmWkyyfOXtq9KGSLt6QwBeB9wJfr6rrgVPAPeffqar2VdVSVS1tW/w//UuawZDATwAnqurA6tf7WQle0ia3buBV9SLwbJL3rN50M3B40lWSRjH0VfRPAQ+svoJ+DPjEdJMkjWVQ4FX1JLA08RZJI/OdbFJjBi41ZuBSYwYuNWbgUmNeVVWTmurqp5e8/OvRj3nuyn6fofIMLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjXnRRW9IUF0j8zfZpclj49dlJjjuEZ3CpMQOXGjNwqTEDlxozcKkxA5caM3CpsUGBJ/lskkNJnk7yrSSXTT1M0satG3iSncCngaWqug5YAG6fepikjRv6EH0RuDzJIrAdeH66SZLGsm7gVfUc8CXgGeAF4OWq+t7590tyV5LlJMtnzp4af6mkizbkIfrbgNuAvcA7gR1J7jj/flW1r6qWqmpp2+KO8ZdKumhDHqJ/GPh5VZ2sqjPAQ8AHp50laQxDAn8GuCHJ9iQBbgaOTDtL0hiGPAc/AOwHDgJPrf49+ybeJWkEgz4AW1VfBL448RZJI/OdbFJjBi41ZuBSYwYuNWbgUmNeVVVaNdnVT8+dm+CgNehensGlxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcZSNezqjBd10OQk8G8D7vp24D9GHzCdrbR3K22FrbV3M2x9V1Vds96dJgl8qCTLVbU024CLtJX2bqWtsLX2bqWtPkSXGjNwqbG5A9838z//Ym2lvVtpK2ytvVtm66zPwSVNa+4zuKQJGbjU2GyBJ7klyU+THE1yz1w71pNkd5IfJDmc5FCSu+feNESShSRPJPnO3FvWkuStSfYn+UmSI0k+MPemtST57Or3wdNJvpXksrk3rWWWwJMsAF8DPgJcC3w8ybVzbBngLPC5qroWuAH420289fXuBo7MPWKArwLfrao/BP6ETbw5yU7g08BSVV0HLAC3z7tqbXOdwd8PHK2qY1V1GngQuG2mLWuqqheq6uDqr19h5Rtw57yr1pZkF/BR4N65t6wlyZXAh4D7AKrqdFW9NO+qdS0ClydZBLYDz8+8Z01zBb4TePZ1X59gk0cDkGQPcD1wYN4l6/oK8Hlgij95fkx7gZPAN1efTtybZMfcoy6kqp4DvgQ8A7wAvFxV35t31dp8kW2gJFcA3wY+U1W/mnvPhSS5FfhFVT0+95YBFoH3Al+vquuBU8Bmfj3mbaw80twLvBPYkeSOeVetba7AnwN2v+7rXau3bUpJtrES9wNV9dDce9ZxI/CxJMdZeepzU5L75510QSeAE1X1P4+I9rMS/Gb1YeDnVXWyqs4ADwEfnHnTmuYK/DHg3Un2JrmUlRcqHp5py5qShJXniEeq6stz71lPVX2hqnZV1R5W/r1+v6o25Vmmql4Enk3yntWbbgYOzzhpPc8ANyTZvvp9cTOb+EVBWHmI9IarqrNJPgk8ysorkd+oqkNzbBngRuBO4KkkT67e9vdV9ciMmzr5FPDA6v/ojwGfmHnPBVXVgST7gYOs/HTlCTb521Z9q6rUmC+ySY0ZuNSYgUuNGbjUmIFLjRm41JiBS439NyrdYAajSKUYAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAEICAYAAAByNDmmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAADO9JREFUeJzt3W1snfV5x/HfL46fEjqHsm4dcUbCoLCoFQK5ERSp6yDS+rSirawLU2jLm2jqgLRiYxStKqu0NxtCIK1DuNBKE6h0DVnF2ohSrXRt1DatCag0SWnDQ43TIMKYAwkhsZNrL+xJWbb43Mb/P7d95fuRkOKTOxeXbH9zn3N8zh1HhADktKjtBQDUQ+BAYgQOJEbgQGIEDiRG4EBiBA4kRuCnKNth+5wTbrvF9r1t7YTyCBxIjMCBxAgcSIzAgcQI/NR1VFL3Cbd1S5poYRdUQuCnrlFJK0+4bZWkX77xq6AWAj91fUXS39getL3I9lpJfyhpU8t7oSDzfvBTk+1+SZ+T9CeSTpf0lKRbIuLBVhdDUQQOJMZddCAxAgcSI3AgMQIHEltcY2j3QH/0vXWg+NxFoy4+U5J81tHiMw/v7y0+U5KO9VZ6UrTOp1b9vUeqzD2j+2DxmQeO1vmajR9YWnzm5Esv6eiBgx2/alUC73vrgC76p/XF5/Zf31N8piT13TVefObPt5xbfKYkHVxV6YVmPceqjL3g7LEqc6/+rR8Un/n9V87pfNDr8K9b1xSfufcfbm90HHfRgcQIHEiMwIHECBxIjMCBxAgcSKxR4Lbfa/tJ27tt31R7KQBldAzcdpekz0t6n6TVkq6yvbr2YgDmrskZfI2k3RHxdEQckXS/pCvqrgWghCaBL5f03HEfj03f9r/Y3mB7xPbIxPirpfYDMAfFnmSLiOGIGIqIoe5lS0qNBTAHTQLfI2nFcR8PTt8GYJ5rEviPJZ1re5XtHknrJHHdLmAB6PhusoiYtH2tpG9K6pL0xYjYUX0zAHPW6O2iEbFF0pbKuwAojFeyAYkROJAYgQOJETiQGIEDiVW56GKXj+lNPYeLz/3Tr20tPlOS/vaBjxSf+ZF13ys+U5K2/9eKzge9Dv5onb/rd1x7dpW5t7ztjOIzI+pcWnbxq+XnuuE1MjmDA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJVbmq6uSxRfrPQ+X/jfDPfbX81U8lacUjR4rP/N73Lyk+U5L+Y3i4ytx33lnnc9v33TpXKn31QG/xmWffFcVnStJTVzW8BOosRFez4ziDA4kROJAYgQOJETiQGIEDiRE4kBiBA4l1DNz2CtuP2N5pe4ftjW/EYgDmrskLXSYl3RAR222/SdKjtr8VETsr7wZgjjqewSNib0Rsn/71K5J2SVpeezEAczerx+C2V0q6UNK2/+f3NtgesT0ysf9Qme0AzEnjwG2fJukBSZ+MiJdP/P2IGI6IoYgY6h7oL7kjgNepUeC2uzUV930RsbnuSgBKafIsuiXdI2lXRNxWfyUApTQ5g18q6WpJl9l+fPq/91feC0ABHX9MFhFbJdV5Uy+AqnglG5AYgQOJETiQGIEDiVW56KIldS8qf6G5vn11nut7/hOvFZ85+OEdxWdK0u/e9Ykqc8/6ux9VmfviNWdUmduzta/4zLG1db6/Bh+eLD7zxZebXSCSMziQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kBiBA4kROJAYgQOJETiQGIEDiRE4kFidq6o61NtV/kqSh36z2ZUkZ2vwrvL/3PEv7ri4+ExJ6jpc53PQ9+91rn468bU6Vyp9884jxWfefMO/FJ8pSff84zuLz+w6cLjRcZzBgcQIHEiMwIHECBxIjMCBxAgcSIzAgcQaB267y/Zjtr9ecyEA5czmDL5R0q5aiwAor1HgtgclfUDS3XXXAVBS0zP47ZJulHTsZAfY3mB7xPbIxP5DRZYDMDcdA7f9QUkvRMSjMx0XEcMRMRQRQ90D5V/bDWD2mpzBL5X0IdvPSrpf0mW27626FYAiOgYeEZ+OiMGIWClpnaRvR8T66psBmDN+Dg4kNqv3g0fEdyR9p8omAIrjDA4kRuBAYgQOJEbgQGIEDiRW5aqqXT6mgd7yL1dd854673VZ+8c7i8+87QtXFp8pST5aZawGl4xXmXvkoX1V5o5e8ZbiM2/8xp8VnylJ53xlrPjM+PNm52bO4EBiBA4kRuBAYgQOJEbgQGIEDiRG4EBiBA4kRuBAYgQOJEbgQGIEDiRG4EBiBA4kRuBAYgQOJEbgQGIEDiRG4EBiBA4kRuBAYlWuqlrLD59dWWXuky/9RvGZB95xuPhMSRr4UW+VuQ/vPr/K3L9/8IEqcz/zhY8Wn3n+raPFZ0rS0TPPKD90rFm6nMGBxAgcSIzAgcQIHEiMwIHECBxIjMCBxBoFbnuZ7U22f2Z7l+1Lai8GYO6avtDlDkkPRcSVtnskLam4E4BCOgZue0DSuyV9XJIi4oikI3XXAlBCk7voqyTtk/Ql24/Zvtv20hMPsr3B9ojtkcPjrxVfFMDsNQl8saSLJN0ZERdKOijpphMPiojhiBiKiKHeZX2F1wTwejQJfEzSWERsm/54k6aCBzDPdQw8Ip6X9Jzt86ZvulzSzqpbASii6bPo10m6b/oZ9KclXVNvJQClNAo8Ih6XNFR5FwCF8Uo2IDECBxIjcCAxAgcSI3AgsSpXVV3kUF/XRPG55/3VC8VnStLPN55VfOb7f3978ZmS9I3XLqgyd9nW//Pq4yL+evfVVeae889PFZ85uv53is+UpDVX/qT4zJ6PNbtqL2dwIDECBxIjcCAxAgcSI3AgMQIHEiNwIDECBxIjcCAxAgcSI3AgMQIHEiNwIDECBxIjcCAxAgcSI3AgMQIHEiNwIDECBxKrctHFY7FIr072FJ/7zB1vLj5Tkt72l3uKz3ziu3Uujnj+6MtV5j77R3U+t+94zy+qzP1J97nFZ06cXv5CoZL02/0vFZ/Zs2iy0XGcwYHECBxIjMCBxAgcSIzAgcQIHEiMwIHEGgVu+1O2d9j+qe0v2+6rvRiAuesYuO3lkq6XNBQRb5fUJWld7cUAzF3Tu+iLJfXbXixpiaRf1VsJQCkdA4+IPZJulTQqaa+k/RHx8InH2d5ge8T2yOHxQ+U3BTBrTe6iny7pCkmrJJ0paant9SceFxHDETEUEUO9y/rLbwpg1prcRV8r6ZmI2BcRE5I2S3pX3bUAlNAk8FFJF9teYtuSLpe0q+5aAEpo8hh8m6RNkrZLemL6zwxX3gtAAY3eDx4Rn5X02cq7ACiMV7IBiRE4kBiBA4kROJAYgQOJVbmq6sSxRdp78NeKzx3YfFrxmZI0vqb83KMff7H4TEnq/0xvlblynbGHPlbna3bdv20pPvOrN/9B8ZmS9OCjv1d85vgL2xsdxxkcSIzAgcQIHEiMwIHECBxIjMCBxAgcSIzAgcQIHEiMwIHECBxIjMCBxAgcSIzAgcQIHEiMwIHECBxIjMCBxAgcSIzAgcQIHEjMEVF+qL1P0i8bHPrrkupcfrSOhbTvQtpVWlj7zoddz4qIt3Q6qErgTdkeiYih1haYpYW070LaVVpY+y6kXbmLDiRG4EBibQc+3PL/f7YW0r4LaVdpYe27YHZt9TE4gLraPoMDqIjAgcRaC9z2e20/aXu37Zva2qMT2ytsP2J7p+0dtje2vVMTtrtsP2b7623vMhPby2xvsv0z27tsX9L2TjOx/anp74Of2v6y7b62d5pJK4Hb7pL0eUnvk7Ra0lW2V7exSwOTkm6IiNWSLpb0F/N41+NtlLSr7SUauEPSQxFxvqQLNI93tr1c0vWShiLi7ZK6JK1rd6uZtXUGXyNpd0Q8HRFHJN0v6YqWdplRROyNiO3Tv35FU9+Ay9vdama2ByV9QNLdbe8yE9sDkt4t6R5JiogjETHe7lYdLZbUb3uxpCWSftXyPjNqK/Dlkp477uMxzfNoJMn2SkkXStrW7iYd3S7pRknH2l6kg1WS9kn60vTDibttL217qZOJiD2SbpU0KmmvpP0R8XC7W82MJ9kasn2apAckfTIiXm57n5Ox/UFJL0TEo23v0sBiSRdJujMiLpR0UNJ8fj7mdE3d01wl6UxJS22vb3ermbUV+B5JK477eHD6tnnJdrem4r4vIja3vU8Hl0r6kO1nNfXQ5zLb97a70kmNSRqLiP+5R7RJU8HPV2slPRMR+yJiQtJmSe9qeacZtRX4jyWda3uV7R5NPVHxYEu7zMi2NfUYcVdE3Nb2Pp1ExKcjYjAiVmrq8/rtiJiXZ5mIeF7Sc7bPm77pckk7W1ypk1FJF9teMv19cbnm8ZOC0tRdpDdcREzavlbSNzX1TOQXI2JHG7s0cKmkqyU9Yfvx6dtujogtLe6UyXWS7pv+i/5pSde0vM9JRcQ225skbdfUT1ce0zx/2SovVQUS40k2IDECBxIjcCAxAgcSI3AgMQIHEiNwILH/Bobry1k8oM1RAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAEICAYAAAByNDmmAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAC71JREFUeJzt3X+sX3V9x/Hni7YEWhZUYFtsm7VujoW5bJgrQUn4A1iCw0my+QcECPMfsmwqOhODyxKW/bkYo384Z4O6PyDiUogjjAgm6h9mGaP8SKCtkqYwKD+0bhOlim3pe3/cr0nX2Xu/5Z7Tc+87z0dyk/s99/R835D77Dnf7z3301QVkno6Y+oBJI3HwKXGDFxqzMClxgxcaszApcYMvJkk5ye5PcllU8+i6Rn4GpCkkvzWCdv+NsmdJ2zbBPwr8IfA/UkuPu5rNyR5dfbxsyTHjnv86hLPnST7k+wZak6dPgbeRJINwD3AHuBy4M+B+5L8JkBV3VVV51TVOcB7gRd/8Xi27WQuB34VeFuSd437X6GhrZ96AK1ckgD/BDwD/EUt3p741SQ/ZzHyK6rq+2/w8DcD/wKcPfv8kQFG1mli4A3Mgr7hl2z/GvC1N3rcJBuBDwDXsRj4F5L8VVUdfqPH1OnlJbqW8ifAz4GHWHxtvwG4ZtKJdEoMfG14ncW4jrcBODLy894M/HNVHa2q11h8jX/zEvtPNadOwkv0teE5YBuw97ht24Gnx3rCJFuAK4BLkvzpbPNG4Kwk51fVD1fDnFqaZ/C14avA3yTZkuSMJFcBfwzsHPE5b2IxzAuBP5h9/DZwALh+Fc2pJRj42vB3wL8B3wH+B/h74IaqemrE57wZ+Ieqevn4D+AfOfll+hRzaglxwQepL8/gUmMGLjVm4FJjBi41NsrPwc9/y7ratvXE+x1W7ul95w1+TGkteu3wjzh85FCW22+UwLdt3cB/PLh18ONefe1Ngx9TWov+/akvzLWfl+hSYwYuNWbgUmMGLjVm4FJjBi41NlfgSa5O8r0k+5LcNvZQkoaxbOBJ1gGfY3ElzouA65NcNPZgklZunjP4JcC+qto/W2zvbuDacceSNIR5At8MPH/c4wOzbf9HkluS7Eqy6+B/vT7UfJJWYLA32apqR1UtVNXCBeetG+qwklZgnsBfAI6/sXzLbJukVW6ewB8B3p5ke5IzWVwE/75xx5I0hGV/m6yqjib5EPAgsA74UlXtHn0ySSs216+LVtUDwAMjzyJpYN7JJjVm4FJjBi41ZuBSYwYuNTbKootP7ztvlAUS65EnBz8mQN71e6McV5qaZ3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqbFRVlUdy1irn/73754z+DHfsvvVwY8pnSrP4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjywaeZGuSbyXZk2R3kltPx2CSVm6eG12OAh+vqseS/ArwaJJvVNWekWeTtELLnsGr6qWqemz2+U+AvcDmsQeTtHKn9Bo8yTbgYuDhX/K1W5LsSrLryNFDw0wnaUXmDjzJOcA9wEer6scnfr2qdlTVQlUtbFi/acgZJb1BcwWeZAOLcd9VVfeOO5KkoczzLnqALwJ7q+rT448kaSjznMEvA24CrkjyxOzjj0aeS9IAlv0xWVV9B8hpmEXSwLyTTWrMwKXGDFxqzMClxtbUootjGWOBxDP2vzD4MQGOvc27hDU/z+BSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmOuqjqSsVY/PePZl0c57rFtvz7KcTUtz+BSYwYuNWbgUmMGLjVm4FJjBi41ZuBSY3MHnmRdkseT3D/mQJKGcypn8FuBvWMNIml4cwWeZAtwDXDHuONIGtK8Z/DPAJ8Ajp1shyS3JNmVZNeRo4cGGU7SyiwbeJL3AT+oqkeX2q+qdlTVQlUtbFi/abABJb1x85zBLwPen+RZ4G7giiR3jjqVpEEsG3hVfbKqtlTVNuA64JtVdePok0laMX8OLjV2Sr8PXlXfBr49yiSSBucZXGrMwKXGDFxqzMClxgxcasxVVdeYsVY//dmvnT3Kcc/+/s9GOa7m4xlcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMVVUFjLf66WsXnDXKcc86+Noox+3GM7jUmIFLjRm41JiBS40ZuNSYgUuNGbjU2FyBJ3lTkp1Jvptkb5J3jz2YpJWb90aXzwJfr6oPJDkT2DjiTJIGsmzgSc4FLgf+DKCqDgOHxx1L0hDmuUTfDhwEvpzk8SR3JNl04k5JbkmyK8muI0cPDT6opFM3T+DrgXcCn6+qi4FDwG0n7lRVO6pqoaoWNqz/f/1LmsA8gR8ADlTVw7PHO1kMXtIqt2zgVfUy8HySC2ebrgT2jDqVpEHM+y76h4G7Zu+g7wc+ON5IkoYyV+BV9QSwMPIskgbmnWxSYwYuNWbgUmMGLjVm4FJjrqqqUY21+ukZr/x08GMeO7ff71B5BpcaM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMRdd1Jo0xgKJr28cJ4d1Pz06ynHn4RlcaszApcYMXGrMwKXGDFxqzMClxgxcamyuwJN8LMnuJE8l+UqSs8YeTNLKLRt4ks3AR4CFqnoHsA64buzBJK3cvJfo64Gzk6wHNgIvjjeSpKEsG3hVvQB8CngOeAl4paoeOnG/JLck2ZVk15Gjh4afVNIpm+cS/c3AtcB24K3ApiQ3nrhfVe2oqoWqWtiwftPwk0o6ZfNcol8FPFNVB6vqCHAv8J5xx5I0hHkCfw64NMnGJAGuBPaOO5akIczzGvxhYCfwGPDk7M/sGHkuSQOY6xdgq+p24PaRZ5E0MO9kkxozcKkxA5caM3CpMQOXGnNVVWlmtNVPjx0b4aA1116ewaXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxlI13+qMp3TQ5CDwn3Psej7ww8EHGM9amnctzQpra97VMOtvVNUFy+00SuDzSrKrqhYmG+AUraV519KssLbmXUuzeokuNWbgUmNTB75j4uc/VWtp3rU0K6ytedfMrJO+Bpc0rqnP4JJGZOBSY5MFnuTqJN9Lsi/JbVPNsZwkW5N8K8meJLuT3Dr1TPNIsi7J40nun3qWpSR5U5KdSb6bZG+Sd08901KSfGz2ffBUkq8kOWvqmZYySeBJ1gGfA94LXARcn+SiKWaZw1Hg41V1EXAp8JereNbj3QrsnXqIOXwW+HpV/Q7w+6zimZNsBj4CLFTVO4B1wHXTTrW0qc7glwD7qmp/VR0G7gaunWiWJVXVS1X12Ozzn7D4Dbh52qmWlmQLcA1wx9SzLCXJucDlwBcBqupwVf1o2qmWtR44O8l6YCPw4sTzLGmqwDcDzx/3+ACrPBqAJNuAi4GHp51kWZ8BPgGM8S/PD2k7cBD48uzlxB1JNk091MlU1QvAp4DngJeAV6rqoWmnWppvss0pyTnAPcBHq+rHU89zMkneB/ygqh6depY5rAfeCXy+qi4GDgGr+f2YN7N4pbkdeCuwKcmN0061tKkCfwHYetzjLbNtq1KSDSzGfVdV3Tv1PMu4DHh/kmdZfOlzRZI7px3ppA4AB6rqF1dEO1kMfrW6Cnimqg5W1RHgXuA9E8+0pKkCfwR4e5LtSc5k8Y2K+yaaZUlJwuJrxL1V9emp51lOVX2yqrZU1TYW/79+s6pW5Vmmql4Gnk9y4WzTlcCeCUdaznPApUk2zr4vrmQVvykIi5dIp11VHU3yIeBBFt+J/FJV7Z5iljlcBtwEPJnkidm2v66qByacqZMPA3fN/qLfD3xw4nlOqqoeTrITeIzFn648ziq/bdVbVaXGfJNNaszApcYMXGrMwKXGDFxqzMClxgxcaux/AWPSvgMr5OSGAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "Niter = 100\n", "matrix_shape = (10, 10)\n", "in_shape = shape_of(\n", " (onp.zeros(matrix_shape, dtype=onp.float32), \n", " onp.eye(matrix_shape[0]),\n", " 1)\n", ")\n", "\n", "# body computation -- QR loop: X_i = Q R , X_{i+1} = R Q\n", "bcb = xla_client.ComputationBuilder(\"bodycomp\")\n", "intuple = bcb.ParameterWithShape(in_shape)\n", "X = bcb.GetTupleElement(intuple, 0)\n", "O = bcb.GetTupleElement(intuple, 1)\n", "cntr = bcb.GetTupleElement(intuple, 2)\n", "QR = bcb.QR(X)\n", "Q = bcb.GetTupleElement(QR, 0)\n", "R = bcb.GetTupleElement(QR, 1)\n", "RQ = bcb.Dot(R, Q)\n", "Onew = bcb.Dot(O, Q)\n", "bcb.Tuple(RQ, Onew, bcb.Sub(cntr, bcb.Constant(onp.int32(1))))\n", "body_computation = bcb.Build()\n", "\n", "# test computation -- just a for loop condition\n", "tcb = xla_client.ComputationBuilder(\"testcomp\")\n", "intuple = tcb.ParameterWithShape(in_shape)\n", "cntr = tcb.GetTupleElement(intuple, 2)\n", "test = tcb.Gt(cntr, tcb.Constant(onp.int32(0)))\n", "test_computation = tcb.Build()\n", "\n", "# while computation:\n", "wcb = xla_client.ComputationBuilder(\"whilecomp\")\n", "intuple = wcb.ParameterWithShape(in_shape)\n", "wcb.While(test_computation, body_computation, intuple)\n", "while_computation = wcb.Build()\n", "\n", "# Now compile and execute:\n", "compiled_computation = while_computation.Compile([in_shape,])\n", "\n", "X = onp.random.random(matrix_shape).astype(onp.float32)\n", "X = (X + X.T) / 2.0\n", "Omat = onp.eye(matrix_shape[0], dtype=onp.float32)\n", "it = onp.array(Niter, dtype=onp.int32)\n", "\n", "device_in = xla_client.LocalBuffer.from_pyval((X, Omat, it))\n", "device_out = compiled_computation.Execute([device_in,])\n", "\n", "host_out = device_out.to_py()\n", "eigh_vals = host_out[0].diagonal()\n", "eigh_mat = host_out[1]\n", "\n", "plt.title('D')\n", "plt.imshow(host_out[0])\n", "plt.figure()\n", "plt.title('U')\n", "plt.imshow(eigh_mat)\n", "plt.figure()\n", "plt.title('U^T A U')\n", "plt.imshow(onp.dot(onp.dot(eigh_mat.T, X), eigh_mat))\n", "print('sorted eigenvalues')\n", "print(onp.sort(eigh_vals))\n", "print('sorted eigenvalues from numpy')\n", "print(onp.sort(onp.linalg.eigh(X)[0]))\n", "print('sorted error') \n", "print(onp.sort(eigh_vals) - onp.sort(onp.linalg.eigh(X)[0]))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Ee3LMzOvlCuK" }, "source": [ "## Convolutions\n", "\n", "I keep hearing from the AGI folks that we can use convolutions to build artificial life. Let's try it out." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "height": 132 }, "colab_type": "code", "executionInfo": { "elapsed": 1347, "status": "ok", "timestamp": 1549929594704, "user": { "displayName": "Anselm Levskaya", "photoUrl": "https://lh3.googleusercontent.com/-rqjtQZ4KjgQ/AAAAAAAAAAI/AAAAAAAAAAw/5BQt0zmTW5o/s64/photo.jpg", "userId": "09409386770882740563" }, "user_tz": 480 }, "id": "J8QkirDalBse", "outputId": "73c53980-8dbd-497b-fe56-7e606a29c19f" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAABFoAAABwCAYAAAAuRhTQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAABBBJREFUeJzt3MFt20AQQFHKUBWpwk0EqSBVpoIgTbiKlBH6FMDRxZL4lxI3753sk4jB8vKxnNO6rgsAAAAA2708+gEAAAAAZiG0AAAAAESEFgAAAICI0AIAAAAQEVoAAAAAIkILAAAAQERoAQAAAIic9/yxry/f1z1/bza//vw4/f3bLLf5OMtlMc+tnM2OWXa85y1ns+NstpzNjll2vOctZ7Njlp3L9/wjN1oAAAAAIkILAAAAQERoAQAAAIgILQAAAAARoQUAAAAgIrQAAAAARIQWAAAAgIjQAgAAABARWgAAAAAiQgsAAABARGgBAAAAiAgtAAAAABGhBQAAACAitAAAAABEhBYAAACAiNACAAAAEBFaAAAAACJCCwAAAEBEaAEAAACInB/9AKWfv9/++f/bl9cHPQkAAADwP3KjBQAAACAitAAAAABEhBYAAACAyFQ7Wi53stjZcr3PZmWWLfO8n9mNY7YAALCdGy0AAAAAEaEFAAAAICK0AAAAAEQOvaPlcp8A9/tsJwu3setinFvPqtlfz26mbey62o9ZbmN+45gtAMviRgsAAABARmgBAAAAiAgtAAAAAJFD72jx3et+zPo2dt6MYydLx7lsee879lyMZdfVOHYz3c+eq32Z5/3MbpyZZutGCwAAAEBEaAEAAACICC0AAAAAkUPvaGGcI38P94zMs2OWHbPcl3lfz76bsexk6TibHe99a6ZdF8/GnqtxZtrN5EYLAAAAQERoAQAAAIgILQAAAAARO1oAYIAjfUf87MyyZZ4ds9yPWd/Gzptx7GTpzHwu3WgBAAAAiAgtAAAAABGhBQAAACBiRwsAAPBU7LlomWfHLDszz9KNFgAAAICI0AIAAAAQEVoAAAAAIkILAAAAQERoAQAAAIgILQAAAAARoQUAAAAgIrQAAAAARIQWAAAAgIjQAgAAABARWgAAAAAiQgsAAABARGgBAAAAiAgtAAAAABGhBQAAACAitAAAAABEhBYAAACAiNACAAAAEBFaAAAAACJCCwAAAEBEaAEAAACICC0AAAAAEaEFAAAAICK0AAAAAESEFgAAAICI0AIAAAAQEVoAAAAAIkILAAAAQERoAQAAAIgILQAAAAARoQUAAAAgIrQAAAAARIQWAAAAgIjQAgAAABARWgAAAAAip3VdH/0MAAAAAFNwowUAAAAgIrQAAAAARIQWAAAAgIjQAgAAABARWgAAAAAiQgsAAABARGgBAAAAiAgtAAAAABGhBQAAACAitAAAAABEhBYAAACAiNACAAAAEBFaAAAAACJCCwAAAEBEaAEAAACICC0AAAAAEaEFAAAAICK0AAAAAESEFgAAAICI0AIAAAAQEVoAAAAAIkILAAAAQERoAQAAAIi8A8vjuqwsx0TPAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "Niter=13\n", "matrix_shape = (1,1, 20, 20)\n", "in_shape = shape_of(\n", " (onp.zeros(matrix_shape, dtype=onp.int32), 1)\n", ")\n", "\n", "# Body computation -- Conway Update\n", "bcb = xla_client.ComputationBuilder(\"bodycomp\")\n", "intuple = bcb.ParameterWithShape(in_shape)\n", "x = bcb.GetTupleElement(intuple, 0)\n", "cntr = bcb.GetTupleElement(intuple, 1)\n", "# convs require floating-point type\n", "xf = bcb.ConvertElementType(x, to_xla_type('float32'))\n", "stamp = bcb.Constant(onp.ones((1,1,3,3), dtype=onp.float32))\n", "convd = bcb.Conv(xf, stamp, onp.array([1, 1]), xla_client.PaddingType.SAME)\n", "# logic ops require integer types\n", "convd = bcb.ConvertElementType(convd, to_xla_type('int32'))\n", "bool_x = bcb.Eq(x, bcb.ConstantS32Scalar(1))\n", "# core update rule\n", "res = bcb.Or(\n", " # birth rule\n", " bcb.And(bcb.Not(bool_x), bcb.Eq(convd, bcb.ConstantS32Scalar(3))),\n", " # survival rule\n", " bcb.And(bool_x, bcb.Or(\n", " # these are +1 the normal numbers since conv-sum counts self\n", " bcb.Eq(convd, bcb.ConstantS32Scalar(4)),\n", " bcb.Eq(convd, bcb.ConstantS32Scalar(3)))\n", " )\n", ")\n", "# Convert output back to int type for type constancy\n", "int_res = bcb.ConvertElementType(res, to_xla_type('int32'))\n", "bcb.Tuple(int_res, bcb.Sub(cntr, bcb.ConstantS32Scalar(1)))\n", "body_computation = bcb.Build()\n", "\n", "# Test computation -- just a for loop condition\n", "tcb = xla_client.ComputationBuilder(\"testcomp\")\n", "intuple = tcb.ParameterWithShape(in_shape)\n", "cntr = tcb.GetTupleElement(intuple, 1)\n", "test = tcb.Gt(cntr, tcb.ConstantS32Scalar(0))\n", "test_computation = tcb.Build()\n", "\n", "# While computation:\n", "wcb = xla_client.ComputationBuilder(\"whilecomp\")\n", "intuple = wcb.ParameterWithShape(in_shape)\n", "wcb.While(test_computation, body_computation, intuple)\n", "while_computation = wcb.Build()\n", "\n", "# Now compile and execute:\n", "compiled_computation = while_computation.Compile([in_shape,])\n", "\n", "# Set up initial state\n", "X = onp.zeros(matrix_shape, dtype=onp.int32)\n", "X[0,0, 5:8, 5:8] = onp.array([[0,1,0],[0,0,1],[1,1,1]])\n", "\n", "# Evolve\n", "movie = onp.zeros((Niter,)+matrix_shape[-2:], dtype=onp.int32)\n", "for it in range(Niter):\n", " itr = onp.array(it, dtype=onp.int32)\n", " device_in = xla_client.LocalBuffer.from_pyval((X, itr))\n", " device_out = compiled_computation.Execute([device_in,])\n", " movie[it] = device_out.to_py()[0][0,0]\n", "\n", "# Plot\n", "fig = plt.figure(figsize=(15,2))\n", "gs = gridspec.GridSpec(1,Niter)\n", "for i in range(Niter):\n", " ax1 = plt.subplot(gs[:, i])\n", " ax1.axis('off')\n", " ax1.imshow(movie[i])\n", "plt.subplots_adjust(left=0.0, right=1.0, top=1.0, bottom=0.0, hspace=0.0, wspace=0.05)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "9-0PJlqv237S" }, "source": [ "# Fin \n", "\n", "There's much more to XLA, but this hopefully highlights how easy it is to play with via the python client!" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "XLA in Python", "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.2" } }, "nbformat": 4, "nbformat_minor": 2 }