import os import torch import torch.distributed as dist from loguru import logger def load_pt_safetensors(in_path, remove_key=None, include_keys=None): include_keys = include_keys or [] ext = os.path.splitext(in_path)[-1] if ext in (".pt", ".pth", ".tar"): state_dict = torch.load(in_path, map_location="cpu", weights_only=True) keys_to_keep = [] for key in state_dict.keys(): if include_keys: if any(inc_key in key for inc_key in include_keys): keys_to_keep.append(key) else: if not (remove_key and remove_key in key): keys_to_keep.append(key) state_dict = {k: state_dict[k] for k in keys_to_keep} else: import safetensors tensors = {} with safetensors.safe_open(in_path, framework="pt", device="cpu") as f: for key in f.keys(): if include_keys: if any(inc_key in key for inc_key in include_keys): tensors[key] = f.get_tensor(key) else: if not (remove_key and remove_key in key): tensors[key] = f.get_tensor(key) state_dict = tensors return state_dict def load_weights(checkpoint_path, cpu_offload=False, remove_key=None, load_from_rank0=False, include_keys=None): if not dist.is_initialized() or not load_from_rank0: logger.info(f"Loading weights from {checkpoint_path}") return load_pt_safetensors(checkpoint_path, remove_key, include_keys) is_weight_loader = dist.get_rank() == 0 cpu_weight_dict = {} if is_weight_loader: logger.info(f"Loading weights from {checkpoint_path}") cpu_weight_dict = load_pt_safetensors(checkpoint_path, remove_key) meta_dict = {} if is_weight_loader: for key, tensor in cpu_weight_dict.items(): meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype} obj_list = [meta_dict] if is_weight_loader else [None] dist.broadcast_object_list(obj_list, src=0) synced_meta_dict = obj_list[0] current_rank = dist.get_rank() if cpu_offload: target_device = "cpu" distributed_weight_dict = { key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items() } dist.barrier() else: target_device = torch.device(f"cuda:{current_rank}") distributed_weight_dict = { key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device) for key, meta in synced_meta_dict.items() } dist.barrier(device_ids=[torch.cuda.current_device()]) for key in sorted(synced_meta_dict.keys()): tensor_to_broadcast = distributed_weight_dict[key] if is_weight_loader: tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True) if cpu_offload: if is_weight_loader: gpu_tensor = tensor_to_broadcast.cuda() dist.broadcast(gpu_tensor, src=0) tensor_to_broadcast.copy_(gpu_tensor.cpu(), non_blocking=True) del gpu_tensor torch.cuda.empty_cache() else: gpu_tensor = torch.empty_like(tensor_to_broadcast, device="cuda") dist.broadcast(gpu_tensor, src=0) tensor_to_broadcast.copy_(gpu_tensor.cpu(), non_blocking=True) del gpu_tensor torch.cuda.empty_cache() else: dist.broadcast(tensor_to_broadcast, src=0) if is_weight_loader: del cpu_weight_dict if cpu_offload: torch.cuda.empty_cache() logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}") return distributed_weight_dict