{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "正文开始:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. 第一章 - 基本概念: Var"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "目录:\n",
    "* 1.1 Var 的基本介绍\n",
    "* 1.2 如何生成 Var\n",
    "* 1.3 Var 的常见属性\n",
    "* 1.4 Var 的常见操作\n",
    "* 1.5 Var 的运算\n",
    "* 1.6 与 NumPy 的相互转化"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.1 Var 的基本介绍"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Var:计图的基本数据类型\n",
    "\n",
    "在计图 (Jittor) - 开源深度学习框架中,Var 类型是最基本的数据类型。计图的所有操作,都将基于 Var 进行。所以,学习并掌握 Var 是使用计图的第一步。请不要感到畏难,我们之所以引入 Var 类型,是为了优化深度学习的计算能力与效率,目的是方便您更加高效地进行开发,请您尽情享受。  \n",
    "\n",
    "简单来说,Var 是一种特殊的数据结构,它与数组和矩阵非常相似。您可简单理解为,一维 Var 即是数组,二维 Var 即是矩阵。在计图中,我们使用 Var 来编码模型的输入和输出,以及模型的参数。\n",
    "\n",
    "**计图的 Var 与 NumPy 的 ndarray 极其相似**,但 Var 可以在 GPU 或其他专用硬件上运行,以加速计算。如果您熟悉 ndarray,您就会十分熟悉 Var 的相关操作。如果没有,请按照本教程,对 Var 类型进行快速演练。本教程已覆盖 Var 的基础操作,可以基本满足您的开发需求。如您想要更加深入了解并灵活使用 Var,敬请查阅我们的 [API 文档](https://cg.cs.tsinghua.edu.cn/jittor/assets/docs/index.html)。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.2 如何生成 Var"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先,我们在 Python 中导入计图:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[38;5;2m[i 0202 22:49:45.609290 80 compiler.py:847] Jittor(1.2.2.23) src: /home/llt/.local/lib/python3.7/site-packages/jittor\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:45.611377 80 compiler.py:848] g++ at /usr/bin/g++\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:45.612691 80 compiler.py:849] cache_path: /home/llt/.cache/jittor/default/g++\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:45.626682 80 __init__.py:257] Found /usr/local/cuda/bin/nvcc(10.2.89) at /usr/local/cuda/bin/nvcc.\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:45.697894 80 __init__.py:257] Found gdb(8.1.0) at /usr/bin/gdb.\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:45.715477 80 __init__.py:257] Found addr2line(2.30) at /usr/bin/addr2line.\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:45.768340 80 compiler.py:889] pybind_include: -I/usr/include/python3.7m -I/usr/local/lib/python3.7/dist-packages/pybind11/include\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:45.792263 80 compiler.py:891] extension_suffix: .cpython-37m-x86_64-linux-gnu.so\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:45.974687 80 __init__.py:169] Total mem: 62.78GB, using 16 procs for compiling.\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:46.131834 80 jit_compiler.cc:21] Load cc_path: /usr/bin/g++\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:46.304843 80 init.cc:54] Found cuda archs: [75,]\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:46.491697 80 __init__.py:257] Found mpicc(2.1.1) at /usr/bin/mpicc.\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:46.549639 80 compiler.py:654] handle pyjt_include/home/llt/.local/lib/python3.7/site-packages/jittor/extern/mpi/inc/mpi_warper.h\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:46.596179 80 compile_extern.py:287] Downloading nccl...\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:46.683976 80 compile_extern.py:16] found /usr/local/cuda/include/cublas.h\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:46.686093 80 compile_extern.py:16] found /usr/lib/x86_64-linux-gnu/libcublas.so\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:47.293398 80 compile_extern.py:16] found /usr/local/cuda/include/cudnn.h\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:47.294725 80 compile_extern.py:16] found /usr/local/cuda/lib64/libcudnn.so\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:47.332204 80 compiler.py:654] handle pyjt_include/home/llt/.local/lib/python3.7/site-packages/jittor/extern/cuda/cudnn/inc/cudnn_warper.h\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:49.192471 80 compile_extern.py:16] found /usr/local/cuda/include/curand.h\u001b[m\n",
      "\u001b[38;5;2m[i 0202 22:49:49.193815 80 compile_extern.py:16] found /usr/local/cuda/lib64/libcurand.so\u001b[m\n"
     ]
    }
   ],
   "source": [
    "import jittor as jt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**从 “普通数据” 导入生成 Var**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "data: \n",
      " [[1, 2], [3, 4], [5, 6]]\n",
      "var_data: \n",
      " jt.Var([[1 2]\n",
      " [3 4]\n",
      " [5 6]], dtype=int32)\n",
      "var_data的类型是: <class 'jittor.jittor_core.Var'>\n"
     ]
    }
   ],
   "source": [
    "data = [[1, 2],[3, 4],[5, 6]]                   # 生成一个普通的列表\n",
    "print(\"data: \\n\", data)\n",
    "\n",
    "var_data = jt.array(data)                         # 通过列表生成 Var\n",
    "print(\"var_data: \\n\", var_data)\n",
    "\n",
    "print(\"var_data的类型是:\", type(var_data))                          # 打印类型,可以发现生成的是jittor.Var\n",
    "var_data = jt.Var(data)                        # 当然,jittor.Var也可以直接使用"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**从 “NumPy 数组” 导入生成 Var**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "np_array: \n",
      " [[1 2]\n",
      " [3 4]\n",
      " [5 6]]\n",
      "var_np_array: \n",
      " jt.Var([[1 2]\n",
      " [3 4]\n",
      " [5 6]], dtype=int64)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "np_array = np.array([[1, 2],[3, 4],[5, 6]])      # 生成一个 NumPy 的数组\n",
    "print(\"np_array: \\n\", np_array)\n",
    "\n",
    "var_np_array=jt.array(np_array)                    # 通过 NumPy 的数组生成 Var\n",
    "print(\"var_np_array: \\n\", var_np_array)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**从 “其他 Var 数据” 导入生成新的 Var**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Zeros Var: \n",
      " jt.Var([[0 0]\n",
      " [0 0]\n",
      " [0 0]], dtype=int32)\n",
      "Ones Var: \n",
      " jt.Var([[1 1]\n",
      " [1 1]\n",
      " [1 1]], dtype=int32)\n",
      "full Var: \n",
      " jt.Var([[5 5]\n",
      " [5 5]\n",
      " [5 5]], dtype=int32)\n",
      "Random Var: \n",
      " jt.Var([[0.27817145 0.22768587]\n",
      " [0.71629596 0.78609395]\n",
      " [0.8813336  0.5737404 ]], dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "# 准备一个原始的 Var\n",
    "var_data = jt.Var([[1, 2],[3, 4],[5, 6]])\n",
    "\n",
    "# 生成一个形状相同,元素全为 0 的 Var\n",
    "var_zeros = jt.zeros_like(var_data) \n",
    "print(\"Zeros Var: \\n\", var_zeros)\n",
    "\n",
    "# 生成一个形状相同,元素全为 1 的 Var\n",
    "var_ones = jt.ones_like(var_data) \n",
    "print(\"Ones Var: \\n\", var_ones)\n",
    "\n",
    "# 生成一个形状相同,并指定元素值的 Var\n",
    "var_full = jt.full_like(var_data, 5)\n",
    "print(\"full Var: \\n\", var_full)\n",
    "\n",
    "# 生成一个形状相同,元素为随机数的 Var\n",
    "var_rand = jt.rand(var_data.shape, dtype=jt.float)    # 重写元素数据类型\n",
    "print(\"Random Var: \\n\", var_rand)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**生成一些特定的 Var**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Random Var: \n",
      " jt.Var([[0.8556649  0.16008338 0.5214482  0.98024756]\n",
      " [0.02092005 0.6033252  0.08662996 0.9897186 ]\n",
      " [0.20087309 0.07411052 0.57556623 0.54145   ]], dtype=float32)\n",
      "Ones Var: \n",
      " jt.Var([[1. 1. 1. 1.]\n",
      " [1. 1. 1. 1.]\n",
      " [1. 1. 1. 1.]], dtype=float32)\n",
      "Zeros Var: \n",
      " jt.Var([[0. 0. 0. 0.]\n",
      " [0. 0. 0. 0.]\n",
      " [0. 0. 0. 0.]], dtype=float32)\n",
      "Arange Var 1: \n",
      " jt.Var([0 1 2 3 4 5 6 7 8], dtype=int32)\n",
      "Arange Var 2: \n",
      " jt.Var([0 2 4 6 8], dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "shape = (3,4,)                                 # 指定将要生成 Var 的形状\n",
    "rand_var = jt.rand(shape)                      # 生成元素为随机数的 Var\n",
    "ones_var = jt.ones(shape)                      # 生成元素全为 1 的 Var\n",
    "zeros_var = jt.zeros(shape)                    # 生成元素全为 0 的 Var\n",
    "\n",
    "print(\"Random Var: \\n\", rand_var)\n",
    "print(\"Ones Var: \\n\", ones_var)\n",
    "print(\"Zeros Var: \\n\", zeros_var)\n",
    "\n",
    "arange_var1 = jt.arange(9)                     # 生成元素为等差数组的一维 Var (只指定一个结束参数)\n",
    "arange_var2 = jt.arange(0, 9, 2)               # 生成元素为等差数组的一维 Var (指定起始、结束、步长)\n",
    "print(\"Arange Var 1: \\n\", arange_var1)\n",
    "print(\"Arange Var 2: \\n\", arange_var2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.3 Var 的常见属性"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Var 的形状:** 通过 var.shape 查看"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Var: \n",
      " jt.Var([[ 1  2  3  4]\n",
      " [ 5  6  7  8]\n",
      " [ 9 10 11 12]], dtype=int32)\n",
      "Shape of Var: \n",
      " [3,4,]\n"
     ]
    }
   ],
   "source": [
    "var = jt.array([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12]])\n",
    "print(\"Var: \\n\", var)\n",
    "print(\"Shape of Var: \\n\", var.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Var 的数据类型:** 通过 var.dtype 查看"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Var: \n",
      " jt.Var([[1. 1. 1. 1.]\n",
      " [1. 1. 1. 1.]\n",
      " [1. 1. 1. 1.]], dtype=float32)\n",
      "Datatype of Var: \n",
      " float32\n"
     ]
    }
   ],
   "source": [
    "var = jt.ones((3,4))\n",
    "print(\"Var: \\n\", var)\n",
    "print(\"Datatype of Var: \\n\", var.dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.4 Var 的常见操作"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**重塑形状 Reshape**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Var1: \n",
      " jt.Var([ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n",
      " 24 25 26], dtype=int32)\n",
      "Shape of Var1: \n",
      " [27,]\n",
      "Var2: \n",
      " jt.Var([[[ 0  1  2]\n",
      "  [ 3  4  5]\n",
      "  [ 6  7  8]]\n",
      "\n",
      " [[ 9 10 11]\n",
      "  [12 13 14]\n",
      "  [15 16 17]]\n",
      "\n",
      " [[18 19 20]\n",
      "  [21 22 23]\n",
      "  [24 25 26]]], dtype=int32)\n",
      "Shape of Var2: \n",
      " [3,3,3,]\n",
      "Var3: \n",
      " jt.Var([[ 0  1  2  3  4  5  6  7  8]\n",
      " [ 9 10 11 12 13 14 15 16 17]\n",
      " [18 19 20 21 22 23 24 25 26]], dtype=int32)\n",
      "Shape of Var3: \n",
      " [3,9,]\n"
     ]
    }
   ],
   "source": [
    "var1 = jt.arange(27)                         # 创建一个元素为 0-26 的一维 Var1\n",
    "print(\"Var1: \\n\", var1)\n",
    "print(\"Shape of Var1: \\n\", var1.shape)\n",
    "\n",
    "var2 = var1.reshape(3,3,3)                   # 将 Var1 重塑形状为 3*3*3 的 Var2\n",
    "print(\"Var2: \\n\", var2)\n",
    "print(\"Shape of Var2: \\n\", var2.shape)\n",
    "\n",
    "var3 = var2.reshape(3,-1)                    # 参数值为 -1,代表将根据其他形状参数自动计算出合适的值,再进行重塑形状\n",
    "print(\"Var3: \\n\", var3)\n",
    "print(\"Shape of Var3: \\n\", var3.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**转置 Transpose**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* 矩阵转置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Var1: \n",
      " jt.Var([[ 0  1  2]\n",
      " [ 3  4  5]\n",
      " [ 6  7  8]\n",
      " [ 9 10 11]], dtype=int32)\n",
      "Shape of Var1: \n",
      " [4,3,]\n",
      "Var1 Transpose: \n",
      " jt.Var([[ 0  3  6  9]\n",
      " [ 1  4  7 10]\n",
      " [ 2  5  8 11]], dtype=int32)\n",
      "Shape of Var1 Transpose: \n",
      " [3,4,]\n"
     ]
    }
   ],
   "source": [
    "# (提示:这里我们用 arange 初始化一个 0 到 11 的一维 Var,然后通过 reshape 将其转化为 4*3 的 Var1。)\n",
    "var1 = jt.arange(12).reshape(4,3)                           # 创建一个 4*3 的二维矩阵 Var1\n",
    "print(\"Var1: \\n\", var1)\n",
    "print(\"Shape of Var1: \\n\", var1.shape)\n",
    "\n",
    "var1_t = var1.transpose()                                   # 矩阵转置成 3*4 的二维矩阵 Var1 Transpose\n",
    "print(\"Var1 Transpose: \\n\", var1_t)\n",
    "print(\"Shape of Var1 Transpose: \\n\", var1_t.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* 指定轴转置(较难)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Var1: \n",
      " jt.Var([[[ 0  1  2  3]\n",
      "  [ 4  5  6  7]\n",
      "  [ 8  9 10 11]]\n",
      "\n",
      " [[12 13 14 15]\n",
      "  [16 17 18 19]\n",
      "  [20 21 22 23]]], dtype=int32)\n",
      "Shape of Var1: \n",
      " [2,3,4,]\n",
      "Var1 Transpose: \n",
      " jt.Var([[[ 0  4  8]\n",
      "  [ 1  5  9]\n",
      "  [ 2  6 10]\n",
      "  [ 3  7 11]]\n",
      "\n",
      " [[12 16 20]\n",
      "  [13 17 21]\n",
      "  [14 18 22]\n",
      "  [15 19 23]]], dtype=int32)\n",
      "Shape of Var1 Transpose: \n",
      " [2,4,3,]\n"
     ]
    }
   ],
   "source": [
    "var1 = jt.arange(24).reshape(2,3,4)                        # 创建一个 2*3*4 的三维 Var1\n",
    "print(\"Var1: \\n\", var1)                                    # 其含义为,第一个轴上 2 个元素;第二个轴上 3 个元素;第三个轴上 4 个元素。\n",
    "print(\"Shape of Var1: \\n\", var1.shape)\n",
    "\n",
    "var1_t = var1.transpose(0,2,1)                             # 根据轴索引转置。三维轴索引为 0,1,2,调整索引顺序即代表对应轴调整顺序。\n",
    "print(\"Var1 Transpose: \\n\", var1_t)                        # 从默认的“0,1,2”调整为“0,2,1”,即代表后两轴进行互换。\n",
    "print(\"Shape of Var1 Transpose: \\n\", var1_t.shape)         # 互换后,第一个轴上 2 个元素;第二个轴上 4 个元素;第三个轴上 3 个元素。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**切片和索引**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 Var: \n",
      " jt.Var([[0 1 2]\n",
      " [3 4 5]\n",
      " [6 7 8]], dtype=int32)\n",
      "2 Var: \n",
      " jt.Var([[0 0 2]\n",
      " [3 0 5]\n",
      " [6 0 8]], dtype=int32)\n",
      "3 Line 2 in Var: \n",
      " jt.Var([3 0 5], dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "var = jt.arange(9).reshape(3,3)                     # 初始化一个 3*3 的 Var\n",
    "print(1, \"Var: \\n\", var)\n",
    "\n",
    "var[:,1] = 0                                        # 利用切片,将任意行的第二个元素(即第二列)重置为 0\n",
    "print(2, \"Var: \\n\", var)\n",
    "\n",
    "line_2 = var[1]                                     # 利用索引,拿到第二行元素\n",
    "print(3, \"Line 2 in Var: \\n\", line_2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**原地操作**  \n",
    "在计图中,我们使用 assign() 函数实现在原变量上直接操作。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 Var0: \n",
      " jt.Var([[0 1 2]\n",
      " [3 4 5]], dtype=int32)\n",
      "1 Var1: \n",
      " jt.Var([[1 2 3]\n",
      " [4 5 6]], dtype=int32)\n",
      "1 Var0: \n",
      " jt.Var([[0 1 2]\n",
      " [3 4 5]], dtype=int32)\n",
      "2 Var2: \n",
      " jt.Var([[1 2 3]\n",
      " [4 5 6]], dtype=int32)\n",
      "2 Var0: \n",
      " jt.Var([[1 2 3]\n",
      " [4 5 6]], dtype=int32)\n",
      "3 Var3: \n",
      " jt.Var([[1 4]\n",
      " [2 5]\n",
      " [3 6]], dtype=int32)\n",
      "3 Var0: \n",
      " jt.Var([[1 2 3]\n",
      " [4 5 6]], dtype=int32)\n",
      "4 Var4: \n",
      " jt.Var([[1 4]\n",
      " [2 5]\n",
      " [3 6]], dtype=int32)\n",
      "4 Var0: \n",
      " jt.Var([[1 4]\n",
      " [2 5]\n",
      " [3 6]], dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "var0 = jt.arange(6).reshape(2,3)              # 初始化一个 2*3 的 Var0\n",
    "print(0, \"Var0: \\n\", var0)                    # 打印查看 Var0\n",
    "\n",
    "var1 = var0.add(1)                            # 使 Var1 等于 Var0 上每个元素加 1\n",
    "print(1, \"Var1: \\n\", var1)                    # 打印发现 Var1 赋值成功\n",
    "print(1, \"Var0: \\n\", var0)                    # 但 Var0 并没有发生变化\n",
    "\n",
    "var2 = var0.assign(var0.add(1))               # 利用 assign() 函数,使 Var2 等于 Var0 上每个元素加 1\n",
    "print(2, \"Var2: \\n\", var2)                    # 打印发现 Var2 赋值成功\n",
    "print(2, \"Var0: \\n\", var0)                    # 同时 Var0 也发生了改变\n",
    "\n",
    "var3 = var0.transpose()                       # 使 Var3 等于 Var0 的转置\n",
    "print(3, \"Var3: \\n\", var3)                    # 打印发现 Var3 赋值成功\n",
    "print(3, \"Var0: \\n\", var0)                    # 但 Var0 并没有发生变化\n",
    "\n",
    "var4 = var0.assign(var0.transpose())          # 利用 assign() 函数,使 Var4 等于 Var0 的转置\n",
    "print(4, \"Var4: \\n\", var4)                    # 打印发现 Var4 赋值成功\n",
    "print(4, \"Var0: \\n\", var0)                    # 同时 Var0 也发生了改变"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.5 Var 的运算"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**开启 GPU 加速**  \n",
    "在计图中,开启 GPU 加速的方式非常简单。我们使用 jt.flags.use_cuda 表示是否使用 GPU 训练。\n",
    "\n",
    "  \n",
    "* 如果 jt.flags.use_cuda = 0,表示使用 CPU 训练;\n",
    "* 如果 jt.flags.use_cuda = 1,表示使用 GPU 训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[38;5;2m[i 0130 19:00:46.379159 88 cuda_flags.cc:26] CUDA enabled.\u001b[m\n"
     ]
    }
   ],
   "source": [
    "jt.flags.use_cuda = 1         # 使用 GPU 训练 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Var 的加法和连接**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "var1: \n",
      " jt.Var([[0 1 2]\n",
      " [3 4 5]\n",
      " [6 7 8]], dtype=int32)\n",
      "var2: \n",
      " jt.Var([[1. 1. 1.]\n",
      " [1. 1. 1.]\n",
      " [1. 1. 1.]], dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "# 初始化两个Var\n",
    "var1 = jt.arange(9).reshape(3,3)\n",
    "var2 = jt.ones((3,3))\n",
    "print(\"var1: \\n\", var1)\n",
    "print(\"var2: \\n\", var2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "var3: \n",
      " jt.Var([[ 3  4  5]\n",
      " [ 6  7  8]\n",
      " [ 9 10 11]], dtype=int32)\n",
      "var4: \n",
      " jt.Var([[1. 2. 3.]\n",
      " [4. 5. 6.]\n",
      " [7. 8. 9.]], dtype=float32)\n",
      "var5: \n",
      " jt.Var([[0. 1. 2. 1. 1. 1.]\n",
      " [3. 4. 5. 1. 1. 1.]\n",
      " [6. 7. 8. 1. 1. 1.]], dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "# Var 与实数相加\n",
    "var3 = var1 + 3\n",
    "print(\"var3: \\n\", var3)\n",
    "\n",
    "# 两个 Var 相加\n",
    "var4 = var1 + var2\n",
    "print(\"var4: \\n\", var4)\n",
    "\n",
    "# 连接两个 Var \n",
    "var5 = jt.concat([var1, var2], dim=1)\n",
    "print(\"var5: \\n\", var5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Var 的乘法:标量相乘 / 矩阵相乘**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "var1: \n",
      " jt.Var([[0 1 2]\n",
      " [3 4 5]\n",
      " [6 7 8]], dtype=int32)\n",
      "var2: \n",
      " jt.Var([[1. 1. 1.]\n",
      " [1. 1. 1.]\n",
      " [1. 1. 1.]], dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "# 初始化两个Var\n",
    "var1 = jt.arange(9).reshape(3,3)\n",
    "var2 = jt.ones((3,3))\n",
    "print(\"var1: \\n\", var1)\n",
    "print(\"var2: \\n\", var2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* **Var 与实数相乘:**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "var3: \n",
      " jt.Var([[ 0  3  6]\n",
      " [ 9 12 15]\n",
      " [18 21 24]], dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "# Var 与实数相乘\n",
    "var3 = var1 * 3\n",
    "print(\"var3: \\n\", var3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* **两个 Var 标量相乘:**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "var4: \n",
      " jt.Var([[0. 1. 2.]\n",
      " [3. 4. 5.]\n",
      " [6. 7. 8.]], dtype=float32)\n",
      "var5: \n",
      " jt.Var([[0. 1. 2.]\n",
      " [3. 4. 5.]\n",
      " [6. 7. 8.]], dtype=float32)\n",
      "var6: \n",
      " jt.Var([[0. 1. 2.]\n",
      " [3. 4. 5.]\n",
      " [6. 7. 8.]], dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "# 标量相乘 方式一\n",
    "var4 = var1 * var2\n",
    "print(\"var4: \\n\", var4)\n",
    "\n",
    "# 标量相乘 方式二\n",
    "var5 = var1.multiply(var2)\n",
    "print(\"var5: \\n\", var5)\n",
    "\n",
    "# 标量相乘 方式三\n",
    "var6 = jt.multiply(var1, var2)\n",
    "print(\"var6: \\n\", var6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* **两个 Var 矩阵相乘:**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "var7: \n",
      " jt.Var([[ 3.  3.  3.]\n",
      " [12. 12. 12.]\n",
      " [21. 21. 21.]], dtype=float32)\n",
      "var8: \n",
      " jt.Var([[ 3.  3.  3.]\n",
      " [12. 12. 12.]\n",
      " [21. 21. 21.]], dtype=float32)\n",
      "var9: \n",
      " jt.Var([[ 3.  3.  3.]\n",
      " [12. 12. 12.]\n",
      " [21. 21. 21.]], dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "# 矩阵相乘 方式一\n",
    "var7 = var1 @ var2\n",
    "print(\"var7: \\n\", var7)\n",
    "\n",
    "# 矩阵相乘 方式二\n",
    "var8 = var1.matmul(var2)\n",
    "print(\"var8: \\n\", var8)\n",
    "\n",
    "# 矩阵相乘 方式三\n",
    "var9 = jt.matmul(var1, var2)\n",
    "print(\"var9: \\n\", var9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.6 与 NumPy 的相互转化"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**从 Var 转化成 NumPy 数组**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "var: \n",
      " jt.Var([[1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1.]], dtype=float32)\n",
      "type(var): <class 'jittor.jittor_core.Var'>\n",
      "ndarray: \n",
      " [[1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1.]]\n",
      "type(ndarray): <class 'numpy.ndarray'>\n"
     ]
    }
   ],
   "source": [
    "# 新建一个 Var\n",
    "var = jt.ones((4,5))\n",
    "print(\"var: \\n\", var)\n",
    "print(\"type(var):\", type(var))\n",
    "\n",
    "# 将其转化为 NumPy 数组\n",
    "ndarray = var.numpy()\n",
    "print(\"ndarray: \\n\", ndarray)\n",
    "print(\"type(ndarray):\", type(ndarray))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**从 NumPy 数组转化成 Var**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ndarray: \n",
      " [[1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1.]]\n",
      "type(ndarray): <class 'numpy.ndarray'>\n",
      "var: \n",
      " jt.Var([[1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1.]\n",
      " [1. 1. 1. 1. 1.]], dtype=float64)\n",
      "type(var): <class 'jittor.jittor_core.Var'>\n"
     ]
    }
   ],
   "source": [
    "# 新建一个 NumPy 数组\n",
    "ndarray = np.ones((4,5))\n",
    "print(\"ndarray: \\n\", ndarray)\n",
    "print(\"type(ndarray):\", type(ndarray))\n",
    "\n",
    "# 将其转化为 Var 类型\n",
    "var = jt.array(ndarray)\n",
    "print(\"var: \\n\", var)\n",
    "print(\"type(var):\", type(var))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 📣\n",
    "至此,第一章 - 基本概念: Var - 结束!\n",
    "\n",
    "恭喜您!已成功解锁了 Jittor (计图) 最核心、最基本的数据类型 - Var 类型的基本使用。  \n",
    "\n",
    "在掌握了 Var 类型的基本操作之后,您只需要进一步熟悉神经网络训练的一般模式,便可以利用 Jittor 解决更高阶、更实战的问题。  \n",
    "请不要停下脚步,只差一点点,你就可以轻松地用 Jittor 来进行神经网络的训练啦!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}