def __init__()

in optimum/fx/parallelization/parallel_layers/linear.py [0:0]


    def __init__(self, ctx: ParallelExecutionCtx, linear: nn.Linear, gather_output: bool = True) -> None:
        super(ColumnParallelLinear, self).__init__()
        self.process_group = ctx.tp_group
        world_size = dist.get_world_size(self.process_group)
        tp_rank = dist.get_rank(self.process_group)
        ensure_divisibility(linear.out_features, world_size)

        out_features = linear.out_features // world_size
        bias = linear.bias is not None

        # modify meta information
        weight_meta = getattr(linear.weight, "meta", None)
        assert isinstance(
            weight_meta, ParameterMeta
        ), "should have run `initialize_parameter_meta` after moving model to current device"

        if weight_meta.is_modified_meta:
            assert weight_meta.is_tied, "only tied parameters could already have modified meta"
        else:
            weight_meta.need_initialize = True
            weight_meta.is_parallel = True
            weight_meta.dim = 0
            for _, Slice in weight_meta.mapping.items():
                Slice.index = slice(tp_rank * out_features, (tp_rank + 1) * out_features)
            weight_meta.is_modified_meta = True

        # skip creating actual parameters
        self.weight = linear.weight
        self.gather_output = gather_output

        if bias:
            bias_meta = getattr(linear.bias, "meta", None)
            assert isinstance(
                bias_meta, ParameterMeta
            ), "should have run `initialize_parameter_meta` after moving model to current device"

            if bias_meta.is_modified_meta:
                assert bias_meta.is_tied, "only tied parameters could already have modified meta"
            else:
                bias_meta.need_initialize = True
                bias_meta.is_parallel = True
                bias_meta.init_fn = torch.zero_
                bias_meta.dim = 0
                for _, Slice in bias_meta.mapping.items():
                    Slice.index = slice(tp_rank * out_features, (tp_rank + 1) * out_features)
                bias_meta.is_modified_meta = True
            self.bias = linear.bias
        else:
            self.register_parameter("bias", None)