--- layout: '@/layouts/Doc.astro' title: '🚑 Torchtune Patch on Aurora' date: 2025-03-23 description: 'A specific diff patch to resolve torchtune import issues with FSDP on Aurora.' --- Sam Foreman 2025-03-23 - [Patch to get `torchtune` working on Aurora](#patch-to-get-torchtune-working-on-aurora) ## Patch to get `torchtune` working on Aurora ```diff diff --git a/torchtune/training/_distributed.py b/torchtune/training/_distributed.py index ff959c5f..c3966290 100644 --- a/torchtune/training/_distributed.py +++ b/torchtune/training/_distributed.py @@ -14,7 +14,11 @@ import torch import torch.distributed as dist from torch import nn -from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard +try: + from torch.distributed._composable.fsdp import fully_shard +except (ImportError, ModuleNotFoundError): + from torch.distributed._composable.fsdp.fully_shard import fully_shard + from torch.distributed._tensor import distribute_tensor, DTensor from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta from torch.distributed.checkpoint.state_dict import ( @@ -532,6 +536,11 @@ def shard_model( """ fsdp_kwargs = {"reshard_after_forward": reshard_after_forward} if cpu_offload: + try: + from torch.distributed._composable.fsdp import CPUOffloadPolicy + except (ImportError, ModuleNotFoundError): + from torch.distributed._composable.fsdp._fsdp_api import MixedPrecisionPolicy, CPUOffloadPolicy + # from torch.distributed._composable import CPUOffloadPolicy fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() # Shard the model with FSDP, iterating in reverse to start with ```