relogic/pretrainkit/multitask_trainer.py [873:941]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        # return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
        return PredictionOutputWithSize(predictions=preds, predictions_size=preds_size, label_ids=label_ids, label_size=label_size, metrics=metrics)

    def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
        assert self.args.local_rank != -1

        output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(output_tensors, tensor)

        concat = torch.cat(output_tensors, dim=0)

        # truncate the dummy elements added by SequentialDistributedSampler
        output = concat[:num_total_examples]
        return output

    def distributed_concat_tensor(self, tensor: torch.Tensor):
        assert self.args.local_rank != -1

        output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(output_tensors, tensor)

        concat = torch.cat(output_tensors, dim=0)
        return concat

    def distributed_concat_varsize_tensor(self, tensor: torch.Tensor):
        assert self.args.local_rank != -1

        sizes = self.distributed_concat_tensor(tensor.new_full(size=(1,), fill_value=tensor.size(0)))
        max_size = sizes.max().item()

        padded = tensor.new_zeros(max_size)
        padded[:tensor.size(0)] = tensor

        padded_agg = self.distributed_concat_tensor(padded)
        slices = []
        for i, size in enumerate(sizes):
            start_idx = i * max_size
            end_idx = start_idx + size.item()
            slices.append(padded_agg[start_idx: end_idx])
        ret = torch.cat(slices, dim=0)
        return ret


    def distributed_concat_with_size(self, tensor: torch.Tensor, size: torch.Tensor, num_total_examples: int) -> torch.Tensor:
        assert self.args.local_rank != -1

        # output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        # output_sizes = [size.clone() for _ in range(torch.distributed.get_world_size())]
        # torch.distributed.all_gather(output_tensors, tensor)
        # torch.distributed.all_gather(output_sizes, size)
        # concat = torch.cat(output_tensors, dim=0)
        # concat_sizes = torch.cat(output_sizes, dim=0)
        concat_sizes = self.distributed_concat_varsize_tensor(size)
        concat = self.distributed_concat_varsize_tensor(tensor)

        # output_sizes = concat_sizes[:num_total_examples]

        assert concat_sizes.sum() == concat.size(0)
        return concat, concat_sizes
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



relogic/pretrainkit/trainer.py [835:903]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        # return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
        return PredictionOutputWithSize(predictions=preds, predictions_size=preds_size, label_ids=label_ids, label_size=label_size, metrics=metrics)

    def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
        assert self.args.local_rank != -1

        output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(output_tensors, tensor)

        concat = torch.cat(output_tensors, dim=0)

        # truncate the dummy elements added by SequentialDistributedSampler
        output = concat[:num_total_examples]
        return output

    def distributed_concat_tensor(self, tensor: torch.Tensor):
        assert self.args.local_rank != -1

        output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(output_tensors, tensor)

        concat = torch.cat(output_tensors, dim=0)
        return concat

    def distributed_concat_varsize_tensor(self, tensor: torch.Tensor):
        assert self.args.local_rank != -1

        sizes = self.distributed_concat_tensor(tensor.new_full(size=(1,), fill_value=tensor.size(0)))
        max_size = sizes.max().item()

        padded = tensor.new_zeros(max_size)
        padded[:tensor.size(0)] = tensor

        padded_agg = self.distributed_concat_tensor(padded)
        slices = []
        for i, size in enumerate(sizes):
            start_idx = i * max_size
            end_idx = start_idx + size.item()
            slices.append(padded_agg[start_idx: end_idx])
        ret = torch.cat(slices, dim=0)
        return ret


    def distributed_concat_with_size(self, tensor: torch.Tensor, size: torch.Tensor, num_total_examples: int) -> torch.Tensor:
        assert self.args.local_rank != -1

        # output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        # output_sizes = [size.clone() for _ in range(torch.distributed.get_world_size())]
        # torch.distributed.all_gather(output_tensors, tensor)
        # torch.distributed.all_gather(output_sizes, size)
        # concat = torch.cat(output_tensors, dim=0)
        # concat_sizes = torch.cat(output_sizes, dim=0)
        concat_sizes = self.distributed_concat_varsize_tensor(size)
        concat = self.distributed_concat_varsize_tensor(tensor)

        # output_sizes = concat_sizes[:num_total_examples]

        assert concat_sizes.sum() == concat.size(0)
        return concat, concat_sizes
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



