{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "775d1720-0fa3-40bf-a9e5-a8cc44744dca",
"metadata": {},
"outputs": [],
"source": [
"from IPython import display\n",
"import sys\n",
"sys.path.append(\"../\")\n",
"\n",
"import torch\n",
"import numpy as np\n",
"import cv2\n",
"from PIL import Image\n",
"\n",
"from constants import classes"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "76ed828a-a15a-489c-bfd8-dcf24d3db703",
"metadata": {},
"outputs": [],
"source": [
"path_to_model = \"../mvit16-1.pt\"\n",
"path_to_input_video = \"f17a6060-6ced-4bd1-9886-8578cfbb864f.mp4\"\n",
"path_to_output_video = \"output_torch.mp4\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4a36baad-bd49-4126-b98a-4f20b7919caf",
"metadata": {},
"outputs": [],
"source": [
"model = torch.jit.load(path_to_model)\n",
"model.eval()\n",
"window_size = 16 # from model name\n",
"threshold = 0.5\n",
"frame_interval = 1\n",
"mean = [123.675, 116.28, 103.53]\n",
"std = [58.395, 57.12, 57.375]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d72fb23e-3946-4b76-ac62-cfcc325ff657",
"metadata": {},
"outputs": [],
"source": [
"def resize(im, new_shape=(224, 224)):\n",
" \"\"\"\n",
" Resize and pad image while preserving aspect ratio.\n",
"\n",
" Parameters\n",
" ----------\n",
" im : np.ndarray\n",
" Image to be resized.\n",
" new_shape : Tuple[int]\n",
" Size of the new image.\n",
"\n",
" Returns\n",
" -------\n",
" np.ndarray\n",
" Resized image.\n",
" \"\"\"\n",
" shape = im.shape[:2] # current shape [height, width]\n",
" if isinstance(new_shape, int):\n",
" new_shape = (new_shape, new_shape)\n",
"\n",
" # Scale ratio (new / old)\n",
" r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])\n",
"\n",
" # Compute padding\n",
" new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))\n",
" dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding\n",
"\n",
" dw /= 2\n",
" dh /= 2\n",
"\n",
" if shape[::-1] != new_unpad: # resize\n",
" im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)\n",
" top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))\n",
" left, right = int(round(dw - 0.1)), int(round(dw + 0.1))\n",
" im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114)) # add border\n",
" return im"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "184ed911-6b9b-4250-a30b-c347e3be2ed1",
"metadata": {},
"outputs": [],
"source": [
"cap = cv2.VideoCapture(path_to_input_video)\n",
"_,frame = cap.read()\n",
"shape = frame.shape\n",
"fourcc = cv2.VideoWriter_fourcc(*'H264')\n",
"writer = cv2.VideoWriter(path_to_output_video, fourcc, 30, (frame.shape[1], frame.shape[0]+50))\n",
"\n",
"tensors_list = []\n",
"prediction_list = []\n",
"prediction_list.append(\"---\")\n",
"\n",
"frame_counter = 0\n",
"while True:\n",
" _, frame = cap.read()\n",
" if frame is None:\n",
" break\n",
" frame_counter += 1\n",
" if frame_counter == frame_interval:\n",
" image = cv2.cvtColor(frame.copy(), cv2.COLOR_BGR2RGB)\n",
" image = resize(image, (224, 224))\n",
" image = (image - mean) / std\n",
" image = np.transpose(image, [2, 0, 1])\n",
" tensors_list.append(image)\n",
" if len(tensors_list) == window_size:\n",
" input_tensor = np.stack(tensors_list[: window_size], axis=1)[None][None]\n",
" input_tensor = input_tensor.astype(np.float32)\n",
" input_tensor = torch.from_numpy(input_tensor)\n",
" with torch.no_grad():\n",
" outputs = model(input_tensor)[0]\n",
" gloss = str(classes[outputs.argmax().item()])\n",
" if outputs.max() > threshold:\n",
" if gloss != prediction_list[-1] and len(prediction_list):\n",
" if gloss != \"---\":\n",
" prediction_list.append(gloss)\n",
" tensors_list.clear()\n",
" frame_counter = 0\n",
"\n",
" text = \" \".join(prediction_list)\n",
" text_div = np.zeros((50, frame.shape[1], 3), dtype=np.uint8)\n",
" cv2.putText(text_div, text, (10, 30), cv2.FONT_HERSHEY_COMPLEX, 0.7, (255, 255, 255), 2)\n",
"\n",
" frame = np.concatenate((frame, text_div), axis=0)\n",
" writer.write(frame)\n",
"writer.release()\n",
"cap.release()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3c512a02-1d2b-4603-b3cd-9801216c3bdf",
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import display, HTML"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "53a41c5c-dcff-439b-a17a-07b9530525f8",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"video_tag = f''\n",
"display(HTML(video_tag))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "226bd051-c0f4-48eb-9555-5c24526cccf5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}