# Copyright Niantic 2019. Patent Pending. All rights reserved. # # This software is licensed under the terms of the Monodepth2 licence # which allows for non-commercial use only, the full terms of which are made # available in the LICENSE file. from __future__ import absolute_import, division, print_function import os import hashlib import zipfile from six.moves import urllib import numpy as np import cv2 def visualize_depth(depth, mask=None, depth_min=None, depth_max=None, direct=False): """Visualize the depth map with colormap. Rescales the values so that depth_min and depth_max map to 0 and 1, respectively. """ if not direct: depth = 1.0 / (depth + 1e-6) invalid_mask = np.logical_or(np.isnan(depth), np.logical_not(np.isfinite(depth))) if mask is not None: invalid_mask += np.logical_not(mask) if depth_min is None: depth_min = np.percentile(depth[np.logical_not(invalid_mask)], 5) if depth_max is None: depth_max = np.percentile(depth[np.logical_not(invalid_mask)], 95) depth[depth < depth_min] = depth_min depth[depth > depth_max] = depth_max depth[invalid_mask] = depth_max depth_scaled = (depth - depth_min) / (depth_max - depth_min) depth_scaled_uint8 = np.uint8(depth_scaled * 255) depth_color = cv2.applyColorMap(depth_scaled_uint8, cv2.COLORMAP_MAGMA) depth_color[invalid_mask, :] = 0 return depth_color def compute_errors(gt, pred): """Computation of error metrics between predicted and ground truth depths """ thresh = np.maximum((gt / pred), (pred / gt)) a1 = (thresh < 1.25 ).mean() a2 = (thresh < 1.25 ** 2).mean() a3 = (thresh < 1.25 ** 3).mean() rmse = (gt - pred) ** 2 rmse = np.sqrt(rmse.mean()) rmse_log = (np.log(gt) - np.log(pred)) ** 2 rmse_log = np.sqrt(rmse_log.mean()) abs_rel = np.mean(np.abs(gt - pred) / gt) sq_rel = np.mean(((gt - pred) ** 2) / gt) return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 def readlines(filename): """Read all the lines in a text file and return as a list """ with open(filename, 'r') as f: lines = f.read().splitlines() return lines def normalize_image(x): """Rescale image pixels to span range [0, 1] """ ma = float(x.max().cpu().data) mi = float(x.min().cpu().data) d = ma - mi if ma != mi else 1e5 return (x - mi) / d def sec_to_hm(t): """Convert time in seconds to time in hours, minutes and seconds e.g. 10239 -> (2, 50, 39) """ t = int(t) s = t % 60 t //= 60 m = t % 60 t //= 60 return t, m, s def sec_to_hm_str(t): """Convert time in seconds to a nice string e.g. 10239 -> '02h50m39s' """ h, m, s = sec_to_hm(t) return "{:02d}h{:02d}m{:02d}s".format(h, m, s) def download_model_if_doesnt_exist(model_name): """If pretrained kitti model doesn't exist, download and unzip it """ # values are tuples of (, ) download_paths = { "mono_640x192": ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_640x192.zip", "a964b8356e08a02d009609d9e3928f7c"), "stereo_640x192": ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_640x192.zip", "3dfb76bcff0786e4ec07ac00f658dd07"), "mono+stereo_640x192": ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_640x192.zip", "c024d69012485ed05d7eaa9617a96b81"), "mono_no_pt_640x192": ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_no_pt_640x192.zip", "9c2f071e35027c895a4728358ffc913a"), "stereo_no_pt_640x192": ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_no_pt_640x192.zip", "41ec2de112905f85541ac33a854742d1"), "mono+stereo_no_pt_640x192": ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_no_pt_640x192.zip", "46c3b824f541d143a45c37df65fbab0a"), "mono_1024x320": ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono_1024x320.zip", "0ab0766efdfeea89a0d9ea8ba90e1e63"), "stereo_1024x320": ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/stereo_1024x320.zip", "afc2f2126d70cf3fdf26b550898b501a"), "mono+stereo_1024x320": ("https://storage.googleapis.com/niantic-lon-static/research/monodepth2/mono%2Bstereo_1024x320.zip", "cdc5fc9b23513c07d5b19235d9ef08f7"), } if not os.path.exists("models"): os.makedirs("models") model_path = os.path.join("models", model_name) def check_file_matches_md5(checksum, fpath): if not os.path.exists(fpath): return False with open(fpath, 'rb') as f: current_md5checksum = hashlib.md5(f.read()).hexdigest() return current_md5checksum == checksum # see if we have the model already downloaded... if not os.path.exists(os.path.join(model_path, "encoder.pth")): model_url, required_md5checksum = download_paths[model_name] if not check_file_matches_md5(required_md5checksum, model_path + ".zip"): print("-> Downloading pretrained model to {}".format(model_path + ".zip")) urllib.request.urlretrieve(model_url, model_path + ".zip") if not check_file_matches_md5(required_md5checksum, model_path + ".zip"): print(" Failed to download a file which matches the checksum - quitting") quit() print(" Unzipping model...") with zipfile.ZipFile(model_path + ".zip", 'r') as f: f.extractall(model_path) print(" Model unzipped to {}".format(model_path))