in optimum/fx/parallelization/parallel_layers/linear.py [0:0]
def __init__(self, ctx: ParallelExecutionCtx, linear: nn.Linear, gather_output: bool = True) -> None:
super(ColumnParallelLinear, self).__init__()
self.process_group = ctx.tp_group
world_size = dist.get_world_size(self.process_group)
tp_rank = dist.get_rank(self.process_group)
ensure_divisibility(linear.out_features, world_size)
out_features = linear.out_features // world_size
bias = linear.bias is not None
# modify meta information
weight_meta = getattr(linear.weight, "meta", None)
assert isinstance(
weight_meta, ParameterMeta
), "should have run `initialize_parameter_meta` after moving model to current device"
if weight_meta.is_modified_meta:
assert weight_meta.is_tied, "only tied parameters could already have modified meta"
else:
weight_meta.need_initialize = True
weight_meta.is_parallel = True
weight_meta.dim = 0
for _, Slice in weight_meta.mapping.items():
Slice.index = slice(tp_rank * out_features, (tp_rank + 1) * out_features)
weight_meta.is_modified_meta = True
# skip creating actual parameters
self.weight = linear.weight
self.gather_output = gather_output
if bias:
bias_meta = getattr(linear.bias, "meta", None)
assert isinstance(
bias_meta, ParameterMeta
), "should have run `initialize_parameter_meta` after moving model to current device"
if bias_meta.is_modified_meta:
assert bias_meta.is_tied, "only tied parameters could already have modified meta"
else:
bias_meta.need_initialize = True
bias_meta.is_parallel = True
bias_meta.init_fn = torch.zero_
bias_meta.dim = 0
for _, Slice in bias_meta.mapping.items():
Slice.index = slice(tp_rank * out_features, (tp_rank + 1) * out_features)
bias_meta.is_modified_meta = True
self.bias = linear.bias
else:
self.register_parameter("bias", None)