def scaling_factors()

in pycls/models/scaler.py [0:0]


def scaling_factors(scale_type, scale_factor):
    """
    Computes model scaling factors to allow for scaling along d, w, g, r.

    Compute scaling factors such that d * w * w * r * r == scale_factor.
    Here d is depth, w is width, g is groups, and r is resolution.
    Note that scaling along g is handled in a special manner (see paper or code).

    Examples of scale_type include "d", "dw", "d1_w2", and "d1_w2_g2_r0".
    A scale_type of the form "dw" is equivalent to "d1_w1_g0_r0". The scalar value
    after each scaling dimensions gives the relative scaling along that dimension.
    For example, "d1_w2" indicates to scale twice more along width than depth.
    Finally, scale_factor indicates the absolute amount of scaling.

    The "fast compound scaling" strategy from the paper is specified via "d1_w8_g8_r1".
    """
    if all(s in "dwgr" for s in scale_type):
        weights = {s: 1.0 if s in scale_type else 0.0 for s in "dwgr"}
    else:
        weights = {sw[0]: float(sw[1::]) for sw in scale_type.split("_")}
        weights = {**{s: 0.0 for s in "dwgr"}, **weights}
        assert all(s in "dwgr" for s in weights.keys()), scale_type
    sum_weights = weights["d"] + weights["w"] + weights["r"] or weights["g"] / 2 or 1.0
    d = scale_factor ** (weights["d"] / sum_weights)
    w = scale_factor ** (weights["w"] / sum_weights / 2.0)
    g = scale_factor ** (weights["g"] / sum_weights / 2.0)
    r = scale_factor ** (weights["r"] / sum_weights / 2.0)
    s_actual = d * w * w * r * r
    assert d == w == r == 1.0 or isclose(s_actual, scale_factor, rel_tol=0.01)
    return d, w, g, r