in code/experiment_synthetic/main.py [0:0]
def run_experiment(args):
if args["seed"] >= 0:
torch.manual_seed(args["seed"])
numpy.random.seed(args["seed"])
torch.set_num_threads(1)
if args["setup_sem"] == "chain":
setup_str = "chain_ones={}_hidden={}_hetero={}_scramble={}".format(
args["setup_ones"],
args["setup_hidden"],
args["setup_hetero"],
args["setup_scramble"])
elif args["setup_sem"] == "icp":
setup_str = "sem_icp"
else:
raise NotImplementedError
all_methods = {
"ERM": EmpiricalRiskMinimizer,
"ICP": InvariantCausalPrediction,
"IRM": InvariantRiskMinimization
}
if args["methods"] == "all":
methods = all_methods
else:
methods = {m: all_methods[m] for m in args["methods"].split(',')}
all_sems = []
all_solutions = []
all_environments = []
for rep_i in range(args["n_reps"]):
if args["setup_sem"] == "chain":
sem = ChainEquationModel(args["dim"],
ones=args["setup_ones"],
hidden=args["setup_hidden"],
scramble=args["setup_scramble"],
hetero=args["setup_hetero"])
env_list = [float(e) for e in args["env_list"].split(",")]
environments = [sem(args["n_samples"], e) for e in env_list]
else:
raise NotImplementedError
all_sems.append(sem)
all_environments.append(environments)
for sem, environments in zip(all_sems, all_environments):
sem_solution, sem_scramble = sem.solution()
solutions = [
"{} SEM {} {:.5f} {:.5f}".format(setup_str,
pretty(sem_solution), 0, 0)
]
for method_name, method_constructor in methods.items():
method = method_constructor(environments, args)
method_solution = sem_scramble @ method.solution()
err_causal, err_noncausal = errors(sem_solution, method_solution)
solutions.append("{} {} {} {:.5f} {:.5f}".format(
setup_str,
method_name,
pretty(method_solution),
err_causal,
err_noncausal))
all_solutions += solutions
return all_solutions