def main()

in complex_shift_operator/__main__.py [0:0]


def main(params):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"running on {device}")

    args = parser.parse_args(params)

    SEED = int(args.seed)
    random.seed(SEED)
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.cuda.manual_seed_all(SEED)

    if args.dataset == "simpleshapes":
        data = datasets.SimpleShapes(
            batch_size=args.bs,
            n_x_translations=args.data_n_x,
            n_y_translations=args.data_n_y,
            n_rotations=args.data_n_rot,
            n_classes=args.n_classes,
            n_pixels=28,
        )
    elif args.dataset == "mnist":
        data = datasets.ProjectiveMNIST(
            batch_size=args.bs,
            n_x_translations=args.data_n_x,
            n_y_translations=args.data_n_y,
            n_rotations=args.data_n_rot,
            train_set_proportion=args.tr_prop,
            test_set_proportion=args.te_prop,
            valid_set_proportion=args.val_prop,
        )
    if args.mode == "train":
        print("Training")
    if args.mode == "test":
        print("Testing")

    # automatically set z_dim to image size
    image_size = data.n_pixels ** 2
    if not os.path.exists(args.output_directory):
        os.mkdir(args.output_directory)
    dict_args = vars(args)
    save_name = "_".join(
        [
            "{0}_{1}".format(key, dict_args[key])
            for key in dict_args
            if key not in ["output_directory", "mode"]
        ]
    )

    if args.supervised:
        transformation_types = []
        indexes = []
        if args.n_rot > 0:
            transformation_types.append("ComplexShiftOperator")
            indexes.append(0)
        if args.n_x > 0:
            transformation_types.append("ComplexShiftOperator")
            indexes.append(1)
        if args.n_y > 0:
            transformation_types.append("ComplexShiftOperator")
            indexes.append(2)
        model_with_rotation = ComplexAutoEncoder(
            data,
            transformation_types=transformation_types,
            indexes=indexes,
            device=device,
            z_dim=image_size,
            seed=SEED,
            output_directory=args.output_directory,
            save_name=save_name,
            n_rotations=args.n_rot,
            n_x_translations=args.n_x,
            n_y_translations=args.n_y,
        )
        n_transfos = len(indexes)
    else:
        model_with_rotation = WeaklyComplexAutoEncoder(
            data,
            transformation_type="ComplexShiftOperator",
            device=device,
            z_dim=image_size,
            seed=SEED,
            temperature=args.tau,
            output_directory=args.output_directory,
            save_name=save_name,
            use_softmax=args.sftmax,
            n_rotations=args.n_rot,
            n_x_translations=args.n_x,
            n_y_translations=args.n_y,
        )

    if args.mode == "train":
        (
            train_loss,
            valid_loss,
            train_mse,
            valid_mse,
            test_mse,
        ) = model_with_rotation.run(n_epochs=args.n_epochs, learning_rate=args.lr)

        perf = np.array([train_mse, valid_mse, test_mse])

        torch.save(perf, os.path.join(args.output_directory, "final_mse_" + save_name))
        torch.save(
            train_loss, os.path.join(args.output_directory, "train_loss_" + save_name)
        )

        torch.save(
            valid_loss, os.path.join(args.output_directory, "valid_loss_" + save_name)
        )

        file_name = "best_checkpoint_{}.pth.tar".format(model_with_rotation.save_name)
        path_to_model = os.path.join(args.output_directory, file_name)
        best_mse, best_epoch = model_with_rotation.load_model(path_to_model)

        ##### Plots train reconstructions
        samples_pairs = np.random.randint(
            0, len(model_with_rotation.data.X_train), size=(10,)
        ).tolist()
        model_with_rotation.plot_x2_reconstructions(
            indices=samples_pairs,
            train_set=True,
            save_name=os.path.join(args.output_directory, "plots_train_reconstructions_" + save_name),
        )
        ##### Plots train rotations of samples
        train_indices = np.random.randint(
            0, len(model_with_rotation.data.X_orig_train), size=(10,)
        ).tolist()
        figsave_name=os.path.join(args.output_directory, "plots_train_rotations_" + save_name + '.png')
        if args.supervised:
            if n_transfos == 1:
                if args.data_n_x > 0:
                    param_name = 'tx'
                elif args.data_n_y > 0:
                    param_name = 'ty'
                if args.data_n_rot > 0:
                    param_name = 'angle'
                model_with_rotation.plot_multiple_transformations(indices=train_indices, train_set = True, 
                    param_name=param_name, save_name=figsave_name
                )
            else:
                model_with_rotation.plot_multiple_transformations_stacked(indices=train_indices, train_set = True, 
                    n_plots = 10, save_name=figsave_name
                )
        else:
            if args.data_n_x > 0:
                param_name = 'tx'
            elif args.data_n_y > 0:
                param_name = 'ty'
            if args.data_n_rot > 0:
                param_name = 'angle'
            model_with_rotation.plot_multiple_transformations(indices=train_indices, train_set = True, 
                param_name=param_name,save_name=figsave_name
            )

        ##### Plots test reconstructions
        samples_pairs = np.random.randint(
            0, len(model_with_rotation.data.X_test), size=(10,)
        ).tolist()
        model_with_rotation.plot_x2_reconstructions(
            indices=samples_pairs,
            train_set=False,
            save_name=os.path.join(args.output_directory, "plots_test_reconstructions_" + save_name),
        )

        ##### Plots test rotations of samples
        test_indices = np.random.randint(
            0, len(model_with_rotation.data.X_orig_test), size=(10,)
        ).tolist()
        figsave_name=os.path.join(args.output_directory, "plots_test_rotations_" + save_name + '.png')
        if args.supervised:
            if n_transfos == 1:
                if args.data_n_x > 0:
                    param_name = 'tx'
                elif args.data_n_y > 0:
                    param_name = 'ty'
                if args.data_n_rot > 0:
                    param_name = 'angle'
                model_with_rotation.plot_multiple_transformations(indices=test_indices, train_set = False, 
                    param_name=param_name, save_name=figsave_name
                )
            else:
                model_with_rotation.plot_multiple_transformations_stacked(indices=test_indices, train_set = False, 
                    n_plots = 10, save_name=figsave_name
                )
        else:
            if args.data_n_x > 0:
                param_name = 'tx'
            elif args.data_n_y > 0:
                param_name = 'ty'
            if args.data_n_rot > 0:
                param_name = 'angle'
            model_with_rotation.plot_multiple_transformations(indices=test_indices, train_set = False, 
                param_name=param_name, save_name=figsave_name
            )

        
    elif args.mode == "test":
        file_name = "best_checkpoint_{}.pth.tar".format(model_with_rotation.save_name)
        path_to_model = os.path.join(args.output_directory, file_name)
        model_with_rotation.load_model(path_to_model)
        if args.supervised:
            loss_func = model_with_rotation.reconstruction_mse_transformed_z1
        else:
            loss_func = model_with_rotation.reconstruction_mse_transformed_z1_weak
        test_mse = model_with_rotation.compute_test_loss(
            loss_func, model_with_rotation.data.test_loader_batch_100
        )
        torch.save(
            torch.FloatTensor([test_mse]),
            os.path.join(
                args.output_directory, "test_mse_" + model_with_rotation.save_name
            ),
        )