in neural/dataset.py [0:0]
def load_torch_megs(path, n_subjects_max=None, subject=None, init=60, exclude=[], include=[]):
# Create dict to the paths of all extracted files (one per subject)
path = Path(path)
subjects = defaultdict(dict)
for child in path.iterdir():
if child.suffix == ".pth": # e.g. meg_1076_4_visual.pth
kind, sub, *_ = child.stem.split("_")
subjects[sub][kind] = child
# Select subjects of interest
to_load = list(subjects.keys())
to_load.sort()
if subject is not None:
to_load = [to_load[subject]]
if n_subjects_max:
to_load = to_load[:n_subjects_max]
train_sets = []
valid_sets = []
test_sets = []
meg_scalers = []
means = []
pca_mats = []
iterator = tqdm.tqdm(to_load, leave=False, ncols=120, desc="Loading data...")
# Loop over subjects of interest
subjs = []
for index, subject in enumerate(iterator):
# Load meg and forcing extraction files
megdata = th.load(subjects[subject]["meg"])
forcings = th.load(subjects[subject]["forcing"])
subjs.append(megdata.get("subject", subject)) # what does the second arg do?
after = forcings.pop("word_n_after", None)
before = forcings.pop("word_n_before", None)
# Define new stimulus features (last_word, first_word)
# from old stimulus features (word_n_after, word_n_before)
# assuming they are inclusive
if after is not None:
forcings["last_word"] = (after == 1).astype(np.float32)
if before is not None:
forcings["first_word"] = (before == 1).astype(np.float32)
if "is_stop" in forcings:
last_word = "is_stop"
else:
last_word = "last_word"
# Create mask to select the first stimulus only in the 2.5s epoch
if "first_mask" not in forcings:
stim = forcings["stimulus"]
first = 0 * stim
for row in range(len(stim)):
low = 60
start = low + stim[row, low:].nonzero()[0][0]
end = (stim[row, start:] == 0).nonzero()[0]
if len(end):
end = end[0]
first[row, start:start + end] = 1
else:
# print(subject, row, stim[row])
first[row, start:] = 1
forcings["first_mask"] = first
# Include or exclude stimulus features based on their name (key of forcing dict)
for name in exclude:
if name not in forcings:
raise ValueError(f"{name} is not a valid feature name.")
for name in include:
if name not in forcings:
raise ValueError(f"{name} is not a valid feature name.")
if include:
feats = list(include)
else:
feats = list(forcings.keys())
for name in exclude:
feats.remove(name)
forcings = {
# just normalizes forcing and permutes to [N, 1, T]
name: _prepare_forcing(forcing)
for name, forcing in forcings.items() if name in feats
}
forcing_dims = {}
for key, value in forcings.items():
forcing_dims[key] = value.size(1) # expected: 1
meg = megdata["meg"]
pca_mats.append(megdata["pca_mat"])
if "meg_last_idx" in megdata:
last_index = th.from_numpy(megdata["meg_last_idx"])
else:
last_index = th.full((meg.shape[0], ), meg.shape[1], dtype=th.long)
# Scale (robust) meg data
# TODO: separate scaling for each set (train, valid, test)?
meg_scaler = RobustScaler()
meg_scalers.append(meg_scaler)
meg = meg_scaler.fit_transform(meg.reshape(-1, meg.shape[-1])).reshape(*meg.shape)
meg = th.from_numpy(meg)
# Remove trials where an amplitude is too high (e.g. 16) after scaling
max_amplitude = meg.abs().max(dim=1)[0].max(dim=1)[0]
mask = max_amplitude <= 16
# print(mask.float().mean(), mask.shape)
meg = meg[mask]
forcings = {key: value[mask] for key, value in forcings.items()}
last_index = last_index[mask]
# Center meg data
mean = meg.mean(dim=0, keepdim=True).mean(dim=1, keepdim=True)
means.append(mean)
meg = meg - mean
# Change meg format: [N, T, C] -> [N, C, T]
meg = meg.permute(0, 2, 1)
meg_dim = meg.size(1) # expected: C
n_trials = meg.shape[0]
# Separate trials into train / valid / test:
# search for an end of sentence to do the cuts
train, valid, test = 0.7, 0.1, 0.2
for trial in range(int(train * n_trials),
int((train + valid) * n_trials)):
if forcings[last_word][trial, 0, 60] > 0: # end of sentence
break
idx_train = list(range(trial + 1))
for trial in range(int((train + valid) * n_trials), n_trials):
if forcings[last_word][trial, 0, 60] > 0:
break
idx_valid = list(range(idx_train[-1] + 1, trial + 1))
idx_test = list(range(idx_valid[-1] + 1, n_trials))
# Instantiate train/valid/test epoched datasets
dataset_train = MegSubject(
meg=meg[idx_train],
forcings={k: v[idx_train]
for k, v in forcings.items()},
length=1 + last_index[idx_train],
subject_id=index)
dataset_valid = MegSubject(
meg=meg[idx_valid],
forcings={k: v[idx_valid]
for k, v in forcings.items()},
length=1 + last_index[idx_valid],
subject_id=index)
dataset_test = MegSubject(
meg=meg[idx_test],
forcings={k: v[idx_test]
for k, v in forcings.items()},
length=1 + last_index[idx_test],
subject_id=index)
train_sets.append(dataset_train)
valid_sets.append(dataset_valid)
test_sets.append(dataset_test)
print("subjects: ", subjs)
print("Overall train size: ", sum(tr.meg.shape[0] for tr in train_sets))
print("Overall valid size: ", sum(tr.meg.shape[0] for tr in valid_sets))
print("Overall test size: ", sum(tr.meg.shape[0] for tr in test_sets))
return MegDatasets(
train_sets=train_sets,
valid_sets=valid_sets,
test_sets=test_sets,
meg_scalers=meg_scalers,
means=means,
pca_mats=pca_mats,
meg_dim=meg_dim,
forcing_dims=forcing_dims)