def _rebuild_full_params()

in fairscale/nn/data_parallel/fully_sharded_data_parallel.py [0:0]


    def _rebuild_full_params(self, force_full_precision: bool = False) -> Optional[List[Tuple[torch.Tensor, bool]]]:
        """
        Gather all shards of params.

        Note, this is idempotent if full params are already gathered. Callers
        assume the idempotency. So please keep it that way.

        Args:
            force_full_precision (bool, Optional): by default params will be gathered
                in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
                ``True``, in which case they will be gathered in full precision
                (e.g., FP32), possibly in fresh storage. The parameter that's being
                rebuilt will end up in full precision as well.

        Returns:
            A list of tuples, where the first element is the full-sized param
            and the second element is a bool indicating if it's safe for the
            caller to free the full-sized param. This will be ``None`` if
            ``force_full_precision=False`` and the full params are already gathered.
        """
        output_tensors: List[Tuple[torch.Tensor, bool]] = []

        def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
            """
            Helper function to update p.data pointer.

            Args:
                custom_output_tensor (torch.Tensor, Optional): if not None, this
                tensor contains the data we just gathered.
            """
            if custom_output_tensor is not None:
                assert p._is_sharded
                p.data = custom_output_tensor
                output_tensors.append((p.data, True))
            elif not p._is_sharded:
                if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
                    assert p._fp16_shard is not None
                    p.data = p._fp16_shard
                    output_tensors.append((p.data, True))
                else:
                    # Here p.data == p._fp32_shard, so it's not safe to free.
                    output_tensors.append((p.data, False))
            else:
                p.data = p._full_param_padded
                output_tensors.append((p.data, True))
            # Trim any padding and reshape to match original size.
            p.data = p.data[: p._orig_size.numel()].view(p._orig_size)

        if self.ssd_offload:
            for p in self.params:
                assert isinstance(p, SsdFlatParameter)
                p.to_tensor()

            self.has_full_params = False

        if self._has_shared_params:
            # self.has_full_params flag can be out of sync if a shared param is
            # sharded by another FSDP instance. An example is that in eval case
            # with reshard_after_forward=False but the sharing instance has
            # reshard_after_forward=True. Then, on the second forward, the
            # other instance can shard the shared param and but this instance
            # can mistakenly think the full param is already gathered from the
            # has_full_params flag.
            #
            # Therefore, we update the flag accordingly here.
            self.has_full_params = not any(p._full_param_padded.storage().size() == 0 for p in self.params)

        # Early exit if we already have full params and don't need full precision.
        if self.has_full_params and not force_full_precision:
            for p in self.params:
                update_p_data()
            return output_tensors

        self.has_full_params = True

        with torch.cuda.stream(self._streams["all_gather"]):
            if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
                self._cast_fp32_param_shards_to_fp16()

            if self.move_params_to_cpu:
                if force_full_precision:
                    # If the compute_dtype and storage dtype are the same,
                    # use pinned memory. Otherwise move p.data to the compute
                    # device.
                    if self.params[0].dtype == self.compute_dtype:
                        self._cast_fp32_param_shards_to_fp16()
                    else:
                        for p in self.params:
                            p.data = p.data.to(self.compute_device)

            for p in self.params:
                if not p._is_sharded:  # e.g., when world_size == 1
                    update_p_data()
                else:
                    # Skip if already built. Only shared param can be rebuilt multiple times.
                    # A corner case is p._orig_size = (1,), which means the shape equality is
                    # not a perfect check. But we assume we don't share a param with shape (1,).
                    if p.data.shape == p._orig_size and hasattr(p, "_is_shared") and p._is_shared:
                        continue
                    # If self.move_params_to_cpu and force_full_precision, we need to cast
                    # the FP32 CPU param to CUDA for the all-gather.
                    p_data = p.data.to(p._full_param_padded.device, non_blocking=True)

                    p_size = p._full_param_padded.size()
                    assert p_size.numel() % self.world_size == 0
                    if self.mixed_precision and force_full_precision:
                        # Allocate fresh tensor in full precision since we are in
                        # mixed precision and full precision rebuild is asked.
                        output_tensor = p_data.new_zeros(p_size)
                    else:
                        if p._full_param_padded.storage().size() != p_size.numel():
                            # Allocate based on full size from all shards.
                            alloc_storage_(p._full_param_padded, size=p_size)
                        output_tensor = p._full_param_padded

                    # Fill output_tensor with (p.data for each shard in self.world_size)
                    if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
                        # New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
                        dist._all_gather_base(output_tensor, p_data, group=self.process_group)
                    else:
                        chunks = list(output_tensor.chunk(self.world_size))
                        dist.all_gather(chunks, p_data, group=self.process_group)

                    # Set p.data = output_tensor (with padding trimmed)
                    update_p_data(output_tensor)

                    if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
                        self._free_fp16_param_shard([p])

                    if self.move_params_to_cpu and (self.params[0].dtype == self.compute_dtype):
                        self._free_fp16_param_shard([p])

        torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
        return output_tensors