{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import cv2\n",
    "import numpy as np\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser(description=\"Camera Intrinsic Calibration\")\n",
    "parser.add_argument('-input', '--INPUT_TYPE', default='camera', type=str, help='Input Source: camera/video/image')\n",
    "parser.add_argument('-type', '--CAMERA_TYPE', default='fisheye', type=str, help='Camera Type: fisheye/normal')\n",
    "parser.add_argument('-id', '--CAMERA_ID', default=1, type=int, help='Camera ID')\n",
    "parser.add_argument('-path', '--INPUT_PATH', default='./data/', type=str, help='Input Video/Image Path')\n",
    "parser.add_argument('-video', '--VIDEO_FILE', default='video.mp4', type=str, help='Input Video File Name (eg.: video.mp4)')\n",
    "parser.add_argument('-image', '--IMAGE_FILE', default='img_raw', type=str, help='Input Image File Name Prefix (eg.: img_raw)')\n",
    "parser.add_argument('-mode', '--SELECT_MODE', default='auto', type=str, help='Image Select Mode: auto/manual')\n",
    "parser.add_argument('-fw','--FRAME_WIDTH', default=1280, type=int, help='Camera Frame Width')\n",
    "parser.add_argument('-fh','--FRAME_HEIGHT', default=1024, type=int, help='Camera Frame Height')\n",
    "parser.add_argument('-bw','--BORAD_WIDTH', default=7, type=int, help='Chess Board Width (corners number)')\n",
    "parser.add_argument('-bh','--BORAD_HEIGHT', default=6, type=int, help='Chess Board Height (corners number)')\n",
    "parser.add_argument('-size','--SQUARE_SIZE', default=10, type=int, help='Chess Board Square Size (mm)')\n",
    "parser.add_argument('-num','--CALIB_NUMBER', default=5, type=int, help='Least Required Calibration Frame Number')\n",
    "parser.add_argument('-delay','--FRAME_DELAY', default=12, type=int, help='Capture Image Time Interval (frame number)')\n",
    "parser.add_argument('-subpix','--SUBPIX_REGION', default=5, type=int, help='Corners Subpix Optimization Region')\n",
    "parser.add_argument('-fps','--CAMERA_FPS', default=20, type=int, help='Camera Frame per Second(FPS)')\n",
    "parser.add_argument('-fs', '--FOCAL_SCALE', default=0.5, type=float, help='Camera Undistort Focal Scale')\n",
    "parser.add_argument('-ss', '--SIZE_SCALE', default=1, type=float, help='Camera Undistort Size Scale')\n",
    "parser.add_argument('-store','--STORE_FLAG', default=False, type=bool, help='Store Captured Images (Ture/False)')\n",
    "parser.add_argument('-store_path', '--STORE_PATH', default='./data/', type=str, help='Path to Store Captured Images')\n",
    "parser.add_argument('-crop','--CROP_FLAG', default=False, type=bool, help='Crop Input Video/Image to (fw,fh) (Ture/False)')\n",
    "parser.add_argument('-resize','--RESIZE_FLAG', default=False, type=bool, help='Resize Input Video/Image to (fw,fh) (Ture/False)')\n",
    "args = parser.parse_args([])                 # Jupyter Notebook中直接运行时要加[], py文件则去掉"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# args.INPUT_TYPE = 'image'                  # 输入形式 相机/视频/图像\n",
    "# args.CAMERA_TYPE = 'normal'                # 相机类型 鱼眼/普通\n",
    "# args.CAMERA_ID = 1                         # 相机编号\n",
    "# args.INPUT_PATH = './data/'                # 图片、视频输入路径\n",
    "# args.VIDEO_FILE = 'video.mp4'              # 输入视频文件名(含扩展名)\n",
    "# args.IMAGE_FILE = 'raw'                    # 输入图像文件名前缀\n",
    "# args.SELECT_MODE = 'manual'                # 选择自动/手动模式\n",
    "# args.FRAME_WIDTH = 1280                    # 相机分辨率 帧宽度\n",
    "# args.FRAME_HEIGHT = 720                    # 相机分辨率 帧高度\n",
    "# args.BORAD_WIDTH = 7                       # 棋盘宽度 【内角点数】\n",
    "# args.BORAD_HEIGHT = 6                      # 棋盘高度 【内角点数】\n",
    "# args.SQUARE_SIZE = 10                      # 棋盘格边长 mm\n",
    "# args.CALIB_NUMBER = 10                     # 初始化最小标定图片采样数量\n",
    "# args.FRAME_DELAY = 15                      # 间隔多少帧数采样\n",
    "# args.SUBPIX_REGION = 3                     # 角点坐标亚像素优化时的搜索区域大小(根据图像分辨率调整)\n",
    "# args.STORE_FLAG = True                     # 是否保存抓取的图像\n",
    "# args.STORE_PATH = './data/'                # 保存抓取的图像的路径\n",
    "# args.CROP_FLAG = True                      # 是否将输入视频/图像尺寸裁剪至FRAME_WIDTH和FRAME_HEIGHT的设定值\n",
    "# args.RESIZE_FLAG = True                    # 是否将输入视频/图像尺寸缩放至FRAME_WIDTH和FRAME_HEIGHT的设定值(图像缩放会改变相机焦距)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 外部调用修改参数\n",
    "def getInCalibArgs():\n",
    "    return args\n",
    "\n",
    "def editInCalibArgs(new_args):\n",
    "    global args\n",
    "    args = new_args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CalibData:                             # 标定数据类\n",
    "    def __init__(self):\n",
    "        self.type = None                     # 自定义数据类型\n",
    "        self.camera_mat = None               # 相机内参\n",
    "        self.dist_coeff = None               # 畸变参数\n",
    "        self.rvecs = None                    # 旋转向量\n",
    "        self.tvecs = None                    # 平移向量\n",
    "        self.map1 = None                     # 映射矩阵1\n",
    "        self.map2 = None                     # 映射矩阵2\n",
    "        self.reproj_err = None               # 重投影误差\n",
    "        self.ok = False                      # 数据采集完成标志"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cv2.fisheye.calibrate ( objectPoints,      # 角点在棋盘中的空间坐标向量\n",
    "#                         imagePoints,       # 角点在图像中的坐标向量\n",
    "#                         image_size,        # 图片大小\n",
    "#                         K,                 # 相机内参矩阵\n",
    "#                         D,                 # 畸变参数向量\n",
    "#                         rvecs,             # 旋转向量\n",
    "#                         tvecs,             # 平移向量\n",
    "#                         flags,             # 操作标志\n",
    "#                         criteria           # 迭代优化算法的停止标准\n",
    "#                         )\n",
    "\n",
    "# flags:\n",
    "#     cv2.fisheye.CALIB_USE_INTRINSIC_GUESS     # 当相机内参矩阵包含有效的fx,fy,cx,cy初始值时,这些值会进一步进行优化\n",
    "#                                               # 否则,(cx,cy)初始化设置为图像中心(使用imageSize),并且以最小二乘法计算焦距。\n",
    "#     cv2.fisheye.CALIB_RECOMPUTE_EXTRINSIC     # 在每次内部参数优化迭代之后,将重新计算外部参数。\n",
    "#     cv2.fisheye.CALIB_CHECK_COND              # 检查条件编号的有效性\n",
    "#     cv2.fisheye.CALIB_FIX_SKEW                # 偏斜系数(alpha)设置为零,并保持为零\n",
    "#     cv2.fisheye.CALIB_FIX_K1 (K1-K4)          # 选定的畸变系数设置为零,并保持为零  CALIB_FIX_INTRINSIC则全为零\n",
    "\n",
    "# criteria:\n",
    "#     TermCriteria (int type, int maxCount, double epsilon)      # 类型、最大计数次数、最小精度\n",
    "#     Criteria type, can be one of: COUNT, EPS or COUNT + EPS    # 角点优化最大迭代次数或角点优化位置移动量小于epsilon值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cv2.checkRange 检查元素非空及无异常值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "class Fisheye:           # 鱼眼相机\n",
    "    def __init__(self):\n",
    "        self.data = CalibData()\n",
    "        self.inited = False\n",
    "        self.BOARD = np.array([ [(j * args.SQUARE_SIZE, i * args.SQUARE_SIZE, 0.)]\n",
    "                               for i in range(args.BORAD_HEIGHT) \n",
    "                               for j in range(args.BORAD_WIDTH) ],dtype=np.float32)     # 棋盘角点二维坐标(乘上尺寸)\n",
    "        \n",
    "    # 更新标定数据,分为初始化和精调\n",
    "    def update(self, corners, frame_size):\n",
    "        board = [self.BOARD] * len(corners)\n",
    "        if not self.inited:\n",
    "            self._update_init(board, corners, frame_size)\n",
    "            self.inited = True\n",
    "        else:\n",
    "            self._update_refine(board, corners, frame_size)\n",
    "        self._calc_reproj_err(corners)\n",
    "        self._get_undistort_maps()\n",
    "    \n",
    "    # 得到一定数量标定样本时进行初始标定\n",
    "    def _update_init(self, board, corners, frame_size):\n",
    "        data = self.data\n",
    "        data.type = \"FISHEYE\"\n",
    "        data.camera_mat = np.eye(3, 3)\n",
    "        data.dist_coeff = np.zeros((4, 1))\n",
    "        data.ok, data.camera_mat, data.dist_coeff, data.rvecs, data.tvecs = cv2.fisheye.calibrate(\n",
    "            board, corners, frame_size, data.camera_mat, data.dist_coeff,\n",
    "            flags=cv2.fisheye.CALIB_FIX_SKEW|cv2.fisheye.CALIB_RECOMPUTE_EXTRINSIC,\n",
    "            criteria=(cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_COUNT, 30, 1e-6)) \n",
    "        data.ok = data.ok and cv2.checkRange(data.camera_mat) and cv2.checkRange(data.dist_coeff)\n",
    "\n",
    "    # 精调时启用CALIB_USE_INTRINSIC_GUESS\n",
    "    def _update_refine(self, board, corners, frame_size):\n",
    "        data = self.data\n",
    "        data.ok, data.camera_mat, data.dist_coeff, data.rvecs, data.tvecs = cv2.fisheye.calibrate(\n",
    "            board, corners, frame_size, data.camera_mat, data.dist_coeff,\n",
    "            flags=cv2.fisheye.CALIB_FIX_SKEW|cv2.fisheye.CALIB_RECOMPUTE_EXTRINSIC|cv2.CALIB_USE_INTRINSIC_GUESS,\n",
    "            criteria=(cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_COUNT, 10, 1e-6))\n",
    "        data.ok = data.ok and cv2.checkRange(data.camera_mat) and cv2.checkRange(data.dist_coeff)\n",
    "\n",
    "    # 计算重投影误差,单位为像素\n",
    "    def _calc_reproj_err(self, corners):\n",
    "        if not self.inited: return\n",
    "        data = self.data\n",
    "        data.reproj_err = []\n",
    "        for i in range(len(corners)):\n",
    "            corners_reproj, _ = cv2.fisheye.projectPoints(self.BOARD, data.rvecs[i], data.tvecs[i], data.camera_mat, data.dist_coeff)\n",
    "            err = cv2.norm(corners_reproj, corners[i], cv2.NORM_L2) / len(corners_reproj)\n",
    "            data.reproj_err.append(err)\n",
    "            \n",
    "    # 计算去畸变的新的相机内参,可以改变焦距和画幅\n",
    "    def _get_camera_mat_dst(self, camera_mat):\n",
    "        camera_mat_dst = camera_mat.copy()\n",
    "        camera_mat_dst[0][0] *= args.FOCAL_SCALE\n",
    "        camera_mat_dst[1][1] *= args.FOCAL_SCALE\n",
    "        camera_mat_dst[0][2] = args.FRAME_WIDTH / 2 * args.SIZE_SCALE\n",
    "        camera_mat_dst[1][2] = args.FRAME_HEIGHT / 2 * args.SIZE_SCALE\n",
    "        return camera_mat_dst\n",
    "    \n",
    "    # 计算去畸变的映射矩阵\n",
    "    def _get_undistort_maps(self):\n",
    "        data = self.data\n",
    "        camera_mat_dst = self._get_camera_mat_dst(data.camera_mat)\n",
    "        data.map1, data.map2 = cv2.fisheye.initUndistortRectifyMap(\n",
    "                                 data.camera_mat, data.dist_coeff, np.eye(3, 3), camera_mat_dst, \n",
    "                                 (int(args.FRAME_WIDTH * args.SIZE_SCALE), int(args.FRAME_HEIGHT * args.SIZE_SCALE)), cv2.CV_16SC2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cv2.calibrateCamera (objectPoints,         # 角点在棋盘中的空间坐标向量        \n",
    "#                      imagePoints,          # 角点在图像中的坐标向量\n",
    "#                      image_size,           # 图片大小\n",
    "#                      K,                    # 相机内参矩阵\n",
    "#                      D,                    # 畸变参数向量\n",
    "#                      rvecs,                # 旋转向量\n",
    "#                      tvecs,                # 平移向量\n",
    "#                      flags,                # 操作标志\n",
    "#                      criteria              # 迭代优化算法的停止标准\n",
    "#                     )\n",
    "\n",
    "# flags:\n",
    "#     cv2.CALIB_USE_INTRINSIC_GUESS          # 当相机内参矩阵包含有效的fx,fy,cx,cy初始值时,这些值会进一步进行优化\n",
    "#                                            # 否则,(cx,cy)初始化设置为图像中心(使用imageSize),并且以最小二乘法计算焦距\n",
    "#     cv2.CALIB_FIX_PRINCIPAL_POINT          # 固定光轴点(当设置CALIB_USE_INTRINSIC_GUESS时可以使用)\n",
    "#     cv2.CALIB_FIX_ASPECT_RATIO             # 固定fx/fy的值,函数仅将fy视为自由参数\n",
    "#     cv2.CALIB_ZERO_TANGENT_DIST            # 切向畸变系数(p1,p2) 设置为零并保持为零\n",
    "#     cv2.CALIB_FIX_FOCAL_LENGTH             # 如果设置了CALIB_USE_INTRINSIC_GUESS,则在全局优化过程中不会更改焦距\n",
    "#     cv2.CALIB_FIX_K1 (K1-K6)               # 固定相应的径向畸变系数为0或给定的初始值\n",
    "#     cv2.CALIB_RATIONAL_MODEL               # 理想模型:启用系数k4,k5和k6。此时返回8个或更多的系数\n",
    "#     cv2.CALIB_THIN_PRISM_MODEL             # 薄棱镜模型:启用系数s1,s2,s3和s4。此时返回12个或更多的系数\n",
    "#     cv2.CALIB_FIX_S1_S2_S3_S4              # 固定薄棱镜畸变系数为0或给定的初始值\n",
    "#     cv2.CALIB_TILTED_MODEL                 # 倾斜模型:启用系数tauX和tauY。此时返回14个系数\n",
    "#     cv2.CALIB_FIX_TAUX_TAUY                # 固定倾斜传感器模型的系数为0或给定的初始值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Normal:           # 平面相机\n",
    "    def __init__(self):\n",
    "        self.data = CalibData()\n",
    "        self.inited = False\n",
    "        self.BOARD = np.array([ [(j * args.SQUARE_SIZE, i * args.SQUARE_SIZE, 0.)]\n",
    "                               for i in range(args.BORAD_HEIGHT) \n",
    "                               for j in range(args.BORAD_WIDTH) ],dtype=np.float32)\n",
    "        \n",
    "    def update(self, corners, frame_size):\n",
    "        board = [self.BOARD] * len(corners)\n",
    "        if not self.inited:\n",
    "            self._update_init(board, corners, frame_size)\n",
    "            self.inited = True\n",
    "        else:\n",
    "            self._update_refine(board, corners, frame_size)\n",
    "        self._calc_reproj_err(corners)\n",
    "        self._get_undistort_maps()\n",
    "        \n",
    "    def _update_init(self, board, corners, frame_size):\n",
    "        data = self.data\n",
    "        data.type = \"NORMAL\"\n",
    "        data.camera_mat = np.eye(3, 3)\n",
    "        data.dist_coeff = np.zeros((5, 1))     # 畸变向量的尺寸根据使用模型修改\n",
    "        data.ok, data.camera_mat, data.dist_coeff, data.rvecs, data.tvecs = cv2.calibrateCamera(\n",
    "            board, corners, frame_size, data.camera_mat, data.dist_coeff, \n",
    "            criteria=(cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_COUNT, 30, 1e-6))\n",
    "        data.ok = data.ok and cv2.checkRange(data.camera_mat) and cv2.checkRange(data.dist_coeff)\n",
    "        \n",
    "    def _update_refine(self, board, corners, frame_size):\n",
    "        data = self.data\n",
    "        data.ok, data.camera_mat, data.dist_coeff, data.rvecs, data.tvecs = cv2.calibrateCamera(\n",
    "            board, corners, frame_size, data.camera_mat, data.dist_coeff,  \n",
    "            flags = cv2.CALIB_USE_INTRINSIC_GUESS,\n",
    "            criteria=(cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_COUNT, 10, 1e-6))\n",
    "        data.ok = data.ok and cv2.checkRange(data.camera_mat) and cv2.checkRange(data.dist_coeff)\n",
    "        \n",
    "    def _calc_reproj_err(self, corners):\n",
    "        if not self.inited: return\n",
    "        data = self.data\n",
    "        data.reproj_err = []\n",
    "        for i in range(len(corners)):\n",
    "            corners_reproj, _ = cv2.projectPoints(self.BOARD, data.rvecs[i], data.tvecs[i], data.camera_mat, data.dist_coeff)\n",
    "            err = cv2.norm(corners_reproj, corners[i], cv2.NORM_L2) / len(corners_reproj)\n",
    "            data.reproj_err.append(err)\n",
    "            \n",
    "    def _get_camera_mat_dst(self, camera_mat):\n",
    "        camera_mat_dst = camera_mat.copy()\n",
    "        camera_mat_dst[0][0] *= args.FOCAL_SCALE\n",
    "        camera_mat_dst[1][1] *= args.FOCAL_SCALE\n",
    "        camera_mat_dst[0][2] = args.FRAME_WIDTH / 2 * args.SIZE_SCALE\n",
    "        camera_mat_dst[1][2] = args.FRAME_HEIGHT / 2 * args.SIZE_SCALE\n",
    "        return camera_mat_dst\n",
    "    \n",
    "    def _get_undistort_maps(self):\n",
    "        data = self.data\n",
    "        camera_mat_dst = self._get_camera_mat_dst(data.camera_mat)\n",
    "        data.map1, data.map2 = cv2.initUndistortRectifyMap(\n",
    "                                 data.camera_mat, data.dist_coeff, np.eye(3, 3), camera_mat_dst, \n",
    "                                 (int(args.FRAME_WIDTH * args.SIZE_SCALE), int(args.FRAME_HEIGHT * args.SIZE_SCALE)), cv2.CV_16SC2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cv2.findChessboardCorners ( image,         # 棋盘图像\n",
    "#                             patternSize,   # 棋盘格行和列的【内角点】数量\n",
    "#                             corners,       # 输出数组\n",
    "#                             flags          # 操作标志\n",
    "#                             )\n",
    "# flags:\n",
    "#     CV_CALIB_CB_ADAPTIVE_THRESH            # 使用自适应阈值处理将图像转换为黑白图像\n",
    "#     CV_CALIB_CB_NORMALIZE_IMAGE            # 对图像进行归一化。\n",
    "#     CV_CALIB_CB_FILTER_QUADS               # 过滤在轮廓检索阶段提取的假四边形。\n",
    "#     CALIB_CB_FAST_CHECK                    # 对查找棋盘角的图像进行快速检查"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cv2.cornerSubPix (image,                        # 棋盘图像\n",
    "#                   corners,                      # 棋盘角点\n",
    "#                   winSize,                      # 搜索窗口边长的一半\n",
    "#                   zeroZone,                     # 搜索区域死区大小的一半, (-1,-1)代表无\n",
    "#                   criteria                      # 迭代停止标准\n",
    "#                  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cv2.fisheye.initUndistortRectifyMap (K,         # 相机内参矩阵\n",
    "#                                      D,         # 畸变向量\n",
    "#                                      R,         # 旋转矩阵\n",
    "#                                      P,         # 新的相机矩阵\n",
    "#                                      size,      # 输出图像大小\n",
    "#                                      m1type,    # 映射矩阵类型\n",
    "#                                      map1,      # 输出映射矩阵1\n",
    "#                                      map2       # 输出映射矩阵2\n",
    "#                                     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class InCalibrator:                  # 内参标定器\n",
    "    def __init__(self, camera):\n",
    "        if camera == 'fisheye':\n",
    "            self.camera = Fisheye()  # 鱼眼相机类\n",
    "        elif camera == 'normal':\n",
    "            self.camera = Normal()   # 普通相机类\n",
    "        else:\n",
    "            raise Exception(\"camera should be fisheye/normal\")\n",
    "        self.corners = []\n",
    "    \n",
    "    # 获取args参数,供外部调用修改参数\n",
    "    @staticmethod\n",
    "    def get_args():\n",
    "        return args\n",
    "    \n",
    "    # 获取棋盘格角点坐标\n",
    "    def get_corners(self, img):\n",
    "        ok, corners = cv2.findChessboardCorners(img, (args.BORAD_WIDTH, args.BORAD_HEIGHT),\n",
    "                      flags = cv2.CALIB_CB_ADAPTIVE_THRESH|cv2.CALIB_CB_NORMALIZE_IMAGE|cv2.CALIB_CB_FAST_CHECK)\n",
    "        if ok: \n",
    "            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)\n",
    "            # 角点坐标亚像素优化\n",
    "            corners = cv2.cornerSubPix(gray, corners, (args.SUBPIX_REGION, args.SUBPIX_REGION), (-1, -1),\n",
    "                                       (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 0.01))\n",
    "        return ok, corners\n",
    "    \n",
    "    # 在图上绘制棋盘格角点\n",
    "    def draw_corners(self, img):\n",
    "        ok, corners = self.get_corners(img)\n",
    "        cv2.drawChessboardCorners(img, (args.BORAD_WIDTH, args.BORAD_HEIGHT), corners, ok)\n",
    "        return img\n",
    "    \n",
    "    # 图像去畸变\n",
    "    def undistort(self, img):\n",
    "        data = self.camera.data\n",
    "        return cv2.remap(img, data.map1, data.map2, cv2.INTER_LINEAR)\n",
    "    \n",
    "    # 使用现有角点坐标标定\n",
    "    def calibrate(self, img):\n",
    "        if len(self.corners) >= args.CALIB_NUMBER:\n",
    "            self.camera.update(self.corners, img.shape[1::-1])  # 更新标定数据\n",
    "        return self.camera.data\n",
    "    \n",
    "    def __call__(self, raw_frame):\n",
    "        ok, corners = self.get_corners(raw_frame)\n",
    "        result = self.camera.data\n",
    "        if ok:\n",
    "            self.corners.append(corners)          # 加入新的角点坐标\n",
    "            result = self.calibrate(raw_frame)    # 得到标定结果\n",
    "        return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 居中裁剪\n",
    "def centerCrop(img,width,height):\n",
    "    if img.shape[1] < width or img.shape[0] < height:\n",
    "        raise Exception(\"CROP size should be smaller than original size\")\n",
    "    img = img[round((img.shape[0]-height)/2) : round((img.shape[0]-height)/2)+height,\n",
    "              round((img.shape[1]-width)/2) : round((img.shape[1]-width)/2)+width ]  \n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 筛选图片文件\n",
    "def get_images(PATH, NAME):\n",
    "    filePath = [os.path.join(PATH, x) for x in os.listdir(PATH) \n",
    "                if any(x.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])\n",
    "               ]                                                            # 得到给定路径下所有图片文件\n",
    "    filenames = [filename for filename in filePath if NAME in filename]     # 再筛选出包含给定名字的图片\n",
    "    if len(filenames) == 0:\n",
    "        raise Exception(\"from {} read images failed\".format(PATH))\n",
    "    return filenames"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CalibMode():\n",
    "    def __init__(self, calibrator, input_type, mode):\n",
    "        self.calibrator = calibrator\n",
    "        self.input_type = input_type\n",
    "        self.mode = mode\n",
    "    \n",
    "    # 图片预处理\n",
    "    def imgPreprocess(self, img):\n",
    "        if args.CROP_FLAG:          # 裁剪图片尺寸\n",
    "            img = centerCrop(img, args.FRAME_WIDTH, args.FRAME_HEIGHT)\n",
    "        elif args.RESIZE_FLAG:      # 缩放图片尺寸\n",
    "            img = cv2.resize(img, (args.FRAME_WIDTH, args.FRAME_HEIGHT))\n",
    "        return img\n",
    "    \n",
    "    # 设置相机\n",
    "    def setCamera(self, cap):\n",
    "        cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter.fourcc('M','J','P','G'))    # 设置编码格式为MJPG\n",
    "        cap.set(cv2.CAP_PROP_FRAME_WIDTH, args.FRAME_WIDTH)                      # 设置相机分辨率\n",
    "        cap.set(cv2.CAP_PROP_FRAME_HEIGHT, args.FRAME_HEIGHT)\n",
    "        cap.set(cv2.CAP_PROP_FPS, args.CAMERA_FPS)                               # 设置相机帧率\n",
    "        return cap\n",
    "    \n",
    "    # 运行标定程序\n",
    "    def runCalib(self, raw_frame, display_raw=True, display_undist=True):\n",
    "        calibrator = self.calibrator\n",
    "        raw_frame = self.imgPreprocess(raw_frame)\n",
    "        result = calibrator(raw_frame)\n",
    "        raw_frame = calibrator.draw_corners(raw_frame)\n",
    "        if display_raw:\n",
    "            cv2.namedWindow(\"raw_frame\", flags = cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)\n",
    "            cv2.imshow(\"raw_frame\", raw_frame)\n",
    "        if len(calibrator.corners) > args.CALIB_NUMBER and display_undist: \n",
    "            undist_frame = calibrator.undistort(raw_frame)\n",
    "            cv2.namedWindow(\"undist_frame\", flags = cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)\n",
    "            cv2.imshow(\"undist_frame\", undist_frame)   \n",
    "        cv2.waitKey(1)\n",
    "        return result\n",
    "    \n",
    "    # 图片输入自动标定\n",
    "    def imageAutoMode(self):\n",
    "        calibrator = self.calibrator\n",
    "        filenames = get_images(args.INPUT_PATH, args.IMAGE_FILE)\n",
    "        for filename in filenames:\n",
    "            print(filename)\n",
    "            raw_frame = cv2.imread(filename)\n",
    "            result = self.runCalib(raw_frame)\n",
    "            key = cv2.waitKey(1)\n",
    "            if key == 27: break\n",
    "        cv2.destroyAllWindows() \n",
    "        return result\n",
    "    \n",
    "    # 图片输入手动挑选 按空格键确认 其他键丢弃该图片\n",
    "    def imageManualMode(self):\n",
    "        filenames = get_images(args.INPUT_PATH, args.IMAGE_FILE)\n",
    "        for filename in filenames:\n",
    "            print(filename)\n",
    "            raw_frame = cv2.imread(filename)\n",
    "            raw_frame = self.imgPreprocess(raw_frame)\n",
    "            img = raw_frame.copy()\n",
    "            img = self.calibrator.draw_corners(img)\n",
    "            display = \"raw_frame: press SPACE to SELECT, other key to SKIP, press ESC to QUIT\"\n",
    "            cv2.namedWindow(display, flags = cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)\n",
    "            cv2.imshow(display, img)\n",
    "            key = cv2.waitKey(0)\n",
    "            if key == 32:\n",
    "                result = self.runCalib(raw_frame, display_raw = False)\n",
    "            if key == 27: break\n",
    "        cv2.destroyAllWindows() \n",
    "        return result\n",
    "    \n",
    "    # 视频输入自动标定\n",
    "    def videoAutoMode(self):\n",
    "        cap = cv2.VideoCapture(args.INPUT_PATH + args.VIDEO_FILE)\n",
    "        if not cap.isOpened(): \n",
    "            raise Exception(\"from {} read video failed\".format(args.INPUT_PATH + args.VIDEO_FILE))\n",
    "        frame_id = 0\n",
    "        while True:\n",
    "            key = cv2.waitKey(1)\n",
    "            ok, raw_frame = cap.read()\n",
    "            raw_frame = self.imgPreprocess(raw_frame)\n",
    "            if frame_id % args.FRAME_DELAY == 0:\n",
    "                if args.STORE_FLAG:  # 存储该帧图像\n",
    "                    cv2.imwrite(args.STORE_PATH + 'img_raw{}.jpg'.format(len(self.calibrator.corners)), raw_frame)\n",
    "                result = self.runCalib(raw_frame) \n",
    "                print(len(self.calibrator.corners))\n",
    "            frame_id += 1 \n",
    "            key = cv2.waitKey(1)\n",
    "            if key == 27: break\n",
    "        cap.release()\n",
    "        cv2.destroyAllWindows() \n",
    "        return result\n",
    "    \n",
    "    # 视频输入手动挑选 按空格键采集图片\n",
    "    def videoManualMode(self):\n",
    "        cap = cv2.VideoCapture(args.INPUT_PATH + args.VIDEO_FILE)\n",
    "        if not cap.isOpened(): \n",
    "            raise Exception(\"from {} read video failed\".format(args.INPUT_PATH + args.VIDEO_FILE))\n",
    "        while True:\n",
    "            key = cv2.waitKey(1)\n",
    "            ok, raw_frame = cap.read()\n",
    "            raw_frame = self.imgPreprocess(raw_frame)\n",
    "            display = \"raw_frame: press SPACE to capture image\"\n",
    "            cv2.namedWindow(display, flags = cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)\n",
    "            cv2.imshow(display, raw_frame)\n",
    "            if key == 32:\n",
    "                if args.STORE_FLAG:  # 存储该帧图像\n",
    "                    cv2.imwrite(args.STORE_PATH + 'img_raw{}.jpg'.format(len(self.calibrator.corners)), raw_frame)\n",
    "                result = self.runCalib(raw_frame) \n",
    "                print(len(self.calibrator.corners))\n",
    "            if key == 27: break\n",
    "        cap.release()\n",
    "        cv2.destroyAllWindows() \n",
    "        return result\n",
    "    \n",
    "    # 相机输入在线标定\n",
    "    def cameraAutoMode(self):\n",
    "        cap = cv2.VideoCapture(args.CAMERA_ID)\n",
    "        if not cap.isOpened(): \n",
    "            raise Exception(\"from {} read video failed\".format(args.CAMERA_ID))\n",
    "        cap = self.setCamera(cap)\n",
    "        frame_id = 0\n",
    "        start_flag = False\n",
    "        while True:\n",
    "            key = cv2.waitKey(1)\n",
    "            ok, raw_frame = cap.read()\n",
    "            raw_frame = self.imgPreprocess(raw_frame)\n",
    "            if key == 32: start_flag = True\n",
    "            if key == 27: break\n",
    "            if not start_flag:\n",
    "                cv2.putText(raw_frame, 'press SPACE to start!', (args.FRAME_WIDTH//4,args.FRAME_HEIGHT//2), \n",
    "                             cv2.FONT_HERSHEY_COMPLEX, 1.5, (0,0,255), 2)\n",
    "                cv2.imshow(\"raw_frame\", raw_frame)\n",
    "                continue\n",
    "            if frame_id % args.FRAME_DELAY == 0:\n",
    "                if args.STORE_FLAG:  # 存储该帧图像\n",
    "                    cv2.imwrite(args.STORE_PATH + 'img_raw{}.jpg'.format(len(self.calibrator.corners)), raw_frame)\n",
    "                result = self.runCalib(raw_frame) \n",
    "                print(len(self.calibrator.corners))\n",
    "            frame_id += 1 \n",
    "        cap.release()\n",
    "        cv2.destroyAllWindows() \n",
    "        return result\n",
    "    \n",
    "    # 相机输入手动挑选 按空格键采集图片\n",
    "    def cameraManualMode(self):\n",
    "        cap = cv2.VideoCapture(args.CAMERA_ID)\n",
    "        if not cap.isOpened(): \n",
    "            raise Exception(\"from {} read video failed\".format(args.CAMERA_ID))\n",
    "        cap = self.setCamera(cap)\n",
    "        frame_id = 0\n",
    "        while True:\n",
    "            key = cv2.waitKey(1)\n",
    "            ok, raw_frame = cap.read()\n",
    "            raw_frame = self.imgPreprocess(raw_frame)\n",
    "            display = \"raw_frame: press SPACE to capture image\"\n",
    "            cv2.namedWindow(display, flags = cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO)\n",
    "            cv2.imshow(display, raw_frame)\n",
    "            if key == 32:\n",
    "                if args.STORE_FLAG:  # 存储该帧图像\n",
    "                    cv2.imwrite(args.STORE_PATH + 'img_raw{}.jpg'.format(len(self.calibrator.corners)), raw_frame)\n",
    "                result = self.runCalib(raw_frame) \n",
    "                print(len(self.calibrator.corners))\n",
    "            if key == 27: break\n",
    "        cap.release()\n",
    "        cv2.destroyAllWindows() \n",
    "        return result\n",
    "\n",
    "    def __call__(self):\n",
    "        input_type = self.input_type\n",
    "        mode = self.mode\n",
    "        if input_type == 'image' and mode == 'auto':\n",
    "            result = self.imageAutoMode()\n",
    "        if input_type == 'image' and mode == 'manual':\n",
    "            result = self.imageManualMode()\n",
    "        if input_type == 'video' and mode == 'auto':\n",
    "            result = self.videoAutoMode()\n",
    "        if input_type == 'video' and mode == 'manual':\n",
    "            result = self.videoManualMode()\n",
    "        if input_type == 'camera' and mode == 'auto':\n",
    "            result = self.cameraAutoMode()\n",
    "        if input_type == 'camera' and mode == 'manual':\n",
    "            result = self.cameraManualMode()\n",
    "        return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main():\n",
    "    calibrator = InCalibrator(args.CAMERA_TYPE)                            # 初始化内参标定器\n",
    "    calib = CalibMode(calibrator, args.INPUT_TYPE, args.SELECT_MODE)       # 选择标定模式\n",
    "    result = calib()                                                       # 开始标定\n",
    "                  \n",
    "    if len(calibrator.corners) == 0:                      # 标定失败 未找到棋盘或参数设置错误\n",
    "        raise Exception(\"Calibration failed. Chessboard not found, check the parameters\")  \n",
    "    if len(calibrator.corners) < args.CALIB_NUMBER:       # 标定样本小于初始化标定所需的图片数\n",
    "        raise Exception(\"Warning: Calibration images are not enough. At least {} valid images are needed.\".format(args.CALIB_NUMBER))            \n",
    "\n",
    "    print(\"Calibration Complete\")\n",
    "    print(\"Camera Matrix is : {}\".format(result.camera_mat.tolist()))                 # 相机内参\n",
    "    print(\"Distortion Coefficient is : {}\".format(result.dist_coeff.tolist()))        # 畸变向量\n",
    "    print(\"Reprojection Error is : {}\".format(np.mean(result.reproj_err)))            # 平均重投影误差\n",
    "    np.save('camera_{}_K.npy'.format(args.CAMERA_ID),result.camera_mat.tolist())\n",
    "    np.save('camera_{}_D.npy'.format(args.CAMERA_ID),result.dist_coeff.tolist())      # 输出并存储数据\n",
    "        \n",
    "if __name__ == '__main__':\n",
    "    main()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}