from functools import partial import math import argparse from collections import OrderedDict from typing import Union, Dict, Optional, Tuple, Callable, Mapping, Any import warnings import requests from PIL import Image import numpy as np import torch from torch import nn from torch import Tensor import torch.nn.functional as F from einops import repeat, rearrange ######################################################################## # Utility functions def pair(t): if t is None: return None if isinstance(t, tuple): return t elif isinstance(t, list): return tuple(t) else: return (t, t) def _no_grad_trunc_normal_(tensor, mean, std, a, b): # Cut & paste from PyTorch official master until it's in a few official releases - RW # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf def norm_cdf(x): # Computes standard normal cumulative distribution function return (1. + math.erf(x / math.sqrt(2.))) / 2. if (mean < a - 2 * std) or (mean > b + 2 * std): warnings.warn( "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " "The distribution of values may be incorrect.", stacklevel=2 ) with torch.no_grad(): # Values are generated by using a truncated uniform distribution and # then using the inverse CDF for the normal distribution. # Get upper and lower cdf values l = norm_cdf((a - mean) / std) u = norm_cdf((b - mean) / std) # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. tensor.uniform_(2 * l - 1, 2 * u - 1) # Use inverse cdf transform for normal distribution to get truncated # standard normal tensor.erfinv_() # Transform to proper mean, std tensor.mul_(std * math.sqrt(2.)) tensor.add_(mean) # Clamp to ensure it's in the proper range tensor.clamp_(min=a, max=b) return tensor def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): # type: (Tensor, float, float, float, float) -> Tensor r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value Examples: >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ return _no_grad_trunc_normal_(tensor, mean, std, a, b) def build_2d_sincos_posemb(h, w, embed_dim=1024, temperature=10000.): """Sine-cosine positional embeddings from MoCo-v3 Source: https://github.com/facebookresearch/moco-v3/blob/main/vits.py """ grid_w = torch.arange(w, dtype=torch.float32) grid_h = torch.arange(h, dtype=torch.float32) grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij') assert embed_dim % 4 == 0, ( 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' ) pos_dim = embed_dim // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim omega = 1. / (temperature ** omega) out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] pos_emb = rearrange(pos_emb, 'b (h w) d -> b d h w', h=h, w=w, d=embed_dim) return pos_emb def load_process_image(link: str): img = Image.open(requests.get(link, stream=True).raw) img = img.resize((512, 512), Image.Resampling.BILINEAR) # Resize to 512x512 img = img.convert('L') # Convert to grayscale img = np.array(img) img = torch.tensor(img).unsqueeze(0).unsqueeze(0).float() img = img / 255.0 # Normalize to [0, 1] print('Input:', img.dtype, img.shape, img.min(), img.max()) return img ######################################################################## # Input adapter class PatchedInputAdapter(nn.Module): """Adapter for spatial inputs, like images or feature maps. Creates tokens from patches over the image. Args: num_channels: Number of input channels of the image/feature map stride_level: Stride level compared to the full-sized image. E.g. 4 for 1/4th the size of the image. patch_size_full: Int or tuple of the patch size over the full image size. Patch size for smaller inputs will be computed accordingly. dim_tokens: Dimension of output tokens. Can be set using init method. sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings learnable_pos_emb: Set to True to learn positional embeddings instead image_size: Default image size. Used to initialize size of positional embeddings. """ def __init__( self, num_channels: int, stride_level: int, patch_size_full: Union[int, Tuple[int,int]], dim_tokens: Optional[int] = None, sincos_pos_emb: bool = True, learnable_pos_emb: bool = False, image_size: Union[int, Tuple[int]] = 224 ): super().__init__() self.num_channels = num_channels self.stride_level = stride_level self.patch_size_full = pair(patch_size_full) self.dim_tokens = dim_tokens self.sincos_pos_emb = sincos_pos_emb self.learnable_pos_emb = learnable_pos_emb self.image_size = pair(image_size) print(f'Image size: {self.image_size}, Patch size: {self.patch_size_full}') self.num_patches = (self.image_size[0] // patch_size_full[0]) * (self.image_size[1] // patch_size_full[1]) # Actual patch height and width, taking into account stride of input self.P_H = max(1, self.patch_size_full[0] // stride_level) self.P_W = max(1, self.patch_size_full[1] // stride_level) if self.dim_tokens is not None: self.init(dim_tokens=dim_tokens) def init(self, dim_tokens: int = 768): """ Initialize parts of encoder that are dependent on dimension of tokens. Should be called when setting up MIRAGE. Args: dim_tokens: Dimension of tokens """ self.dim_tokens = dim_tokens # Task embedding identifying from which task a given token comes from # Fixed-size positional embeddings. Can be interpolated to different input sizes h_posemb = self.image_size[0] // (self.stride_level * self.P_H) w_posemb = self.image_size[1] // (self.stride_level * self.P_W) if self.sincos_pos_emb: self.pos_emb = build_2d_sincos_posemb(h=h_posemb, w=w_posemb, embed_dim=self.dim_tokens) self.pos_emb = nn.Parameter(self.pos_emb, requires_grad=self.learnable_pos_emb) else: self.pos_emb = nn.Parameter(torch.zeros(1, self.dim_tokens, h_posemb, w_posemb)) trunc_normal_(self.pos_emb, std=0.02) # Image -> tokens projection self.proj = nn.Conv2d( in_channels=self.num_channels, out_channels=self.dim_tokens, kernel_size=(self.P_H, self.P_W), stride=(self.P_H, self.P_W) ) @torch.jit.ignore # type: ignore def no_weight_decay(self): return {'pos_emb'} def forward(self, x): """ Forward pass through input adapter, transforming image to sequence of tokens. Adds task and positional encodings. Args: x: Input image tensor """ _B, _C, H, W = x.shape assert self.dim_tokens is not None, 'Need to call init(dim_tokens) function first' assert (H % self.P_H == 0) and (W % self.P_W == 0), f'Image sizes {H}x{W} must be divisible by patch sizes {self.P_H}x{self.P_W}' N_H, N_W = H // self.P_H, W // self.P_W # Number of patches in height and width # Create patches [B, C, H, W] -> [B, (H*W), C] x_patch = rearrange(self.proj(x), 'b d nh nw -> b (nh nw) d') # Create positional embedding x_pos_emb = F.interpolate(self.pos_emb, size=(N_H, N_W), mode='bicubic', align_corners=False) x_pos_emb = rearrange(x_pos_emb, 'b d nh nw -> b (nh nw) d') # Add patches and positional embeddings x = x_patch + x_pos_emb return x ######################################################################## # MIRAGE building blocks class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ @staticmethod def _drop_path(x, drop_prob: float = 0., training: bool = False): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output def __init__(self, drop_prob=0.0): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return self._drop_path(x, self.drop_prob, self.training) def extra_repr(self) -> str: return 'p={}'.format(self.drop_prob) class Mlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0. ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = attn_drop self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape( B, N, 3, self.num_heads, C // self.num_heads ).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # for torchscript (cannot use tensor as tuple) x = F.scaled_dot_product_attention( q, k, v, scale=self.scale, dropout_p=self.attn_drop ).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4., qkv_bias: bool = False, drop: float = 0., attn_drop: float = 0., drop_path: float = 0., act_layer: Callable = nn.GELU, norm_layer: Callable = nn.LayerNorm ): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, # type: ignore drop=drop ) def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x ######################################################################## # MIRAGE model and wrapper class MIRAGELight(nn.Module): """MultiViT: Multi-modal Vision Transformer This is MIRAGE without masking and with a simplified / faster forward pass Args: input_adapters: Dictionary of task -> input adapters output_adapters: Optional dictionary of task -> output adapters num_global_tokens: Number of additional global tokens to add (like cls tokens), default is 1 dim_tokens: Dimension of encoder tokens depth: Depth of encoder num_heads: Number of attention heads mlp_ratio: MLP hidden dim ratio qkv_bias: Set to False to disable bias drop_rate: Dropout after MLPs and Attention attn_drop_rate: Attention matrix drop rate drop_path_rate: DropPath drop rate norm_layer: Type of normalization layer """ def __init__( self, args, input_adapters: Dict[str, nn.Module], output_adapters: Optional[Dict[str, nn.Module]], num_global_tokens: int = 1, dim_tokens: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, norm_layer: partial[nn.LayerNorm] = partial(nn.LayerNorm, eps=1e-6) ): super().__init__() self.args = args # Initialize input and output adapters for adapter in input_adapters.values(): adapter.init(dim_tokens=dim_tokens) self.input_adapters = nn.ModuleDict(input_adapters) if output_adapters is not None: for adapter in output_adapters.values(): adapter.init(dim_tokens_enc=dim_tokens) self.output_adapters = nn.ModuleDict(output_adapters) else: self.output_adapters = None # Additional learnable tokens that can be used by encoder to process/store global information self.num_global_tokens = num_global_tokens self.global_tokens = nn.Parameter(torch.zeros(1, num_global_tokens, dim_tokens)) trunc_normal_(self.global_tokens, std=0.02) # Transformer encoder dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.encoder = nn.Sequential(*[ Block( dim=dim_tokens, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer ) for i in range(depth) ]) self.apply(self._init_weights) for name, m in self.named_modules(): if isinstance(m, nn.Linear): if 'qkv' in name: # treat the weights of Q, K, V separately val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) nn.init.uniform_(m.weight, -val, val) elif 'kv' in name: # treat the weights of K, V separately val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1])) nn.init.uniform_(m.weight, -val, val) elif isinstance(m, nn.Conv2d): if '.proj' in name: # From MAE, initialize projection like nn.Linear (instead of nn.Conv2d) w = m.weight.data nn.init.xavier_uniform_(w.view([w.shape[0], -1])) self.input_info = None self.token_dist = None def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def get_num_layers(self): return len(self.encoder) @torch.jit.ignore # type: ignore def no_weight_decay(self): no_wd_set = {'global_tokens'} for task, adapter in self.input_adapters.items(): if hasattr(adapter, 'no_weight_decay'): to_skip = adapter.no_weight_decay() to_skip = set([f'input_adapters.{task}.{name}' for name in to_skip]) no_wd_set = no_wd_set | to_skip if self.output_adapters is not None: for task, adapter in self.output_adapters.items(): if hasattr(adapter, 'no_weight_decay'): to_skip = adapter.no_weight_decay() to_skip = set([f'output_adapters.{task}.{name}' for name in to_skip]) no_wd_set = no_wd_set | to_skip return no_wd_set def generate_input_info(self, input_task_tokens, image_size): input_info = OrderedDict() i = 0 input_info['tasks'] = {} for domain, tensor in input_task_tokens.items(): num_tokens = tensor.shape[1] d = { 'num_tokens': num_tokens, 'has_posemb': True, 'start_idx': i, 'end_idx': i + num_tokens, } if isinstance(image_size, dict): d['image_size'] = image_size[domain] if self.args.grid_sizes is not None: d['grid_size'] = self.args.grid_sizes[domain] i += num_tokens input_info['tasks'][domain] = d if isinstance(image_size, int): input_info['image_size'] = image_size input_info['num_task_tokens'] = i input_info['num_global_tokens'] = self.num_global_tokens return input_info def process_input(self, x): # If input x is a Tensor, assume it's RGB x = {'bscan': x} if isinstance(x, torch.Tensor) else x # Need image size for tokens->image reconstruction if 'bscan' in x: B, _, H, W = x['bscan'].shape elif 'semseg' in x: B, H, W = x['semseg'].shape H *= self.input_adapters['semseg'].stride_level W *= self.input_adapters['semseg'].stride_level else: # TODO: Deal with case where not all have same shape B, _, H, W = list(x.values())[0].shape # Encode selected inputs to tokens input_task_tokens = { domain: self.input_adapters[domain](tensor) for domain, tensor in x.items() if domain in self.input_adapters } input_info = self.generate_input_info(input_task_tokens=input_task_tokens, image_size=(H, W)) input_tokens = torch.cat([task_tokens for task_tokens in input_task_tokens.values()], dim=1) # Add global tokens to input tokens global_tokens = repeat(self.global_tokens, '() n d -> b n d', b=B) input_tokens = torch.cat([input_tokens, global_tokens], dim=1) return input_tokens, input_info def forward( # type: ignore self, x: Union[Dict[str, torch.Tensor], torch.Tensor], return_all_layers=False, **kwargs ): """Forward pass through input adapters, transformer encoder and output adapters. Args: x: Input tensor or dictionary of tensors return_all_layers: Set to True to return all transformer layers """ input_tokens, input_info = self.process_input(x) # Pass tokens through Transformer if not return_all_layers: encoder_tokens = self.encoder(input_tokens) else: # Optionally access every intermediate layer encoder_tokens = [] tokens = input_tokens for block in self.encoder: tokens = block(tokens) encoder_tokens.append(tokens) if self.output_adapters is None: return encoder_tokens # Decode tokens for each task using task-specific output adapters preds = { domain: self.output_adapters[domain]( encoder_tokens=encoder_tokens, input_info=input_info, ) for domain in self.output_adapters } return preds class MIRAGEWrapper(nn.Module): def __init__( self, input_size=512, patch_size=32, modalities='bscan-slo', size='base', # 'base' or 'large' ): super().__init__() self.domain_conf = { 'bscan': self.default_domain_conf(), 'slo': self.default_domain_conf(), } self.size = size args = argparse.Namespace() args.num_global_tokens = 1 args.drop_path = 0.0 args.in_domains = modalities.split('-') input_size = pair(input_size) patch_size = pair(patch_size) assert input_size is not None assert patch_size is not None args.patch_size = {} args.input_size = {} args.grid_sizes = {} for domain in args.in_domains: args.patch_size[domain] = patch_size args.input_size[domain] = input_size args.grid_sizes[domain] = [] for i in range(len(input_size)): args.grid_sizes[domain].append(input_size[i] // patch_size[i]) self.args = args self.model = self.get_model() def default_domain_conf(self): return { 'channels': 1, 'stride_level': 1, 'input_adapter': partial(PatchedInputAdapter, num_channels=1), 'output_adapter': None, } def get_model(self): """Creates and returns model from arguments.""" print(f"Creating MIRAGE model for inputs {self.args.in_domains}") input_adapters = { domain: self.domain_conf[domain]['input_adapter']( stride_level=self.domain_conf[domain]['stride_level'], patch_size_full=tuple(self.args.patch_size[domain]), image_size=self.args.input_size[domain], ) for domain in self.args.in_domains } common_args = dict( args=self.args, input_adapters=input_adapters, output_adapters=None, num_global_tokens=self.args.num_global_tokens, drop_path_rate=self.args.drop_path, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), ) if self.size == 'large': model = MIRAGELight( **common_args, dim_tokens=1024, depth=24, num_heads=16, ) elif self.size == 'base': model = MIRAGELight( **common_args, dim_tokens=768, depth=12, num_heads=12, ) else: raise ValueError('Unknown model size:', self.size) return model def forward(self, x: dict): """ Args: Dict[x, (B, C, H, W) tensor]. H and W are determined by the input_size parameter in the constructor. It expects a tensor in the range [0, 1]. Returns: (B, C, H, W) tensor """ return self.model(x) def load_state_dict( self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False ): return self.model.load_state_dict(state_dict, strict, assign) @property def device(self): return next(self.parameters()).device ######################################################################## # Main function to test the MIRAGE model if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--size', type=str, default='base', help='Model size', choices=['base', 'large']) parser.add_argument('--device', type=str, default='cuda', help='Device to use', choices=['cuda', 'cpu']) parser.add_argument('--modalities', type=str, default='bscan', help='Input modality', choices=['bscan', 'slo', 'bscan-slo']) args = parser.parse_args() input_modalities = args.modalities.split('-') bscan_link = 'https://upload.wikimedia.org/wikipedia/commons/2/2d/Macular_OCT_depicting_Central_Serous_Chorioretinopathy_in_the_Left_Eye.png' slo_link = 'https://upload.wikimedia.org/wikipedia/commons/6/6a/Ocular_OCT_OS_IR30_overlay.jpg' input_data = {} if 'bscan' in input_modalities: input_data['bscan'] = load_process_image(bscan_link) if 'slo' in input_modalities: input_data['slo'] = load_process_image(slo_link) print('Input data:') for domain, tensor in input_data.items(): print(f' {domain}: {tensor.shape}, min: {tensor.min()}, max: {tensor.max()}') # NOTE: ViT-Base and ViT-Large versions of MIRAGE are available if args.size == 'base': weights = '../__weights/MIRAGE-Base.pth' else: weights = '../__weights/MIRAGE-Large.pth' model = MIRAGEWrapper() model.eval() state_dict = torch.load(weights, map_location='cpu', weights_only=False) model_state_dict = state_dict["model"] msg = model.model.load_state_dict(model_state_dict, strict=False) print(' # Missing keys:', len(msg.missing_keys)) print(' # Unexpected keys:', len(msg.unexpected_keys)) print(f'Using device: {model.device}') with torch.no_grad(): out = model(input_data) print('Features:', out.shape)