import torch import torch.nn as nn import numpy as np import math from PIL import Image import torchvision.transforms as transforms import torch.nn as nn import pyloudnorm as pyln def rgb_to_lab_torch(rgb: torch.Tensor) -> torch.Tensor: """ PyTorch GPU版本:RGB转Lab颜色空间(输入范围[0,1],张量形状任意,最后一维为通道数) 参考CIE 1931标准转换公式 """ # 转换为线性RGB(sRGB伽马校正逆过程) linear_rgb = torch.where( rgb > 0.04045, ((rgb + 0.055) / 1.055) ** 2.4, rgb / 12.92 ) # 线性RGB转XYZ(使用sRGB标准白点D65) xyz_from_rgb = torch.tensor([ [0.4124564, 0.3575761, 0.1804375], [0.2126729, 0.7151522, 0.0721750], [0.0193339, 0.1191920, 0.9503041] ], dtype=rgb.dtype, device=rgb.device) # 维度适配:确保输入为(B, ..., C),矩阵乘法后保持空间维度 shape = linear_rgb.shape linear_rgb_flat = linear_rgb.reshape(-1, 3) # (N, 3),N=B*T*H*W xyz_flat = linear_rgb_flat @ xyz_from_rgb.T # (N, 3) xyz = xyz_flat.reshape(shape) # 恢复原形状 # XYZ转Lab(使用D65白点参数) xyz_ref = torch.tensor([0.95047, 1.0, 1.08883], dtype=rgb.dtype, device=rgb.device) xyz_normalized = xyz / xyz_ref[None, None, None, None, :] # 广播适配(B, C, T, H, W) # 应用Lab转换公式 epsilon = 0.008856 kappa = 903.3 xyz_normalized = torch.clamp(xyz_normalized, 1e-8, 1.0) # 避免log(0) f_xyz = torch.where( xyz_normalized > epsilon, xyz_normalized ** (1/3), (kappa * xyz_normalized + 16) / 116 ) L = 116 * f_xyz[..., 1] - 16 # Y通道对应亮度 a = 500 * (f_xyz[..., 0] - f_xyz[..., 1]) # X-Y对应红绿 b = 200 * (f_xyz[..., 1] - f_xyz[..., 2]) # Y-Z对应蓝黄 lab = torch.stack([L, a, b], dim=-1) # 最后一维拼接为Lab通道 return lab def lab_to_rgb_torch(lab: torch.Tensor) -> torch.Tensor: """ PyTorch GPU版本:Lab转RGB颜色空间(输出范围[0,1],张量形状任意,最后一维为通道数) """ # Lab分离通道 L = lab[..., 0] a = lab[..., 1] b = lab[..., 2] # Lab转XYZ f_y = (L + 16) / 116 f_x = (a / 500) + f_y f_z = f_y - (b / 200) epsilon = 0.008856 kappa = 903.3 x = torch.where(f_x ** 3 > epsilon, f_x ** 3, (116 * f_x - 16) / kappa) y = torch.where(L > kappa * epsilon, ((L + 16) / 116) ** 3, L / kappa) z = torch.where(f_z ** 3 > epsilon, f_z ** 3, (116 * f_z - 16) / kappa) # 乘以D65白点参数 xyz_ref = torch.tensor([0.95047, 1.0, 1.08883], dtype=lab.dtype, device=lab.device) xyz = torch.stack([x, y, z], dim=-1) * xyz_ref[None, None, None, None, :] # XYZ转线性RGB rgb_from_xyz = torch.tensor([ [3.2404542, -1.5371385, -0.4985314], [-0.9692660, 1.8760108, 0.0415560], [0.0556434, -0.2040259, 1.0572252] ], dtype=lab.dtype, device=lab.device) # 维度适配:矩阵乘法 shape = xyz.shape xyz_flat = xyz.reshape(-1, 3) # (N, 3) linear_rgb_flat = xyz_flat @ rgb_from_xyz.T # (N, 3) linear_rgb = linear_rgb_flat.reshape(shape) # 恢复原形状 # 线性RGB转sRGB(伽马校正) rgb = torch.where( linear_rgb > 0.0031308, 1.055 * (linear_rgb ** (1/2.4)) - 0.055, 12.92 * linear_rgb ) # 确保输出在[0,1]范围内 rgb = torch.clamp(rgb, 0.0, 1.0) return rgb def match_and_blend_colors_torch( source_chunk: torch.Tensor, reference_image: torch.Tensor, strength: float ) -> torch.Tensor: """ 全GPU批量运算版本:将视频chunk的颜色匹配到参考图像并混合(支持B>1、T帧并行) Args: source_chunk (torch.Tensor): 视频chunk (B, C, T, H, W),范围[-1, 1] reference_image (torch.Tensor): 参考图像 (B, C, 1, H, W),范围[-1, 1](B需与source_chunk一致) strength (float): 颜色校正强度 (0.0-1.0),0.0无校正,1.0完全校正 Returns: torch.Tensor: 颜色校正后的视频chunk (B, C, T, H, W),范围[-1, 1] """ # 强度为0直接返回原图 if strength <= 0.0: return source_chunk.clone() # 验证强度范围 if not 0.0 <= strength <= 1.0: raise ValueError(f"Strength必须在0.0-1.0之间,当前值:{strength}") # 验证输入形状(确保B一致,参考图T=1) B, C, T, H, W = source_chunk.shape assert reference_image.shape == (B, C, 1, H, W), \ f"参考图像形状需为(B, C, 1, H, W),当前为{reference_image.shape}" assert C == 3, f"仅支持3通道RGB图像,当前通道数:{C}" # 保持设备和数据类型一致 device = source_chunk.device dtype = source_chunk.dtype reference_image = reference_image.to(device=device, dtype=dtype) # 1. 从[-1,1]转换到[0,1](GPU上直接运算) source_01 = (source_chunk + 1.0) / 2.0 ref_01 = (reference_image + 1.0) / 2.0 # 2. 调整维度顺序:(B, C, T, H, W) → (B, T, H, W, C)(适配颜色空间转换) # 参考图:(B, C, 1, H, W) → (B, 1, H, W, C) source_permuted = source_01.permute(0, 2, 3, 4, 1) # 通道移到最后一维 ref_permuted = ref_01.permute(0, 2, 3, 4, 1) # 3. RGB转Lab(批量处理所有帧) source_lab = rgb_to_lab_torch(source_permuted) ref_lab = rgb_to_lab_torch(ref_permuted) # (B, 1, H, W, 3) # 4. 批量颜色迁移:匹配L/a/b通道的均值和标准差(核心逻辑) # 计算参考图各通道的均值和标准差(对H、W维度求统计,保持B维度) ref_mean = ref_lab.mean(dim=[2, 3], keepdim=True) # (B, 1, 1, 1, 3) ref_std = ref_lab.std(dim=[2, 3], keepdim=True, unbiased=False) # (B, 1, 1, 1, 3) # 计算源视频各通道的均值和标准差(对H、W维度求统计,保持B、T维度) source_mean = source_lab.mean(dim=[2, 3], keepdim=True) # (B, T, 1, 1, 3) source_std = source_lab.std(dim=[2, 3], keepdim=True, unbiased=False) # (B, T, 1, 1, 3) # 避免标准差为0的除法错误(用1.0替代0) source_std_safe = torch.where(source_std < 1e-8, torch.ones_like(source_std), source_std) # 颜色迁移公式:(源 - 源均值) * (参考标准差/源标准差) + 参考均值 corrected_lab = (source_lab - source_mean) * (ref_std / source_std_safe) + ref_mean # 5. Lab转RGB(批量转换所有校正后的帧) corrected_rgb_01 = lab_to_rgb_torch(corrected_lab) # 6. 批量混合原始帧和校正帧(按强度加权) blended_rgb_01 = (1 - strength) * source_permuted + strength * corrected_rgb_01 # 7. 还原维度顺序和数值范围:(B, T, H, W, C) → (B, C, T, H, W),范围[0,1]→[-1,1] blended_rgb_01 = blended_rgb_01.permute(0, 4, 1, 2, 3) # 通道移回第二维 blended_rgb_minus1_1 = (blended_rgb_01 * 2.0) - 1.0 # 8. 确保输出格式正确(连续内存布局) output = blended_rgb_minus1_1.contiguous().to(device=device, dtype=dtype) return output def resize_and_centercrop(cond_image, target_size): """ Resize image or tensor to the target size without padding. """ # Get the original size if isinstance(cond_image, torch.Tensor): _, orig_h, orig_w = cond_image.shape else: orig_h, orig_w = cond_image.height, cond_image.width target_h, target_w = target_size # Calculate the scaling factor for resizing scale_h = target_h / orig_h scale_w = target_w / orig_w # Compute the final size scale = max(scale_h, scale_w) final_h = math.ceil(scale * orig_h) final_w = math.ceil(scale * orig_w) # Resize if isinstance(cond_image, torch.Tensor): if len(cond_image.shape) == 3: cond_image = cond_image[None] resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() # crop cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) cropped_tensor = cropped_tensor.squeeze(0) else: resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR) resized_image = np.array(resized_image) # tensor and crop resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous() cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) cropped_tensor = cropped_tensor[:, :, None, :, :] return cropped_tensor