in fairscale/nn/misc/flatten_params_wrapper.py [0:0]
def __new__(cls, params: Sequence[nn.Parameter], requires_grad: bool = True) -> "FlatParameter":
"""Make an object using the parent's __new__ function."""
# A empty of non-list input doesn't make sense.
if not isinstance(params, (list, tuple)) or len(params) == 0:
raise ValueError("An non-empty list or tuple argument is needed")
# Normally, all items are Parameters. But during pickling, we will have a single
# Tensor as the input and later in __init__, the correct _param_numels and _param_shapes
# are set.
if not all(isinstance(p, (nn.Parameter, Tensor)) for p in params):
raise ValueError("List items need to be Parameter types")
# Flattening involves (1) making a tensor flat (i.e. single dimensional) and (2) making a module
# heirarchy flat (using a single tensor to replace a tree of tensors). Therefore,
# adding back nesting and heirarchy is counter-productive. If nesting is encountered
# in the future, the reasonable thing to do is likely for the top level FlatParameter to
# absorb the nested one and keep the result flat, free from hierarchy.
if any(isinstance(p, FlatParameter) for p in params):
raise ValueError("Nesting FlatParameter is not supported")
data = torch.cat([p.detach().reshape(-1) if isinstance(p, nn.Parameter) else p.reshape(-1) for p in params], 0)
return super(FlatParameter, cls).__new__(cls, data, requires_grad=requires_grad)