in fairscale/nn/data_parallel/fully_sharded_data_parallel.py [0:0]
def _init_param_attributes(self, p: Parameter) -> None:
"""
We manage several attributes on each Parameter instance. The first two
are set by :func:`_shard_parameters_`:
``_is_sharded``: ``True`` if the Parameter is sharded or ``False``
if the Parameter is intentionally not sharded (in which case we
will all-reduce grads for this param).
``_orig_size``: the size of the original Parameter (before sharding)
The remaining attributes are set here:
``_fp32_shard``: a single shard of the parameters in full precision
(typically FP32, but this is dependent on the dtype of the model
as it's passed in by the user). This can be on CPU or GPU
depending on the value of *``move_params_to_cpu``*.
``_fp16_shard``: This will be a single shard of the parameters in FP16, used for all-gather.
This can be in FP16 or FP32 depending on the value of *``compute_dtype``* and
if params are offloaded to CPU.
``_full_param_padded``: the full weight (padded to be evenly
divisible by ``world_size``), used for computation in the
forward and backward pass. This will be resized in place and
only materialized (via all-gather) as needed.
"""
assert hasattr(p, "_is_sharded") and hasattr(p, "_orig_size")
if hasattr(p, "_fp32_shard"):
return
# A single shard of the parameters in full precision.
# TODO(another-pjohnson) - I believe this will cause memory leakage with ssd
# p.data returns a pointer to a handle, and that handle has it's
# ref count incremented by p._fp32_shard. So this tensor will
# never be freed even if we do p.to_disk(). investigate after
# PR #887 is merged
p._fp32_shard = p.data
if self.mixed_precision:
assert p._fp32_shard.dtype == torch.float32
if self.move_params_to_cpu:
assert p._fp32_shard.device == torch.device("cpu")
# We don't pin memory when using ssd_offload since that results in OOM when
# the memory requirements of a model are larger than host memory.
if not self.ssd_offload:
# If we plan to keep the FP32 parameters on CPU, then pinning
# memory allows us to later use non-blocking transfers when moving
# the FP32 param shard to compute_device.
p._fp32_shard = p._fp32_shard.pin_memory()
p.data = p._fp32_shard
if self.move_params_to_cpu or self.mixed_precision:
# In mixed precision mode, we maintain a reduced precision
# (typically FP16) parameter shard on compute_device for performing
# the computation in the forward/backward pass. We resize the
# storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed. If offloading params to CPU, the
# dtype of the fp16 shard will depend on the *`compute_dtype`*.
p._fp16_shard = torch.zeros_like(p._fp32_shard, device=self.compute_device, dtype=self.compute_dtype)
free_storage_(p._fp16_shard)
if self.mixed_precision:
assert p._fp32_shard.dtype == torch.float32
if not self.mixed_precision and not self.move_params_to_cpu:
# use _fp32_shard if you are not in using mixed precision or
# offloading params and grads to CPU.
p._fp16_shard = None
# We also maintain a full-sized parameter of type self.compute_dtype
# (FP16 for mixed_precision or FP32 otherwise). We resize the
# storage to size 0 at init (here) and only materialize as needed. The
# storage may contain padding elements so that it is evenly divisible by
# world_size, although these padding elements will be removed before the
# relevant computation.
if p._is_sharded:
p._full_param_padded = torch.zeros(
p.data.numel() * self.world_size, device=self.compute_device, dtype=self.compute_dtype
)
free_storage_(p._full_param_padded)
if self.move_grads_to_cpu and self.training:
# We can optionally move the grad shard to CPU during the backward
# pass. In this case, it's important to pre-allocate the CPU grad
# shard in pinned memory so that we can do a non-blocking transfer.
# This is only needed during training and not evaluation.
if self.ssd_offload:
assert isinstance(p, SsdFlatParameter)
# Gradients also need to be offloaded to SSD otherwise it can result in
# OOMs when the memory requirements of a model are larger than host memory.
p._cpu_grad = ssd_offload.SsdTensorHandle.from_tensor(torch.zeros_like(p.data, device="cpu"))
p._cpu_grad.set_file_params(p.filename + "_grad", 0)
p._cpu_grad.to_file()
else:
p._cpu_grad = torch.zeros_like(p.data, device="cpu").pin_memory()