in neural/linear/__main__.py [0:0]
def main():
# Make repository
parser = get_parser()
args = parser.parse_args()
out_reg, out_autoreg = make_repo_from_parser(args)
# Prepare model labels
if (not args.with_init) and (args.with_forcing):
label_add = "(no init)"
elif (args.with_init) and (not args.with_forcing):
label_add = "(no forcing)"
elif (args.with_init) and (args.with_forcing):
label_add = ""
elif (not args.with_init) and (not args.with_forcing):
label_add = "(no init, no forcing)"
# Initialize result dicts
reg_results = {"label": "lin reg " + label_add,
"scores": []}
shuffled_results = {"word_freqs": [],
"word_lengths": []}
autoreg_results = {"label": "lin autoreg " + label_add,
"scores": []}
# Loop over subjects (in parallel)
with ProcessPoolExecutor(args.n_workers) as pool:
pendings = []
for sub in range(args.n_subjects):
pendings.append(
pool.submit(
eval_lin_models,
sub,
args.data,
out_reg,
out_autoreg,
with_forcing=args.with_forcing,
with_init=args.with_init,
shuffle=args.shuffle))
for pending in tqdm.tqdm(pendings):
(score_linreg, score_linautoreg, shuffled) = pending.result()
# stack results in lists
reg_results["scores"].append(score_linreg)
autoreg_results["scores"].append(score_linautoreg)
for key, score_shuffled in shuffled.items():
shuffled_results[key].append(score_shuffled)
# Making numpy arrays from lists
reg_results["scores"] = np.array(reg_results["scores"])
autoreg_results["scores"] = np.array(autoreg_results["scores"])
for key in shuffled_results.keys():
shuffled_results[key] = np.array(shuffled_results[key])
# # Converting to torch arrays
# reg_results["scores"] = torch.from_numpy(reg_results["scores"])
# autoreg_results["scores"] = torch.from_numpy(autoreg_results["scores"])
# for key in shuffled_results.keys():
# shuffled_results[key] = torch.from_numpy(shuffled_results[key])
# Save
torch.save(reg_results, out_reg / "reference_metrics.th")
torch.save(autoreg_results, out_autoreg / "reference_metrics.th")
if args.shuffle:
for key, value in shuffled_results.items():
torch.save({'scores': value, 'label': 'lin reg ' + label_add},
out_reg / f"shuffled_{key}_metrics.th")