in main.py [0:0]
def main(args):
if args.model == "TMK_Poullot":
args.normalization = "freq"
excluded = {"output_dir", "pca_mean", "pca_DVt"}
parameter_string = "_".join(
["%s-%s" % (k, str(v)) for (k, v) in vars(args).items() if k not in excluded]
)
output_dir = os.path.join(args.output_dir, parameter_string)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
print(args)
print("Parameter string is", parameter_string)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
args.dataset_test = getattr(datasets, args.dataset_test)
args.model = getattr(models, args.model)
# TMK layers setup
device = "cuda" if torch.cuda.is_available() else "cpu"
model = args.model(args).to(device)
test(model, args)