{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "775d1720-0fa3-40bf-a9e5-a8cc44744dca", "metadata": {}, "outputs": [], "source": [ "from IPython import display\n", "import sys\n", "sys.path.append(\"../\")\n", "\n", "import torch\n", "import numpy as np\n", "import cv2\n", "from PIL import Image\n", "\n", "from constants import classes" ] }, { "cell_type": "code", "execution_count": 2, "id": "76ed828a-a15a-489c-bfd8-dcf24d3db703", "metadata": {}, "outputs": [], "source": [ "path_to_model = \"../mvit16-1.pt\"\n", "path_to_input_video = \"f17a6060-6ced-4bd1-9886-8578cfbb864f.mp4\"\n", "path_to_output_video = \"output_torch.mp4\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "4a36baad-bd49-4126-b98a-4f20b7919caf", "metadata": {}, "outputs": [], "source": [ "model = torch.jit.load(path_to_model)\n", "model.eval()\n", "window_size = 16 # from model name\n", "threshold = 0.5\n", "frame_interval = 1\n", "mean = [123.675, 116.28, 103.53]\n", "std = [58.395, 57.12, 57.375]" ] }, { "cell_type": "code", "execution_count": 4, "id": "d72fb23e-3946-4b76-ac62-cfcc325ff657", "metadata": {}, "outputs": [], "source": [ "def resize(im, new_shape=(224, 224)):\n", " \"\"\"\n", " Resize and pad image while preserving aspect ratio.\n", "\n", " Parameters\n", " ----------\n", " im : np.ndarray\n", " Image to be resized.\n", " new_shape : Tuple[int]\n", " Size of the new image.\n", "\n", " Returns\n", " -------\n", " np.ndarray\n", " Resized image.\n", " \"\"\"\n", " shape = im.shape[:2] # current shape [height, width]\n", " if isinstance(new_shape, int):\n", " new_shape = (new_shape, new_shape)\n", "\n", " # Scale ratio (new / old)\n", " r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])\n", "\n", " # Compute padding\n", " new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))\n", " dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding\n", "\n", " dw /= 2\n", " dh /= 2\n", "\n", " if shape[::-1] != new_unpad: # resize\n", " im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)\n", " top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))\n", " left, right = int(round(dw - 0.1)), int(round(dw + 0.1))\n", " im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) # add border\n", " return im" ] }, { "cell_type": "code", "execution_count": 5, "id": "184ed911-6b9b-4250-a30b-c347e3be2ed1", "metadata": {}, "outputs": [], "source": [ "cap = cv2.VideoCapture(path_to_input_video)\n", "_,frame = cap.read()\n", "shape = frame.shape\n", "fourcc = cv2.VideoWriter_fourcc(*'H264')\n", "writer = cv2.VideoWriter(path_to_output_video, fourcc, 30, (frame.shape[1], frame.shape[0]+50))\n", "\n", "tensors_list = []\n", "prediction_list = []\n", "prediction_list.append(\"---\")\n", "\n", "frame_counter = 0\n", "while True:\n", " _, frame = cap.read()\n", " if frame is None:\n", " break\n", " frame_counter += 1\n", " if frame_counter == frame_interval:\n", " image = cv2.cvtColor(frame.copy(), cv2.COLOR_BGR2RGB)\n", " image = resize(image, (224, 224))\n", " image = (image - mean) / std\n", " image = np.transpose(image, [2, 0, 1])\n", " tensors_list.append(image)\n", " if len(tensors_list) == window_size:\n", " input_tensor = np.stack(tensors_list[: window_size], axis=1)[None][None]\n", " input_tensor = input_tensor.astype(np.float32)\n", " input_tensor = torch.from_numpy(input_tensor)\n", " with torch.no_grad():\n", " outputs = model(input_tensor)[0]\n", " gloss = str(classes[outputs.argmax().item()])\n", " if outputs.max() > threshold:\n", " if gloss != prediction_list[-1] and len(prediction_list):\n", " if gloss != \"---\":\n", " prediction_list.append(gloss)\n", " tensors_list.clear()\n", " frame_counter = 0\n", "\n", " text = \" \".join(prediction_list)\n", " text_div = np.zeros((50, frame.shape[1], 3), dtype=np.uint8)\n", " cv2.putText(text_div, text, (10, 30), cv2.FONT_HERSHEY_COMPLEX, 0.7, (255, 255, 255), 2)\n", "\n", " frame = np.concatenate((frame, text_div), axis=0)\n", " writer.write(frame)\n", "writer.release()\n", "cap.release()" ] }, { "cell_type": "code", "execution_count": 6, "id": "3c512a02-1d2b-4603-b3cd-9801216c3bdf", "metadata": {}, "outputs": [], "source": [ "from IPython.display import display, HTML" ] }, { "cell_type": "code", "execution_count": 7, "id": "53a41c5c-dcff-439b-a17a-07b9530525f8", "metadata": {}, "outputs": [ { "data": { "text/html": [ "