in fairscale/nn/data_parallel/fully_sharded_data_parallel.py [0:0]
def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
"""
Gather all shards of params.
Note, this is idempotent if full params are already gathered. Callers
assume the idempotency. So please keep it that way.
Args:
force_full_precision (bool, Optional): by default params will be gathered
in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
``True``, in which case they will be gathered in full precision
(e.g., FP32), possibly in fresh storage. The parameter that's being
rebuilt will end up in full precision as well.
Returns:
A list of tuples, where the first element is the full-sized param
and the second element is a bool indicating if it's safe for the
caller to free the full-sized param. This will be ``None`` if
``force_full_precision=False`` and the full params are already gathered.
"""
output_tensors: List[Tuple[torch.Tensor, bool]] = []
def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
"""
Helper function to update p.data pointer.
Args:
custom_output_tensor (torch.Tensor, Optional): if not None, this
tensor contains the data we just gathered.
"""
if custom_output_tensor is not None:
assert p._is_sharded
p.data = custom_output_tensor
output_tensors.append((p.data, True))
elif not p._is_sharded:
if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
assert p._fp16_shard is not None
p.data = p._fp16_shard
output_tensors.append((p.data, True))
else:
# Here p.data == p._fp32_shard, so it's not safe to free.
output_tensors.append((p.data, False))
else:
p.data = p._full_param_padded
output_tensors.append((p.data, True))
# Trim any padding and reshape to match original size.
p.data = p.data[: p._orig_size.numel()].view(p._orig_size)
if self.ssd_offload:
for p in self.params:
assert isinstance(p, SsdFlatParameter)
p.to_tensor()
self.has_full_params = False
if self._has_shared_params:
# self.has_full_params flag can be out of sync if a shared param is
# sharded by another FSDP instance. An example is that in eval case
# with reshard_after_forward=False but the sharing instance has
# reshard_after_forward=True. Then, on the second forward, the
# other instance can shard the shared param and but this instance
# can mistakenly think the full param is already gathered from the
# has_full_params flag.
#
# Therefore, we update the flag accordingly here.
self.has_full_params = not any(p._full_param_padded.storage().size() == 0 for p in self.params)
# Early exit if we already have full params and don't need full precision.
if self.has_full_params and not force_full_precision:
for p in self.params:
update_p_data()
return output_tensors
self.has_full_params = True
with torch.cuda.stream(self._streams["all_gather"]):
if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
self._cast_fp32_param_shards_to_fp16()
if self.move_params_to_cpu:
if force_full_precision:
# If the compute_dtype and storage dtype are the same,
# use pinned memory. Otherwise move p.data to the compute
# device.
if self.params[0].dtype == self.compute_dtype:
self._cast_fp32_param_shards_to_fp16()
else:
for p in self.params:
p.data = p.data.to(self.compute_device)
for p in self.params:
if not p._is_sharded: # e.g., when world_size == 1
update_p_data()
else:
# Skip if already built. Only shared param can be rebuilt multiple times.
# A corner case is p._orig_size = (1,), which means the shape equality is
# not a perfect check. But we assume we don't share a param with shape (1,).
if p.data.shape == p._orig_size and hasattr(p, "_is_shared") and p._is_shared:
continue
# If self.move_params_to_cpu and force_full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather.
p_data = p.data.to(p._full_param_padded.device, non_blocking=True)
p_size = p._full_param_padded.size()
assert p_size.numel() % self.world_size == 0
if self.mixed_precision and force_full_precision:
# Allocate fresh tensor in full precision since we are in
# mixed precision and full precision rebuild is asked.
output_tensor = p_data.new_zeros(p_size)
else:
if p._full_param_padded.storage().size() != p_size.numel():
# Allocate based on full size from all shards.
alloc_storage_(p._full_param_padded, size=p_size)
output_tensor = p._full_param_padded
# Fill output_tensor with (p.data for each shard in self.world_size)
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist._all_gather_base(output_tensor, p_data, group=self.process_group)
else:
chunks = list(output_tensor.chunk(self.world_size))
dist.all_gather(chunks, p_data, group=self.process_group)
# Set p.data = output_tensor (with padding trimmed)
update_p_data(output_tensor)
if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
self._free_fp16_param_shard([p])
if self.move_params_to_cpu and (self.params[0].dtype == self.compute_dtype):
self._free_fp16_param_shard([p])
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
return output_tensors