{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "2669ff01",
"metadata": {},
"source": [
"# Implicit Eigen Decomposition\n",
"\n",
"This tutorial notebook applies the [Deep Declarative Networks (DDNs)](https://ieeexplore.ieee.org/document/9355027) framework to the implicit differentiation of Eigen decomposition. For a square matrix $\\mathbf{x} \\in \\mathbb{R}^{m \\times m}$, it can be formulated by solving the problem: $\\mathbf{x} \\in \\mathbb{R}^{m \\times m} \\rightarrow ( \\boldsymbol{\\lambda} \\in \\mathbb{R}^{n}, \\mathbf{u} \\in \\mathbb{R}^{m \\times n})$ with $\\boldsymbol{\\lambda}$ as a vector of $n$ Eigenvalues and $\\mathbf{u}$ as the corresponding Eigenvectors, as follows,\n",
"\n",
"\\begin{equation}\n",
"\\begin{aligned}\n",
"\\min_{\\mathbf{u} \\in \\mathbb{R}^{m \\times n}} f(\\mathbf{x}, \\mathbf{u}) &= -\\text{tr} \\left( \\mathbf{u}^T \\mathbf{x} \\mathbf{u} \\right)\\ ,\\\\\n",
"\\text{subject to} \\quad h \\left( \\mathbf{u} \\right) &= \\mathbf{u}^T \\mathbf{u} = \\mathbf{I}_n\\ ,\n",
"\\end{aligned}\n",
"\\end{equation}\n",
"\n",
"where tr() is trace function of the diagonal elements, and $\\mathbf{I}_n$ is a $n \\times n$ identity matrix. The optimal solution is\n",
"\n",
"\\begin{equation}\n",
"\\mathbf{y} = \\text{argmin}_{\\mathbf{u} \\in \\mathbb{R}^{m \\times n}} f(\\mathbf{x}, \\mathbf{u})\n",
"\\end{equation}\n",
"\n",
"satisfying\n",
"\n",
"\\begin{equation}\n",
"\\mathbf{x} \\mathbf{y}_i = \\lambda_i \\mathbf{y}_i, \\forall i \\in \\mathcal{N} = \\{1, ..., n\\}\\ .\n",
"\\end{equation}\n",
"\n",
"### Solver for Eigenvalues and Eigenvectors\n",
"\n",
"For the greatest Eigenvalue, one can use [the Power Iteration algorithm](https://en.wikipedia.org/wiki/Power_iteration). To be more general, we consider [the Simultaneous Iteration algorithm with QR decomposition](https://dspace.mit.edu/bitstream/handle/1721.1/75282/18-335j-fall-2006/contents/lecture-notes/lec15.pdf) for multiple Eigenvalues.\n",
"\n",
"[WARNING] Since the Eigenvector of an Eigenvalue has two solutions with the reverse $L_2$ based direction against each other, that is when $\\mathbf{y}=-\\mathbf{y}$, we apply the cosine similarity based solutions such that the gradient of the Eigenvector will not hover between these two directions but in one of them.\n",
"\n",
"For the similarity of directions, we use $\\mathbf{y} \\in \\mathbb{R}^{m \\times n}$ to denote the optimal solution and $\\mathbf{r}$ to denote a one-hot reference vector (dimension $\\mathbb{R}^{m}$ with value 1 indicating the direction base) or a reference matrix (dimension $\\mathbb{R}^{m \\times n}$).\n",
"\n",
"- $\\textbf{Solution 1:}$ the motivation of using the one-hot reference vector $\\mathbf{r}$ is to ensure either $\\mathbf{y}$ or $-\\mathbf{y}$ that has a similar (with a positive cosine similarity) or reverse (with a negative cosine similarity) direction against a specific axis (among $m$ axes) is chosen as the optimal solution. It follows\n",
"\n",
"\\begin{equation}\n",
"\\mathbf{s} = \\left( \\mathbf{y}^T \\mathbf{r} \\right) \\in \\mathbb{R}^{n}\\ ,\n",
"\\end{equation}\n",
"\n",
"for $n$ Eigenvectors. \n",
"\n",
"- $\\textbf{Solution 2:}$ for a $m \\times n$ reference matrix $\\mathbf{r}$, which can be the solution $\\mathbf{r} = \\mathbf{y}_{t-1}$ at $(t-1)$ iteration, the cosine similarity between $\\mathbf{y}$ and $\\mathbf{r}$ indicates the updating direction of $\\mathbf{y}$ from $(t-1)$ to $t$. $\\mathbf{s} > 0$ means $\\mathbf{y}$ is changed gradually in a similar direction from the direction in the last iteration. Here,\n",
"\n",
"\\begin{equation}\n",
"\\mathbf{s} = \\text{diagnoal} \\left( \\mathbf{y}^T \\mathbf{r} \\right) \\in \\mathbb{R}^{n}\\ ,\n",
"\\end{equation}\n",
"\n",
"where $\\text{diagnoal}(\\mathbf{x})$ transforms the diagnoal elements of matrix $\\mathbf{x}$ as a vector.\n",
"\n",
"Next, one can choose either $\\mathbf{s}_i > 0$ or $\\mathbf{s}_i < 0$ as the sign condition. If $\\mathbf{y}_i$ violates this condition, $\\mathbf{y}_i=-\\mathbf{y}_i$, for all $i \\in \\mathcal{N}$. The implementation can be found in function [uniform_solution_direction()](eigen_decomposition.py).\n",
"\n",
"### Implicit Gradients with Structure Exploited\n",
"\n",
"Differentiating $\\mathbf{y}$ over $\\mathbf{x}$ is achieved by using Eq. (24) of [Deep Declarative Networks (DDNs)](https://ieeexplore.ieee.org/document/9355027). Then, we have\n",
"\n",
"\\begin{equation}\n",
"D_{X} \\mathbf{y} = H^{-1} A^{T} \\left( A H^{-1} A^T \\right)^{-1} \\left( A H^{-1} B \\right) - H^{-1} B \\in \\mathbb{R}^{(m \\times n) \\times (m \\times m)}\\ ,\n",
"\\end{equation}\n",
"\n",
"where\n",
"\n",
"\\begin{equation}\n",
"\\begin{aligned}\n",
"A &= D_{Y} h(\\mathbf{y}) \\in \\mathbb{R}^{n \\times (m \\times n)}\\ ,\\\\\n",
"B &= D^2_{XY} f(\\mathbf{x}, \\mathbf{y}) \\in \\mathbb{R}^{(m \\times n) \\times (m \\times m)}\\ ,\\\\\n",
"H &= D^2_{YY} f(\\mathbf{x}, \\mathbf{y}) + \\boldsymbol{\\lambda}^T D^2_{YY} h(\\mathbf{y}) \\in \\mathbb{R}^{(m \\times n) \\times (m \\times n)}\\ .\n",
"\\end{aligned}\n",
"\\end{equation}\n",
"\n",
"[ATTENTION] In Eq.(16) of [DDNs](https://ieeexplore.ieee.org/document/9355027), the Lagrangian form uses \"-\" between the objective function and the constraint equalities. While this sign has no effects on the problem defined, it affects $H$ because the Eigenvalues $\\boldsymbol{\\lambda}$ and those calculated by using Eq.(25) have different signs. Hence, we use \"+\" instead of \"-\" in the calculation of $H$. Alternatively, one can use \"-\" in the calculation of $H$ and simply set $\\boldsymbol{\\lambda}=-\\boldsymbol{\\lambda}$.\n",
"\n",
"For an efficient implementation, we first initialize these matrices with all zeros and then assign values by exploiting their matrix structures as\n",
"\n",
"\\begin{equation}\n",
"\\begin{aligned}\n",
"&\\{ D_Y f(\\mathbf{x}, \\mathbf{y})(:, i) = -\\left( \\mathbf{x} + \\mathbf{x}^T \\right) \\mathbf{y}(:, i), \\forall i \\in \\mathcal{N} \\} \\in \\mathbb{R}^{m \\times n}\\ ,\\\\\n",
"&\\{D^2_{YY} f(\\mathbf{x}, \\mathbf{y})(:, i, :, i) = -\\left( \\mathbf{x} + \\mathbf{x}^T \\right), \\forall i \\in \\mathcal{N} \\} \\in \\mathbb{R}^{m \\times n \\times m \\times n}\\ ,\\\\\n",
"&\\{A_i = D_{Y} h(\\mathbf{y})(i, :, i) = 2 \\mathbf{y}(:, i), \\forall i \\in \\mathcal{N}\\} \\in \\mathbb{R}^{n \\times m \\times n}\\ ,\\\\\n",
"&\\{D^2_{YY} h(\\mathbf{y})(i, :, i, :, i) = 2 \\mathbf{I}_m, \\forall i \\in \\mathcal{N}\\} \\in \\mathbb{R}^{n \\times m \\times n \\times m \\times n}\\ .\n",
"\\end{aligned}\n",
"\\end{equation}\n",
"\n",
"However, since $D^2_{YY} h(\\mathbf{y})$ consumes more memory than the others while its $i$th block is merely $2\\mathbf{I}_m$, further exploitation of this structure with the vector-matrix multiplication of $\\boldsymbol{\\lambda}^T D^2_{YY} h(\\mathbf{y})$, which is denoted by $D^2_{YY} h_{\\lambda}(\\mathbf{y})$ for simplicity, can reduce the memory consumption by $n$ times. One can implement it by following\n",
"\n",
"\\begin{equation}\n",
"% \\{D^2_{YY} h_{\\lambda} (\\mathbf{y}) \\left( i \\times n:(i+1) \\times n, i \\times n:(i+1) \\times n \\right) =2 \\text{diag}_n ([ \\lambda_1, ..., \\lambda_n] ), \\forall i \\in \\mathcal{M}=\\{0, ..., m-1\\}\\} \\in \\mathbb{R}^{(m \\times n) \\times (m \\times n)}\\ ,\\\\\n",
"\\{D^2_{YY} h_{\\lambda} (\\mathbf{y}) \\left( :, i, :, i \\right) = 2 \\lambda_i \\mathbf{I}_m, \\forall i \\in \\mathcal{N}\\} \\in \\mathbb{R}^{m \\times n \\times m \\times n}\\ .\n",
"\\end{equation}\n",
"\n",
"Then, the first core for memory reduction is to avoid explicitly storing $D^2_{YY}f(\\mathbf{x}, \\mathbf{y})$ and $D^2_{YY}h(\\mathbf{y})$ by using\n",
"\n",
"\\begin{equation}\n",
"H_i = H \\left(:, i, :, i \\right) = -\\left( \\mathbf{x} + \\mathbf{x}^T \\right) + 2 \\lambda_i \\mathbf{I}_m, \\forall i \\in \\mathcal{N}\\ . \n",
"\\end{equation}\n",
"\n",
"The formulation of $D^2_{XY} f(\\mathbf{x}, \\mathbf{y}) \\in \\mathbb{R}^{m \\times n \\times m \\times m}$ requires two steps for all $i \\in \\mathcal{N}$ and $j \\in \\mathcal{M}=\\{1, ..., m\\}$,\n",
"\n",
"\\begin{equation}\n",
"\\begin{aligned}\n",
"D^2_{XY} f(\\mathbf{x}, \\mathbf{y})(j,i,:,j) &\\mathrel{-}= \\mathbf{y}(:,i)\\ ,\\\\\n",
"D^2_{XY} f(\\mathbf{x}, \\mathbf{y})(j,i,j,:) &\\mathrel{-}= \\mathbf{y}^T(:,i)\\ .\n",
"\\end{aligned}\n",
"\\end{equation}\n",
"\n",
"We observe that $B=D^2_{XY} f(\\mathbf{x}, \\mathbf{y})$ consists of $\\mathbf{y}$ and its tranpose with a specific structure. It is feasible to avoid explicitly storing $B$ which could consume rather large memory. Prior to that, however, we first highlight the second core for memory reduction by looping over each Eigenvalue and then sum over for $D_X \\mathcal{L}(\\mathbf{y})=D_Y \\mathcal{L}(\\mathbf{y}) D_X \\mathbf{y}$ where $\\mathcal{L}$ is the value of loss function of $\\mathbf{y}_i \\in \\mathbb{R}^m$, that is\n",
"\n",
"\\begin{equation}\n",
"D_X \\mathcal{L}(\\mathbf{y})\n",
"=\\sum_{i \\in \\mathcal{N}} D_Y \\mathcal{L}(\\mathbf{y}_i) D_X \\mathbf{y}_i\\ .\n",
"\\end{equation}\n",
"\n",
"In this case, if explicitly stored, $A$, $B$, and $H$ can be reduced by $n^2$, $n$, and $n^2$ times respectively. In addition to this, however, $B_i \\in \\mathbb{R}^{m \\times (m \\times m)}$, for all $i \\in \\mathcal{N}$, still requires large memory. For instance, for $m=256$ with batch size $64$ in single-precision floating point format, it requires $4$ gigabytes.\n",
"\n",
"Now, as the third core for memory reduction, we avoid storing $B$ but instead using $\\mathbf{y}$ to calculate $D_X \\mathcal{L}(\\mathbf{y})$. \n",
"We assume the implicit differentiation is on the $i$th Eigenvalue with index $i$ in all related variables and then define that\n",
"\n",
"\\begin{equation}\n",
"D_X \\mathcal{L}(\\mathbf{y}_i)\n",
"= D_Y \\mathcal{L}(\\mathbf{y}_i) D_X \\mathbf{y}_i\n",
"= D_Y \\mathcal{L}(\\mathbf{y}_i) \n",
"\\left( H_i^{-1} A_i^{T} \\left( A_i H_i^{-1} A_i^T \\right)^{-1} A_i H_i^{-1} - H_i^{-1} \\right) B_i\n",
"= K_i^T B_i \\in \\mathbb{R}^{m \\times m}\\ ,\n",
"\\end{equation}\n",
"\n",
"where $K_i \\in \\mathbb{R}^{m}$ is a vector given $\\mathcal{L} \\in \\mathbb{R}$. Due to the aforementioned special structure of $D^2_{XY} f(\\mathbf{x}, \\mathbf{y})$, we find that\n",
"\n",
"\\begin{equation}\n",
"D_X \\mathcal{L}(\\mathbf{y}_i) = -K_i \\mathbf{y}_i^T - \\mathbf{y}_i K_i^T\\ .\n",
"\\end{equation}\n",
"\n",
"This greatly reduces the memory requirement considering the large data size of $B$.\n",
"To make it more clearly, we take $m=2$, $K_i=[ k_1, k_2 ]^T$, and $\\mathbf{y}_i = [ y_1, y_2 ]^T$ as an example.\n",
"Then,\n",
"\n",
"\\begin{equation}\n",
"B_i\n",
"= B_i^1 + B_i^2\n",
"= \\begin{bmatrix}\n",
"-y_1 & 0 & -y_2 & 0 \\\\\n",
"0 & -y_1 & 0 & -y_2\n",
"\\end{bmatrix}\n",
"+\n",
"\\begin{bmatrix}\n",
"-y_1 & -y_2 & 0 & 0 \\\\\n",
"0 & 0 & -y_1 & -y_2\n",
"\\end{bmatrix}\n",
"\\end{equation}\n",
"\n",
"and\n",
"\n",
"\\begin{equation}\n",
"K_i^T B_i\n",
"= K_i^T B^1_i + K_i^T B^2_i =\n",
"\\begin{bmatrix}\n",
"-k_1 y_1 & -k_2 y_1 & -k_1 y_2 & -k_2 y_2\n",
"\\end{bmatrix}\n",
"+\n",
"\\begin{bmatrix}\n",
"-k_1 y_1 & -k_1 y_2 & -k_2 y_1 & -k_2 y_2\n",
"\\end{bmatrix}\\ ,\n",
"\\end{equation}\n",
"\n",
"where $K_i B^1_i$ is an outer-product of $K_i$ and $\\mathbf{y}_i$, $K_i B^1_i$ and $K_i B^2_i$ are tranposed when reshaping to $m \\times m$ matrices. The memory required by vectors $\\mathbf{y}_i \\in \\mathbb{R}^{m}$ is much less than the one required by $B_i \\in \\mathbb{R}^{m \\times (m \\times m)}$. The structure of $B$ also makes $D_X \\mathcal{L}(\\mathbf{y})$ symmetric.\n",
"\n",
"In summary, we greatly reduce the memory required by the calculation of $D_X \\mathcal{L}(\\mathbf{y})$ by 1) looping over each Eigenvalue instead of stacking all Eigenvalues for high-dimensional matrix computing, 2) avoiding storing the memory-inefficient $B$, and 3) simplifying the formulation of $H_i=H(:,i,:,i)$ by using only the vector $\\mathbf{x}$ and Eigenvalues."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "bbf238b6",
"metadata": {},
"source": [
"### Implicit Function Theorem\n",
"\n",
"By using the Power Iteration algorithm for the greatest Eigenvalue (assume it is positive), the implicit function is defined as\n",
"\n",
"\\begin{equation}\n",
"f(\\mathbf{x}, \\mathbf{u}_k, \\mathbf{u}_{k+1}) = \\mathbf{u}_{k+1} - \\frac{\\mathbf{x} \\mathbf{u}_k}{\\| \\mathbf{x} \\mathbf{u}_k \\|}\\ ,\n",
"\\end{equation}\n",
"where the source matrix $\\mathbf{x} \\in \\mathbb{R}^{m \\times m}$ and the solution $\\mathbf{u} \\in \\mathbb{R}^{m \\times n}$.\n",
"Upon the convergence of $\\mathbf{u}_k$ to a fixed point, we have $\\mathbf{y}=\\mathbf{u}_{k+1}=\\mathbf{u}_k$, and thus,\n",
"\n",
"\\begin{equation}\n",
"f(\\mathbf{x}, \\mathbf{y}) = \\mathbf{y}-\\frac{\\mathbf{x} \\mathbf{y}}{\\| \\mathbf{x} \\mathbf{y} \\|}\\ .\n",
"\\end{equation}\n",
"\n",
"Now, since applying the implicit function theorem to $f(\\mathbf{x}, \\mathbf{y})$ gives\n",
"\n",
"\\begin{equation}\n",
"\\frac{\\partial f}{\\partial \\mathbf{x}} + \\frac{\\partial f}{\\partial \\mathbf{y}} \\frac{\\partial \\mathbf{y}}{\\partial \\mathbf{x}} = 0\\ ,\n",
"\\end{equation}\n",
"\n",
"\\begin{equation}\n",
"\\frac{\\partial \\mathbf{y}}{\\partial \\mathbf{x}} = -\\left( \\frac{\\partial f}{\\partial \\mathbf{y}} \\right)^{-1} \\frac{\\partial f}{\\partial \\mathbf{x}}\\ .\n",
"\\end{equation}\n",
"\n",
"For the notation simplicity, we denote $\\mathcal{A} = \\partial f / \\partial \\mathbf{y} \\in \\mathbb{R}^{1 \\times (mn)}, \\mathcal{B} = \\partial f / \\partial \\mathbf{x} \\in \\mathbb{R}^{1 \\times (mm)}$, and $\\mathcal{K} = \\partial \\mathbf{y} / \\partial \\mathbf{x} \\in \\mathbb{R}^{(mn) \\times (mm)}$.\n",
"Hence, $\\mathcal{K} = -\\mathcal{A}^{-1} \\mathcal{B}$.\n",
"\n",
"Since the solution $\\mathbf{y}$ is used to calculate the learning loss $\\mathcal{L}$, we denote the gradient of $\\mathcal{L}$ over $\\mathbf{y}$ as $\\mathcal{Q} = \\partial \\mathcal{L} / \\partial \\mathbf{y} \\in \\mathbb{R}^{1 \\times (mn)}$ and the one over $\\mathbf{x}$ is\n",
"\n",
"\\begin{equation}\n",
"\\frac{\\partial \\mathcal{L}}{\\partial \\mathbf{x}} = \\frac{\\partial \\mathcal{L}}{\\partial \\mathbf{y}} \\frac{\\partial \\mathbf{y}}{\\partial \\mathbf{x}} = \\mathcal{Q} \\mathcal{K} = -\\mathcal{Q} \\mathcal{A}^{-1} \\mathcal{B} = -\\mathcal{H} \\mathcal{B} \\in \\mathbb{R}^{1 \\times (mm)}\\ ,\n",
"\\end{equation}\n",
"where $\\mathcal{H} = \\mathcal{Q} \\mathcal{A}^{-1} \\in \\mathbb{R}$. Since\n",
"\\begin{aligned}\n",
"\\mathcal{Q} \\mathcal{A}^{-1} &= \\mathcal{H} \\\\\n",
"\\mathcal{Q} &= \\mathcal{H} \\mathcal{A} \\\\\n",
"\\mathcal{Q}^T &= \\mathcal{A}^T \\mathcal{H}^T \\\\\n",
"\\mathcal{A} \\mathcal{Q}^T &= \\mathcal{A} \\mathcal{A}^T \\mathcal{H}^T\\ ,\n",
"\\end{aligned}\n",
"\n",
"$\\mathcal{H}$ can be calculated by $\\text{nn.linalg.solve}(\\mathcal{A} \\mathcal{A}^T, \\mathcal{A} \\mathcal{Q}^T)^T$, followed by the calculation of $\\partial \\mathcal{L} / \\partial \\mathbf{x}$."
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "0af82613",
"metadata": {},
"source": [
"### Implementation"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6985df96",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/zhiwei/anaconda3/envs/DDN112/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch, os\n",
"from copy import deepcopy\n",
"from datetime import datetime\n",
"from utils import generate_random_data, method_mode_explaination\n",
"from utils import run_precision_statistics, run_speed_memory_statistics\n",
"from utils import visual_speed_memory, visual_precision\n",
"\n",
"# Import forward instances\n",
"from ied_forward import EigenAuto, SimultaneousIteration, PowerIteration\n",
"\n",
"# Import backward instances\n",
"from ied_backward import AutoBackprop, DDNBackprop, FixedPointBackprop\n",
"\n",
"# Build visualization tools\n",
"disp_fnc = lambda x: x.detach().cpu().numpy()\n",
"diff_fnc = lambda x, y: disp_fnc((x - y.view(x.shape)).abs().max())\n",
"\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '0'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ea0e9c80",
"metadata": {},
"outputs": [],
"source": [
"# Build instances\n",
"num_iters = 100\n",
"uniform_solution_method = 'positive'\n",
"\n",
"# Generate random data\n",
"enable_stop_condition = False\n",
"enable_symmetric = True\n",
"\n",
"# 'solve' achieves lower memory but dLdx accuracy is bad because of the library ligalg.solve issue;\n",
"# 'pinverse' is more accurate in dLdx but needs more memory.\n",
"backprop_inverse_mode = 'solve' # solve/pinverse\n",
"\n",
"seed = 0\n",
"epsilon = 1.0e-13\n",
"batch, m = 1, 64\n",
"n = 1 # the number of Eigenvalues\n",
"dtype = torch.float32\n",
"\n",
"#\n",
"save_dir = 'results'\n",
"time_string = datetime.now().strftime('%Y%m%d-%H%M%S')\n",
"save_dir = os.path.join(save_dir, time_string, f\"{str(num_iters)}iters\")\n",
"os.makedirs(save_dir, exist_ok=True)\n",
"g_save_path = save_dir + f\"/{backprop_inverse_mode}\"\n",
"\n",
"if enable_symmetric: g_save_path += '_symmetric'\n",
"else: g_save_path += '_nonsymmetric'\n",
"\n",
"if enable_stop_condition: g_save_path += '_stopTrue'\n",
"else: g_save_path += '_stopFalse'\n",
"\n",
"#\n",
"obj_AT = EigenAuto(\n",
" uniform_solution_method=uniform_solution_method, solver_back=AutoBackprop,\n",
" num_eigen_values=n, backprop_inverse_mode=backprop_inverse_mode, enable_symmetric=enable_symmetric)\n",
"obj_PI = PowerIteration(\n",
" num_iters=num_iters, uniform_solution_method=uniform_solution_method, solver_back=DDNBackprop,\n",
" num_eigen_values=n, enable_stop_condition=enable_stop_condition, backprop_inverse_mode=backprop_inverse_mode)\n",
"obj_SI = SimultaneousIteration(\n",
" num_iters=num_iters, uniform_solution_method=uniform_solution_method, solver_back=DDNBackprop,\n",
" num_eigen_values=n, enable_stop_condition=enable_stop_condition, backprop_inverse_mode=backprop_inverse_mode)\n",
"obj_PI_IFT = PowerIteration(\n",
" num_iters=num_iters, uniform_solution_method=uniform_solution_method, solver_back=FixedPointBackprop,\n",
" num_eigen_values=n, enable_stop_condition=enable_stop_condition, backprop_inverse_mode=backprop_inverse_mode)\n",
"obj_SI_IFT = SimultaneousIteration(\n",
" num_iters=num_iters, uniform_solution_method=uniform_solution_method, solver_back=FixedPointBackprop,\n",
" num_eigen_values=n, enable_stop_condition=enable_stop_condition, backprop_inverse_mode=backprop_inverse_mode)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "27e598f9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dLdx_at_ddn vs dLdx_at_ddn_indiv : 1.4551915e-11\n",
"dLdx_si_ddn vs dLdx_si_ddn_indiv : 0.0\n",
"[ Check Eigenvalues ]\n",
"Auto : [101.05782]\n",
"PI : [101.05673]\n",
"SI : [101.056725]\n",
"Max diff : 0.0010986328\n",
"\n",
"[ Check Eigenvectors ]\n",
"Auto : [0.12875485 0.1301477 0.13224962 0.11636169 0.13526331 0.11521947\n",
" 0.12188618 0.13483587 0.12095354 0.11549785 0.13366696 0.13038024\n",
" 0.12198585 0.1215284 0.14015691 0.11579662 0.11842807 0.12208147\n",
" 0.12053845 0.13782045 0.12166768 0.1162324 0.13253036 0.11562862\n",
" 0.12917513 0.12311543 0.12204126 0.12370913 0.11222141 0.11167143\n",
" 0.1246989 0.12367364 0.12886107 0.13705821 0.118035 0.10938529\n",
" 0.1294138 0.11872555 0.1348373 0.12000456 0.13498007 0.12891752\n",
" 0.1170295 0.12522595 0.10646042 0.1321402 0.11618277 0.13953795\n",
" 0.12810251 0.12542069 0.1277837 0.13592969 0.11688785 0.13464409\n",
" 0.12372686 0.12716451 0.12047588 0.12865123 0.12061165 0.11782203\n",
" 0.1311092 0.11846078 0.12677294 0.13404019]\n",
"PI : [0.12875395 0.13014698 0.13224885 0.11636101 0.13526262 0.11521886\n",
" 0.1218853 0.13483502 0.12095293 0.11549731 0.13366619 0.13037927\n",
" 0.12198532 0.12152781 0.14015584 0.11579592 0.11842736 0.12208086\n",
" 0.1205376 0.13781972 0.12166704 0.11623182 0.13252953 0.11562788\n",
" 0.1291745 0.12311484 0.1220405 0.12370848 0.11222087 0.11167082\n",
" 0.12469833 0.12367295 0.12886041 0.13705751 0.11803435 0.10938465\n",
" 0.12941329 0.11872499 0.13483651 0.12000401 0.13497937 0.1289169\n",
" 0.11702888 0.12522534 0.10645989 0.13213956 0.11618212 0.13953719\n",
" 0.12810177 0.12542002 0.12778303 0.13592893 0.11688729 0.13464339\n",
" 0.12372625 0.12716393 0.12047534 0.12865065 0.12061101 0.11782142\n",
" 0.13110845 0.11846016 0.12677234 0.13403955]\n",
"SI : [0.1287539 0.130147 0.13224885 0.11636102 0.13526261 0.11521886\n",
" 0.12188531 0.13483502 0.12095292 0.11549731 0.13366619 0.13037929\n",
" 0.12198532 0.12152781 0.14015584 0.11579591 0.11842735 0.12208088\n",
" 0.1205376 0.1378197 0.12166705 0.11623181 0.13252953 0.11562788\n",
" 0.1291745 0.12311484 0.1220405 0.12370849 0.11222085 0.11167081\n",
" 0.12469833 0.12367296 0.12886041 0.13705751 0.11803433 0.10938465\n",
" 0.12941329 0.11872499 0.13483652 0.12000401 0.13497937 0.1289169\n",
" 0.11702889 0.12522535 0.10645989 0.13213958 0.1161821 0.1395372\n",
" 0.12810175 0.12542002 0.12778303 0.13592893 0.11688727 0.1346434\n",
" 0.12372626 0.12716395 0.12047534 0.12865064 0.12061102 0.11782143\n",
" 0.13110846 0.11846013 0.12677234 0.13403954]\n",
"Max diff : 1.0728836e-06\n",
"\n",
"[ Check Eigen Gaps ]\n",
"Auto : [0.00018311]\n",
"PI : [2.861023e-06]\n",
"SI : [5.722046e-06]\n",
"\n",
"[ Check Fixed Point Gaps ]\n",
"Auto : [1.0728836e-06]\n",
"PI : [2.9802322e-08]\n",
"SI : [4.4703484e-08]\n",
"\n",
"[ Check Gradient Gaps (Max Diff.) ]\n",
"dLdx_at vs dLdx_si : 0.0012523589\n",
"dLdx_at vs dLdx_si_ddn : 0.0012523589\n",
"dLdx_at vs dLdx_at_iter : 5.120455e-10\n",
"dLdx_at vs dLdx_at_ddn : 5.9236423e-05\n",
"dLdx_pi vs dLdx_pi_unroll : 0.0039889836\n",
"dLdx_pi vs dLdx_si : 0.0040286533\n",
"dLdx_si vs dLdx_si_unroll : 0.0012523588\n",
"dLdx_si vs dLdx_si_ddn : 0.0\n",
"dLdx_at vs dLdx_pi_ift_auto: 1.6370905e-09\n",
"dLdx_at vs dLdx_si_ift_auto: 1.6880222e-09\n",
"dLdx_at vs dLdx_pi_ift_stru: 1.553417e-09\n",
"dLdx_at vs dLdx_si_ift_stru: 1.8590072e-09\n"
]
}
],
"source": [
"x_org, dLdy = generate_random_data(\n",
" batch, m, n, seed, enable_symmetric=enable_symmetric, dtype=dtype, enable_grad_one=True,\n",
" distribution_mode='gaussian')\n",
"\n",
"# Get Eigenvalues and Eigenvectors\n",
"# ---- Autogradient\n",
"x = deepcopy(x_org)\n",
"lambd_at, y_at = obj_AT(x, backward_fnc_name='unroll')\n",
"\n",
"# ---- 1. Autogradient: autoback eigh\n",
"x.retain_grad()\n",
"loss = y_at.sum()\n",
"loss.backward()\n",
"dLdx_at = x.grad\n",
"\n",
"# ---- 2. Autogradient: get auto dydx then multiply dLdy\n",
"x = deepcopy(x_org)\n",
"dLdx_at_iter = obj_AT.solver_back.dLdx_auto_iter_fnc(x, y_at, lambd_at, dLdy)\n",
"\n",
"# ---- 3. Autogradient: get auto A, B, H, then DDN dydx and multiply dLdy\n",
"x = deepcopy(x_org)\n",
"dLdx_at_ddn = obj_AT.solver_back.dLdx_DDN_fnc(x, y_at, lambd_at, dLdy, enable_B=True)\n",
"\n",
"# ---- 4. [AVOID STORING B] This is supposed to be the same as dLdx_at_ddn\n",
"dLdx_at_ddn_indiv = 0.0\n",
"\n",
"for i in range(y_at.shape[-1]):\n",
" dLdx_at_ddn_indiv += obj_AT.solver_back.dLdx_DDN_fnc(\n",
" x, y_at[:, :, i:i + 1], lambd_at[:, i:i + 1], dLdy[:, :, i:i + 1], enable_B=False)\n",
"\n",
"print('dLdx_at_ddn vs dLdx_at_ddn_indiv :', diff_fnc(dLdx_at_ddn, dLdx_at_ddn_indiv))\n",
"\n",
"# ---- Simultaneous Iteration\n",
"# ---- 1. Simultaneous Iteration: DDN\n",
"x = deepcopy(x_org)\n",
"lambd_si, y_si = obj_SI(x, backward_fnc_name='dLdx_DDN_fnc')\n",
"loss = y_si.sum()\n",
"x.retain_grad()\n",
"loss.backward()\n",
"dLdx_si = x.grad\n",
"\n",
"# The same as dLdx_si\n",
"dLdx_si_ddn = obj_SI.solver_back.dLdx_DDN_fnc(x, y_si, lambd_si, dLdy, enable_B=True)\n",
"\n",
"# ---- 2. [AVOID STORING B] This is supposed to be the same as dLdx_si_ddn\n",
"dLdx_si_ddn_indiv = 0.0\n",
"\n",
"for i in range(y_si.shape[-1]):\n",
" dLdx_si_ddn_indiv += obj_SI.solver_back.dLdx_DDN_fnc(\n",
" x, y_si[:, :, i:i + 1], lambd_si[:, i:i + 1], dLdy[:, :, i:i + 1], enable_B=False)\n",
"\n",
"print('dLdx_si_ddn vs dLdx_si_ddn_indiv :', diff_fnc(dLdx_si_ddn, dLdx_si_ddn_indiv))\n",
"\n",
"# ---- 3. Simultaneous Iteration: unrolling\n",
"x = deepcopy(x_org)\n",
"lambd_si_unroll, y_si_unroll = obj_SI(x, backward_fnc_name='unroll')\n",
"x.retain_grad()\n",
"loss = y_si_unroll.sum()\n",
"loss.backward()\n",
"dLdx_si_unroll = x.grad\n",
"\n",
"# ---- Power Iteration:\n",
"# ---- 1. Unrolling\n",
"x = deepcopy(x_org)\n",
"lambd_pi_unroll, y_pi_unroll = obj_PI(x, backward_fnc_name='unroll')\n",
"x.retain_grad()\n",
"loss = y_pi_unroll.sum()\n",
"loss.backward()\n",
"dLdx_pi_unroll = x.grad\n",
"\n",
"# ---- 2. DDN\n",
"x = deepcopy(x_org)\n",
"lambd_pi, y_pi = obj_PI(x, backward_fnc_name='dLdx_DDN_fnc')\n",
"loss = y_pi.sum()\n",
"x.retain_grad()\n",
"loss.backward()\n",
"dLdx_pi = x.grad\n",
"\n",
"# ---- Fixed Point Theorem:\n",
"dLdx_pi_ift_auto = obj_PI_IFT.solver_back.dLdx_fnc(x, y_pi, lambd_pi, dLdy, enable_B=False)\n",
"dLdx_pi_ift_stru = obj_PI_IFT.solver_back.dLdx_structured_fnc(x, y_pi, lambd_pi, dLdy, enable_B=False)\n",
"\n",
"dLdx_si_ift_auto = obj_SI_IFT.solver_back.dLdx_fnc(x, y_si, lambd_si, dLdy, enable_B=False)\n",
"dLdx_si_ift_stru = obj_SI_IFT.solver_back.dLdx_structured_fnc(x, y_si, lambd_si, dLdy, enable_B=False)\n",
"\n",
"# ----\n",
"# Zhiwei: torch.eigh() and DDN give symmetric gradients while\n",
"# unrolling does not, and they should be symmetric, so add a patch below.\n",
"# The observation is that no matter x is symmetric or not, unrolling never\n",
"# get symmetric gradients, unless we manually do the following for symmetric\n",
"# x; not for asymmetric x as it makes no sense.\n",
"if enable_symmetric:\n",
" dLdx_si_unroll = 0.5 * (dLdx_si_unroll + dLdx_si_unroll.permute(0, 2, 1))\n",
" dLdx_pi_unroll = 0.5 * (dLdx_pi_unroll + dLdx_pi_unroll.permute(0, 2, 1))\n",
"\n",
"# Compare solution and gradients\n",
"method_mode_explaination()\n",
"\n",
"print('[ Check Eigenvalues ]')\n",
"print('Auto :', disp_fnc(lambd_at).flatten())\n",
"print('PI :', disp_fnc(lambd_pi).flatten())\n",
"print('SI :', disp_fnc(lambd_si).flatten())\n",
"print('Max diff :', diff_fnc(lambd_at, lambd_si))\n",
"\n",
"print('')\n",
"print('[ Check Eigenvectors ]')\n",
"print('Auto :', disp_fnc(y_at.permute(0, 2, 1)).flatten())\n",
"print('PI :', disp_fnc(y_pi.permute(0, 2, 1)).flatten())\n",
"print('SI :', disp_fnc(y_si.permute(0, 2, 1)).flatten())\n",
"print('Max diff :', diff_fnc(y_at, y_si))\n",
"\n",
"#\n",
"eigen_gap_at = obj_AT.check_eigen_gap(lambd_at, y_at, x)\n",
"eigen_gap_pi = obj_PI.check_eigen_gap(lambd_pi, y_pi, x)\n",
"eigen_gap_si = obj_SI.check_eigen_gap(lambd_si, y_si, x)\n",
"\n",
"print('')\n",
"print('[ Check Eigen Gaps ]')\n",
"print('Auto :', disp_fnc(eigen_gap_at).flatten())\n",
"print('PI :', disp_fnc(eigen_gap_pi).flatten())\n",
"print('SI :', disp_fnc(eigen_gap_si).flatten())\n",
"\n",
"#\n",
"fp_gap_at = obj_AT.check_fixed_point_gap(y_at, x)\n",
"fp_gap_pi = obj_PI.check_fixed_point_gap(y_pi, x)\n",
"fp_gap_si = obj_SI.check_fixed_point_gap(y_si, x)\n",
"\n",
"print('')\n",
"print('[ Check Fixed Point Gaps ]')\n",
"print('Auto :', disp_fnc(fp_gap_at).flatten())\n",
"print('PI :', disp_fnc(fp_gap_pi).flatten())\n",
"print('SI :', disp_fnc(fp_gap_si).flatten())\n",
"\n",
"#\n",
"print('')\n",
"print('[ Check Gradient Gaps (Max Diff.) ]')\n",
"print('dLdx_at vs dLdx_si :', diff_fnc(dLdx_at, dLdx_si))\n",
"print('dLdx_at vs dLdx_si_ddn :', diff_fnc(dLdx_at, dLdx_si_ddn))\n",
"print('dLdx_at vs dLdx_at_iter :', diff_fnc(dLdx_at, dLdx_at_iter))\n",
"print('dLdx_at vs dLdx_at_ddn :', diff_fnc(dLdx_at, dLdx_at_ddn))\n",
"print('dLdx_pi vs dLdx_pi_unroll :', diff_fnc(dLdx_pi, dLdx_pi_unroll))\n",
"print('dLdx_pi vs dLdx_si :', diff_fnc(dLdx_pi, dLdx_si))\n",
"print('dLdx_si vs dLdx_si_unroll :', diff_fnc(dLdx_si, dLdx_si_unroll))\n",
"print('dLdx_si vs dLdx_si_ddn :', diff_fnc(dLdx_si, dLdx_si_ddn))\n",
"print('dLdx_at vs dLdx_pi_ift_auto:', diff_fnc(dLdx_at, dLdx_pi_ift_auto))\n",
"print('dLdx_at vs dLdx_si_ift_auto:', diff_fnc(dLdx_at, dLdx_si_ift_auto))\n",
"print('dLdx_at vs dLdx_pi_ift_stru:', diff_fnc(dLdx_at, dLdx_pi_ift_stru))\n",
"print('dLdx_at vs dLdx_si_ift_stru:', diff_fnc(dLdx_at, dLdx_si_ift_stru))\n",
"\n",
"del dLdx_at, dLdx_at_iter, dLdx_at_ddn, dLdx_si, dLdx_si_ddn, dLdx_si_unroll, x_org, x, loss, dLdy\n",
"torch.cuda.empty_cache()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "7fc50274",
"metadata": {},
"source": [
"- Statistics of precision"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "921bc5c8",
"metadata": {},
"outputs": [],
"source": [
"obj_dict = {\n",
" 'AT': obj_AT,\n",
" 'PI': obj_PI,\n",
" 'PI_IFT': obj_PI_IFT,\n",
" 'SI': obj_SI,\n",
" 'SI_IFT': obj_SI_IFT\n",
" }\n",
"batch = 5\n",
"data_sizes = [\n",
" (batch, 32, 1),\n",
" (batch, 64, 1),\n",
" (batch, 128, 1),\n",
" (batch, 256, 1),\n",
" (batch, 512, 1),\n",
" (batch, 1024, 1)\n",
" ]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "86fdaee4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Start precision statistics torch.float32...\n",
"[ Method: AT+dLdx_DDN_fnc, Size: (5, 32, 1) ]\n",
"[ Method: AT+dLdx_DDN_fnc, Size: (5, 64, 1) ]\n",
"[ Method: AT+dLdx_DDN_fnc, Size: (5, 128, 1) ]\n",
"[ Method: AT+dLdx_DDN_fnc, Size: (5, 256, 1) ]\n",
"[ Method: AT+dLdx_DDN_fnc, Size: (5, 512, 1) ]\n",
"[ Method: AT+dLdx_DDN_fnc, Size: (5, 1024, 1) ]\n",
"\n",
"[ Method: PI+dLdx_DDN_fnc, Size: (5, 32, 1) ]\n",
"[ Method: PI+dLdx_DDN_fnc, Size: (5, 64, 1) ]\n",
"[ Method: PI+dLdx_DDN_fnc, Size: (5, 128, 1) ]\n",
"[ Method: PI+dLdx_DDN_fnc, Size: (5, 256, 1) ]\n",
"[ Method: PI+dLdx_DDN_fnc, Size: (5, 512, 1) ]\n",
"[ Method: PI+dLdx_DDN_fnc, Size: (5, 1024, 1) ]\n",
"\n",
"[ Method: SI+dLdx_DDN_fnc, Size: (5, 32, 1) ]\n",
"[ Method: SI+dLdx_DDN_fnc, Size: (5, 64, 1) ]\n",
"[ Method: SI+dLdx_DDN_fnc, Size: (5, 128, 1) ]\n",
"[ Method: SI+dLdx_DDN_fnc, Size: (5, 256, 1) ]\n",
"[ Method: SI+dLdx_DDN_fnc, Size: (5, 512, 1) ]\n",
"[ Method: SI+dLdx_DDN_fnc, Size: (5, 1024, 1) ]\n",
"\n",
"Done!\n"
]
}
],
"source": [
"# ==== Run Precision Statistics\n",
"with torch.no_grad():\n",
" num_seeds = 100\n",
"\n",
" # * and *_IFT have the same solver solution, so show one only\n",
" methods = [\n",
" 'AT',\n",
" 'PI',\n",
" 'SI'\n",
" ]\n",
" mode_dict = {\n",
" 'AT': ['dLdx_DDN_fnc'],\n",
" 'PI': ['dLdx_DDN_fnc'],\n",
" 'PI_IFT': ['dLdx_structured_fnc'],\n",
" 'SI': ['dLdx_DDN_fnc'],\n",
" 'SI_IFT': ['dLdx_structured_fnc']\n",
" }\n",
"\n",
" enable_legend = True\n",
" distribution_mode = 'gaussian' # gaussian/uniform/vonmise/choice + _resnet50\n",
" uniform_sample_max = 1.0\n",
" choice_max = 10.0\n",
"\n",
" # ResNet50 will create m*m neurons, for m=256, it will cause out-of-memory issue, so just run under 256\n",
" if distribution_mode.find('resnet') > -1: data_sizes = [v for v in data_sizes if v[1] <= 128]\n",
"\n",
" for dtype_cur in [torch.float32]:\n",
" if num_seeds > 100 and dtype_cur == torch.float64: continue\n",
" if distribution_mode.find('resnet') > -1 and dtype_cur == torch.float64: continue\n",
"\n",
" print(f'Start precision statistics {dtype_cur}...')\n",
" save_path = f\"{g_save_path}_{str(dtype_cur).replace('torch.', '')}_{distribution_mode}_numseeds{num_seeds}\"\n",
" if distribution_mode == 'uniform': save_path += f'_max{uniform_sample_max}'\n",
"\n",
" precision_info = run_precision_statistics(\n",
" num_seeds, methods, data_sizes, mode_dict, obj_dict, enable_symmetric, dtype_cur,\n",
" distribution_mode=distribution_mode, uniform_sample_max=uniform_sample_max,\n",
" choice_max=choice_max)\n",
"\n",
" enable_legend = visual_precision(precision_info, data_sizes, save_path, enable_legend=enable_legend)\n",
"\n",
" del precision_info\n",
" torch.cuda.empty_cache()\n",
"\n",
" print('Done!')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "54d8fde2",
"metadata": {},
"source": [
"- Statistics of running time and memory: test the forward and backward feasibility of SI solver compared with the autogradient version: check the gap between the Eigenvalues, the gap between the Eigenvectors, and the gap between dLdx."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f95c0dbb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Start time and memory statistics...\n",
"[ Method: AT+unroll, Size: (5, 32, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: AT+unroll, Size: (5, 64, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: AT+unroll, Size: (5, 128, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: AT+unroll, Size: (5, 256, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: AT+unroll, Size: (5, 512, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: AT+unroll, Size: (5, 1024, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: AT+dLdx_DDN_fnc, Size: (5, 32, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: AT+dLdx_DDN_fnc, Size: (5, 64, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: AT+dLdx_DDN_fnc, Size: (5, 128, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: AT+dLdx_DDN_fnc, Size: (5, 256, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: AT+dLdx_DDN_fnc, Size: (5, 512, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: AT+dLdx_DDN_fnc, Size: (5, 1024, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"\n",
"[ Method: PI+unroll, Size: (5, 32, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+unroll, Size: (5, 64, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+unroll, Size: (5, 128, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+unroll, Size: (5, 256, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+unroll, Size: (5, 512, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+unroll, Size: (5, 1024, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+dLdx_DDN_fnc_B, Size: (5, 32, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+dLdx_DDN_fnc_B, Size: (5, 64, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+dLdx_DDN_fnc_B, Size: (5, 128, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+dLdx_DDN_fnc_B, Size: (5, 256, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+dLdx_DDN_fnc_B, Size: (5, 512, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+dLdx_DDN_fnc_B, Size: (5, 1024, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+dLdx_DDN_fnc, Size: (5, 32, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+dLdx_DDN_fnc, Size: (5, 64, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+dLdx_DDN_fnc, Size: (5, 128, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+dLdx_DDN_fnc, Size: (5, 256, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+dLdx_DDN_fnc, Size: (5, 512, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: PI+dLdx_DDN_fnc, Size: (5, 1024, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"\n",
"[ Method: PI_IFT+dLdx_fnc, Size: (5, 32, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI_IFT+dLdx_fnc, Size: (5, 64, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI_IFT+dLdx_fnc, Size: (5, 128, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI_IFT+dLdx_fnc, Size: (5, 256, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI_IFT+dLdx_fnc, Size: (5, 512, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: PI_IFT+dLdx_fnc, Size: (5, 1024, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU time skipped for IFT.\n",
"!!!CPU memory skipped.\n",
"!!! GPU memory skipped for IFT+auto with size >= 1024.\n",
"[ Method: PI_IFT+dLdx_structured_fnc, Size: (5, 32, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI_IFT+dLdx_structured_fnc, Size: (5, 64, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI_IFT+dLdx_structured_fnc, Size: (5, 128, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI_IFT+dLdx_structured_fnc, Size: (5, 256, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: PI_IFT+dLdx_structured_fnc, Size: (5, 512, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: PI_IFT+dLdx_structured_fnc, Size: (5, 1024, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"\n",
"[ Method: SI+unroll, Size: (5, 32, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+unroll, Size: (5, 64, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+unroll, Size: (5, 128, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+unroll, Size: (5, 256, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+unroll, Size: (5, 512, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+unroll, Size: (5, 1024, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+dLdx_DDN_fnc_B, Size: (5, 32, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+dLdx_DDN_fnc_B, Size: (5, 64, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+dLdx_DDN_fnc_B, Size: (5, 128, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+dLdx_DDN_fnc_B, Size: (5, 256, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+dLdx_DDN_fnc_B, Size: (5, 512, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+dLdx_DDN_fnc_B, Size: (5, 1024, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+dLdx_DDN_fnc, Size: (5, 32, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+dLdx_DDN_fnc, Size: (5, 64, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+dLdx_DDN_fnc, Size: (5, 128, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+dLdx_DDN_fnc, Size: (5, 256, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+dLdx_DDN_fnc, Size: (5, 512, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: SI+dLdx_DDN_fnc, Size: (5, 1024, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"\n",
"[ Method: SI_IFT+dLdx_fnc, Size: (5, 32, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI_IFT+dLdx_fnc, Size: (5, 64, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI_IFT+dLdx_fnc, Size: (5, 128, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI_IFT+dLdx_fnc, Size: (5, 256, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI_IFT+dLdx_fnc, Size: (5, 512, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: SI_IFT+dLdx_fnc, Size: (5, 1024, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU time skipped for IFT.\n",
"!!!CPU memory skipped.\n",
"!!! GPU memory skipped for IFT+auto with size >= 1024.\n",
"[ Method: SI_IFT+dLdx_structured_fnc, Size: (5, 32, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI_IFT+dLdx_structured_fnc, Size: (5, 64, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI_IFT+dLdx_structured_fnc, Size: (5, 128, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI_IFT+dLdx_structured_fnc, Size: (5, 256, 1) ]\n",
"!!!CPU memory skipped.\n",
"[ Method: SI_IFT+dLdx_structured_fnc, Size: (5, 512, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"[ Method: SI_IFT+dLdx_structured_fnc, Size: (5, 1024, 1) ]\n",
"!!!CPU time skipped.\n",
"!!!CPU memory skipped.\n",
"\n",
"Done!\n"
]
}
],
"source": [
"# ==== Run Time and Memory Statistics\n",
"print('Start time and memory statistics...')\n",
"methods = [\n",
" 'AT',\n",
" 'PI',\n",
" 'PI_IFT',\n",
" 'SI',\n",
" 'SI_IFT'\n",
" ]\n",
"mode_dict = {\n",
" 'AT': ['unroll', 'dLdx_DDN_fnc'],\n",
" 'PI': ['unroll', 'dLdx_DDN_fnc_B', 'dLdx_DDN_fnc'],\n",
" 'PI_IFT': ['dLdx_fnc', 'dLdx_structured_fnc'],\n",
" 'SI': ['unroll', 'dLdx_DDN_fnc_B', 'dLdx_DDN_fnc'],\n",
" 'SI_IFT': ['dLdx_fnc', 'dLdx_structured_fnc']\n",
" }\n",
"dtype = torch.float32\n",
"save_path = f\"{g_save_path}_{str(dtype).replace('torch.', '')}\"\n",
"num_seeds = 10\n",
"\n",
"cost_info = run_speed_memory_statistics(\n",
" num_seeds, methods, data_sizes, mode_dict, obj_dict, enable_symmetric, dtype)\n",
"\n",
"visual_speed_memory(cost_info, data_sizes, save_path)\n",
"\n",
"del cost_info\n",
"torch.cuda.empty_cache()\n",
"print('Done!')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}