{ "cells": [ { "cell_type": "markdown", "id": "86423c05", "metadata": {}, "source": [ "# Low rank (and stack of low rank) matrices: forward and adjoints\n", "\n", "1. **U V as linear operator (with v as model)**\n", "\n", "$$\n", "\\mathbf{y}=\\mathbf{R}\\mathbf{U}\\mathbf{V}^T = R_u(\\mathbf{v})\n", "$$\n", "\n", "1. **U V as linear operator (with u as model)**\n", "\n", "$$\n", "\\mathbf{y}=\\mathbf{R}\\mathbf{U}\\mathbf{V}^T = R_v(\\mathbf{u})\n", "$$\n", "\n", "where $\\mathbf{R}$ is any generic additional linear operator" ] }, { "cell_type": "code", "execution_count": null, "id": "87e07fc0", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "import numpy as np\n", "\n", "from pylops.basicoperators import *\n", "from pylops.utils.dottest import dottest\n", "\n", "from pyproximal.proximal import *\n", "from pyproximal.utils.bilinear import BilinearOperator\n", "\n", "np.random.seed(0)" ] }, { "cell_type": "code", "execution_count": 2, "id": "052bbba8", "metadata": {}, "outputs": [], "source": [ "class LowRankFactorizedMatrix(BilinearOperator):\n", " def __init__(self, X, Y, d, Op=None):\n", " self.n, self.k = X.shape\n", " self.m = Y.shape[1]\n", "\n", " self.x = X\n", " self.y = Y\n", " self.d = d\n", " self.Op = Op\n", " self.shapex = (self.n * self.m, self.n * self.k)\n", " self.shapey = (self.n * self.m, self.m * self.k)\n", "\n", " def __call__(self, x, y=None):\n", " if y is None:\n", " x, y = x[:self.n * self.k], x[self.n * self.k:]\n", " xold = self.x.copy()\n", " self.updatex(x)\n", " res = self.d - self._matvecy(y)\n", " self.updatex(xold)\n", " return np.linalg.norm(res)**2 / 2.\n", "\n", " def _matvecx(self, x):\n", " X = x.reshape(self.n, self.k)\n", " X = X @ self.y.reshape(self.k, self.m)\n", " if self.Op is not None:\n", " X = self.Op @ X.ravel()\n", " return X.ravel()\n", "\n", " def _matvecy(self, y):\n", " Y = y.reshape(self.k, self.m)\n", " X = self.x.reshape(self.n, self.k) @ Y\n", " if self.Op is not None:\n", " X = self.Op @ X.ravel()\n", " return X.ravel()\n", "\n", " def matvec(self, x):\n", " if x.size == self.shapex[1]:\n", " y = self._matvecx(x)\n", " else:\n", " y = self._matvecy(x)\n", " return y\n", " \n", " def _rmatvecx(self, x):\n", " if self.Op is not None:\n", " x = self.Op.H @ x\n", " X = x.reshape(self.n, self.m)\n", " X = X @ np.conj(self.y.reshape(self.k, self.m).T)\n", " return X.ravel()\n", "\n", " def _rmatvecy(self, x):\n", " if self.Op is not None:\n", " x = self.Op.H @ x\n", " Y = x.reshape(self.n, self.m)\n", " X = (np.conj(Y.T) @ self.x.reshape(self.n, self.k)).T\n", " return X.ravel()\n", "\n", " def rmatvec(self, x, which=\"x\"):\n", " if which == \"x\":\n", " y = self._rmatvecx(x)\n", " else:\n", " y = self._rmatvecy(x)\n", " return y\n", " \n", " def gradx(self, x):\n", " pass\n", "\n", " def grady(self, y):\n", " pass\n", "\n", " def grad(self, x_or_y):\n", " pass\n", "\n", " def lx(self, x):\n", " pass\n", "\n", " def ly(self, y):\n", " pass" ] }, { "cell_type": "code", "execution_count": 3, "id": "2ecc34f0", "metadata": {}, "outputs": [], "source": [ "# Restriction operator\n", "n, m, k = 4, 5, 2\n", "sub = 0.4\n", "nsub = int(n*m*sub)\n", "iava = np.random.permutation(np.arange(n*m))[:nsub]\n", "\n", "Rop = Restriction(n*m, iava)" ] }, { "cell_type": "code", "execution_count": 4, "id": "060dc28d", "metadata": {}, "outputs": [], "source": [ "# model\n", "U = np.random.normal(0., 1., (n, k))\n", "V = np.random.normal(0., 1., (m, k))\n", "\n", "X = U @ V.T\n", "\n", "# data\n", "y = Rop * X.ravel()\n", "\n", "# Masked data\n", "Y = (Rop.H * Rop * X.ravel()).reshape(n, m)" ] }, { "cell_type": "code", "execution_count": 5, "id": "068a8307", "metadata": {}, "outputs": [], "source": [ "X = U @ V.T\n", "X1 = (V @ U.T).T" ] }, { "cell_type": "markdown", "id": "fcf83c94", "metadata": {}, "source": [ "## U V^T as linear operator (with V as model)\n", "\n", "$$\n", "\\mathbf{y}=\\mathbf{R}\\mathbf{U}\\mathbf{V}^T = R_u(\\mathbf{v})\n", "$$" ] }, { "cell_type": "code", "execution_count": 6, "id": "a12681e4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<20x10 MatrixMult with dtype=float64> <10x10 Transpose with dtype=float64>\n", "[[0. 0. 0. 0. 0.]\n", " [0. 0. 0. 0. 0.]\n", " [0. 0. 0. 0. 0.]\n", " [0. 0. 0. 0. 0.]]\n", "[0. 0. 0. 0. 0. 0. 0. 0.]\n" ] } ], "source": [ "Uop = MatrixMult(U, otherdims=(m,))\n", "Top = Transpose((m,k), (1,0))\n", "Uop1 = Uop * Top\n", "print(Uop, Top)\n", "X1 = Uop1 * V.ravel()\n", "X1 = X1.reshape(n,m)\n", "print(X-X1)\n", "\n", "# data\n", "Ruop = Rop * Uop * Top\n", "y1 = Ruop * V.ravel()\n", "print(y-y1)" ] }, { "cell_type": "code", "execution_count": 7, "id": "1fd315b8", "metadata": {}, "outputs": [], "source": [ "v1 = Ruop.H @ y1" ] }, { "cell_type": "markdown", "id": "bdeee9e7", "metadata": {}, "source": [ "## U V^T as linear operator (with U as model)\n", "\n", "$$\n", "\\mathbf{y}=\\mathbf{R}\\mathbf{U}\\mathbf{V}^T = \\mathbf{R}(\\mathbf{V}\\mathbf{U}^T)^T = R_v(\\mathbf{u})\n", "$$" ] }, { "cell_type": "code", "execution_count": 8, "id": "15061bb3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0. 0. 0. 0. 0.]\n", " [0. 0. 0. 0. 0.]\n", " [0. 0. 0. 0. 0.]\n", " [0. 0. 0. 0. 0.]]\n", "[0. 0. 0. 0. 0. 0. 0. 0.]\n" ] } ], "source": [ "Vop = MatrixMult(V, otherdims=(n,))\n", "Top = Transpose((n,k), (1,0))\n", "T1op = Transpose((n,m), (1,0))\n", "Vop1 = T1op.T * Vop * Top\n", "\n", "X1 = Vop1 * U.ravel()\n", "X1 = X1.reshape(n,m)\n", "print(X-X1)\n", "\n", "# data\n", "Ruop = Rop * T1op.T * Vop * Top\n", "y1 = Ruop * U.ravel()\n", "print(y-y1)" ] }, { "cell_type": "code", "execution_count": 9, "id": "ef1a972e", "metadata": {}, "outputs": [], "source": [ "u1 = Ruop.H @ y1" ] }, { "cell_type": "markdown", "id": "52d9d7ec", "metadata": {}, "source": [ "Let's now use our function" ] }, { "cell_type": "code", "execution_count": 10, "id": "36ceb09e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([0., 0., 0., 0., 0., 0., 0., 0.]),\n", " array([0., 0., 0., 0., 0., 0., 0., 0.]))" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "LOp = LowRankFactorizedMatrix(U, V.T, y, Op=Rop)\n", "\n", "y-LOp._matvecx(U.ravel()), y-LOp._matvecy(V.T.ravel())" ] }, { "cell_type": "code", "execution_count": 11, "id": "78ffb4dd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0., 0.],\n", " [0., 0.],\n", " [0., 0.],\n", " [0., 0.]])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "u1 - LOp._rmatvecx(y).reshape(n, k)" ] }, { "cell_type": "code", "execution_count": 12, "id": "98857c1e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0.]])" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v1.T - LOp._rmatvecy(y).reshape(k, m)" ] }, { "cell_type": "code", "execution_count": 13, "id": "e62b4bd5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "np.True_" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Fop = FunctionOperator(LOp._matvecx, LOp._rmatvecx, len(iava), n*k)\n", "dottest(Fop)" ] }, { "cell_type": "code", "execution_count": 14, "id": "66797acf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "np.True_" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Fop = FunctionOperator(LOp._matvecy, LOp._rmatvecy, len(iava), k*m)\n", "dottest(Fop)" ] }, { "cell_type": "markdown", "id": "bf2721e5", "metadata": {}, "source": [ "## Stack of matrices" ] }, { "cell_type": "markdown", "id": "35e49813", "metadata": {}, "source": [ "We do the same now but we assume a stack of matrices, where for each of them we have\n", "\n", "$$\n", "\\mathbf{y}_i=\\mathbf{U}_i\\mathbf{V}_i^T = R_{u_i}(\\mathbf{v}_i)\n", "$$\n", "\n", "and \n", "\n", "$$\n", "\\mathbf{y}=\\mathbf{R} [\\mathbf{y}_1^T, \\mathbf{y}_2^T, ..., \\mathbf{y}_N^T]^T\n", "$$" ] }, { "cell_type": "code", "execution_count": 15, "id": "992c189b", "metadata": {}, "outputs": [], "source": [ "class LowRankFactorizedStackMatrix(BilinearOperator):\n", " r\"\"\"Low-Rank Factorized Stack of Matrix operator.\n", "\n", " Parameters\n", " ----------\n", " X : :obj:`numpy.ndarray`\n", " Left-matrix of size :math:`r \\times n \\times k`\n", " Y : :obj:`numpy.ndarray`\n", " Right-matrix of size :math:`r \\times k \\times m`\n", " d : :obj:`numpy.ndarray`\n", " Data vector\n", " Op : :obj:`pylops.LinearOperator`, optional\n", " Linear operator\n", "\n", " \"\"\"\n", " def __init__(self, X, Y, d, Op=None):\n", " self.r, self.n, self.k = X.shape\n", " self.m = Y.shape[2]\n", "\n", " self.x = X\n", " self.y = Y\n", " self.d = d\n", " self.Op = Op\n", " self.shapex = (self.r * self.n * self.m, self.r * self.n * self.k)\n", " self.shapey = (self.r * self.n * self.m, self.r * self.m * self.k)\n", "\n", " def __call__(self, x, y=None):\n", " if y is None:\n", " x, y = x[:self.r * self.n * self.k], x[self.r * self.n * self.k:]\n", " xold = self.x.copy()\n", " self.updatex(x)\n", " res = self.d - self._matvecy(y)\n", " self.updatex(xold)\n", " return np.linalg.norm(res)**2 / 2.\n", "\n", " def _matvecx(self, x):\n", " X = x.reshape(self.r, self.n, self.k)\n", " X = np.matmul(X, self.y.reshape(self.r, self.k, self.m))\n", " if self.Op is not None:\n", " X = self.Op @ X.ravel()\n", " return X.ravel()\n", "\n", " def _matvecy(self, y):\n", " Y = y.reshape(self.r, self.k, self.m)\n", " X = np.matmul(self.x.reshape(self.r, self.n, self.k), Y)\n", " if self.Op is not None:\n", " X = self.Op @ X.ravel()\n", " return X.ravel()\n", " \n", " def matvec(self, x):\n", " if x.size == self.shapex[1]:\n", " y = self._matvecx(x)\n", " else:\n", " y = self._matvecy(x)\n", " return y\n", " \n", " def _rmatvecx(self, x):\n", " if self.Op is not None:\n", " x = self.Op.H @ x\n", " X = x.reshape(self.r, self.n, self.m)\n", " X = X @ np.conj(self.y.reshape(self.r, self.k, self.m).transpose(0, 2, 1))\n", " return X.ravel()\n", "\n", " def _rmatvecy(self, x):\n", " if self.Op is not None:\n", " x = self.Op.H @ x\n", " Y = x.reshape(self.r, self.n, self.m)\n", " X = (np.conj(Y.transpose(0, 2, 1) @ self.x.reshape(self.r, self.n, self.k)) ).transpose(0, 2, 1)\n", " return X.ravel()\n", "\n", " def rmatvec(self, x, which=\"x\"):\n", " if which == \"x\":\n", " y = self._rmatvecx(x)\n", " else:\n", " y = self._rmatvecy(x)\n", " return y\n", " \n", " def gradx(self, x):\n", " pass\n", "\n", " def grady(self, y):\n", " pass\n", "\n", " def grad(self, x_or_y):\n", " pass\n", "\n", " def lx(self, x):\n", " pass\n", "\n", " def ly(self, y):\n", " pass" ] }, { "cell_type": "code", "execution_count": 16, "id": "92dc629f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Restriction operator\n", "r, n, m, k = 10, 4, 5, 2\n", "nsub = int(r*n*m*sub)\n", "iava = np.random.permutation(np.arange(r*n*m))[:nsub]\n", "Rop = Restriction(r*n*m, iava)\n", "\n", "U = np.random.normal(0., 1., (r, n, k))\n", "V = np.random.normal(0., 1., (r, m, k))\n", "\n", "LOp = LowRankFactorizedStackMatrix(U, V.transpose(0,2,1), y, Op=Rop)\n", "\n", "y = LOp._matvecx(U.ravel())\n", "LOp._matvecx(U.ravel()) - LOp._matvecy(V.transpose(0,2,1).ravel())" ] }, { "cell_type": "code", "execution_count": 17, "id": "d4e5f6d9", "metadata": {}, "outputs": [], "source": [ "LOp._rmatvecx(y).reshape(r, n, k);" ] }, { "cell_type": "code", "execution_count": 18, "id": "d6078643", "metadata": {}, "outputs": [], "source": [ "LOp._rmatvecy(y).reshape(r, k, m);" ] }, { "cell_type": "code", "execution_count": 19, "id": "d0408098", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "np.True_" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Fop = FunctionOperator(LOp._matvecx, LOp._rmatvecx, len(iava), r*n*k)\n", "dottest(Fop)" ] }, { "cell_type": "code", "execution_count": 20, "id": "8c64df25", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "np.True_" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Fop = FunctionOperator(LOp._matvecy, LOp._rmatvecy, len(iava), r*k*m)\n", "dottest(Fop)" ] } ], "metadata": { "kernelspec": { "display_name": "pylops", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.5" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "340px" }, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 5 }