in cm/fp16_util.py [0:0]
def get_param_groups_and_shapes(named_model_params):
named_model_params = list(named_model_params)
scalar_vector_named_params = (
[(n, p) for (n, p) in named_model_params if p.ndim <= 1],
(-1),
)
matrix_named_params = (
[(n, p) for (n, p) in named_model_params if p.ndim > 1],
(1, -1),
)
return [scalar_vector_named_params, matrix_named_params]