def main()

in fairmotion/tasks/motion_prediction/test.py [0:0]


def main(args):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logging.info("Preparing dataset")
    dataset, mean, std = utils.prepare_dataset(
        *[
            os.path.join(args.preprocessed_path, f"{split}.pkl")
            for split in ["train", "test", "validation"]
        ],
        batch_size=args.batch_size,
        device=device,
        shuffle=args.shuffle,
    )
    # number of predictions per time step = num_joints * angle representation
    data_shape = next(iter(dataset["train"]))[0].shape
    num_predictions = data_shape[-1]

    logging.info("Preparing model")
    model = prepare_model(
        f"{args.save_model_path}/{args.epoch if args.epoch else 'best'}.model",
        num_predictions,
        args,
        device,
    )

    logging.info("Running model")
    _, rep = os.path.split(args.preprocessed_path.strip("/"))
    seqs_T, mae = test_model(
        model, dataset["test"], rep, device, mean, std, args.max_len
    )
    logging.info(
        "Test MAE: "
        + " | ".join([f"{frame}: {mae[frame]}" for frame in mae.keys()])
    )

    if args.save_output_path:
        logging.info("Saving results")
        save_motion_files(seqs_T, args)