def mlp()

in automl21/scs_neural/utils/utils.py [0:0]


def mlp(input_dim,
        hidden_dim,
        output_dim,
        hidden_depth,
        output_mod=None,
        act=nn.ReLU,
        init_weight_scale=None):
    if isinstance(act, str):
        if act == 'relu':
            act = nn.ReLU
        elif act == 'elu':
            act = nn.ELU
        elif act == 'tanh':
            act = nn.Tanh
        else:
            raise NotImplementedError()
    if hidden_depth == 0:
        mods = [nn.Linear(input_dim, output_dim)]
    else:
        mods = [nn.Linear(input_dim, hidden_dim), act()]
        for i in range(hidden_depth - 1):
            mods += [nn.Linear(hidden_dim, hidden_dim), act()]
        mods.append(nn.Linear(hidden_dim, output_dim))
    if output_mod is not None:
        mods.append(output_mod)
    trunk = nn.Sequential(*mods)
    init_weight_scale = None if init_weight_scale == "None" else init_weight_scale
    if init_weight_scale is not None:
        for mod in trunk.modules():
            if isinstance(mod, nn.Linear):
                mod.weight.data.div_(init_weight_scale)
                mod.bias.data.zero_()
    return trunk