def main()

in main.py [0:0]


def main(args):
    if args.model == "TMK_Poullot":
        args.normalization = "freq"

    excluded = {"output_dir", "pca_mean", "pca_DVt"}
    parameter_string = "_".join(
        ["%s-%s" % (k, str(v)) for (k, v) in vars(args).items() if k not in excluded]
    )
    output_dir = os.path.join(args.output_dir, parameter_string)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    print(args)
    print("Parameter string is", parameter_string)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    args.dataset_test = getattr(datasets, args.dataset_test)
    args.model = getattr(models, args.model)

    # TMK layers setup
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = args.model(args).to(device)
    test(model, args)