in neural/extraction.py [0:0]
def extract_subject(subject):
'''Extracts MEG and forcing for a subject in the MOUS Dataset.
Input:
subject (int): subject identifier
'''
log_file = log_files.iloc[subject]
try:
# generic output filename
output_fname = [
"%s",
str(log_file["subject"]),
str(log_file["log_id"]), log_file["task"]
]
output_fname = "_".join(output_fname) + ".pth"
##########################
# LOAD MEG AND LOG
##########################
# get meg and log filenames
raw_fname = os.path.join(data_path, log_file['meg_file'])
log_fname = os.path.join(data_path, log_file['log_file'])
# read meg (continuous)
raw = mne.io.read_raw_ctf(raw_fname, preload=True)
raw.filter(1., 30.) # Slow
# preprocess annotations and add task information
log = read_log(log_fname, stimuli)
log = _add_stim_id(log, verbose=False, stimuli=stimuli) # get words
# adding n words before and after in sentence
log_words = log.query('condition=="word"')
words_idx = log.query('condition=="word"').index
sentence_lengths = np.bincount(log_words.sequence_pos.values.astype(int))
n_words_before = np.concatenate([np.arange(length) + 1
for length in sentence_lengths]).flatten()
n_words_before = n_words_before.astype(int)
n_words_after = np.concatenate([np.ones(length) * length
for length in sentence_lengths]).flatten() \
- n_words_before
n_words_after = n_words_after.astype(int)
log.loc[words_idx, "n_words_before"] = n_words_before
log.loc[words_idx, "n_words_after"] = n_words_after
# find events
events = mne.find_events(raw, min_duration=.010)
# link meg and annotations
log = get_log_times(log, events, raw.info['sfreq'])
##########################
# EXTRACT MEG
##########################
# select desired event
log_events = log.query('condition=="word"')
# format events for mne
log_events_formatted = np.c_[log_events.meg_sample,
np.ones((len(log_events), 2), int)]
_, idx = np.unique(log_events_formatted[:, 0], return_index=True)
# segment meg into word-locked epochs
picks = mne.pick_types(raw.info,
meg=True,
eeg=False,
stim=False,
eog=False,
ecg=False)
decim = 10
tmin, tmax = -.500, 2
epochs = mne.Epochs(
raw,
events=log_events_formatted,
metadata=log_events,
tmin=tmin,
tmax=tmax,
decim=decim,
preload=True,
picks=picks,
)
# throw away compensation channels
bads = [epochs.ch_names[i] for i in range(28)] # hardcoded
raw = raw.pick_types(meg=True, exclude=bads)
epochs = epochs.pick_types(meg=True, exclude=bads)
# get evoked meg
evoked = epochs.average(method='mean')
# get pca on evoked
evoked_temp = evoked.apply_baseline().data.T * 1e12 # scaled
duration_for_pca = int((np.abs(tmin) + 1) * epochs.info["sfreq"])
evoked_temp = evoked_temp[:duration_for_pca] # cropped
if args.use_pca:
pca = PCA(args.pca_dim).fit(evoked_temp)
pca_mat = pca.components_
else:
pca_mat = np.eye(evoked_temp.shape[1], dtype=np.float32)
##########################
# SAVE MEG
##########################
# collect
meg = epochs.get_data()
meg_evoked = evoked.apply_baseline().data[None, :, :]
# useful for sentences of different lengths
meg_last_idx = (np.abs(tmin) + tmax) * epochs.info["sfreq"] * np.ones(len(epochs))
meg_last_idx = meg_last_idx.astype(int)
# reformat
meg = np.swapaxes(meg, 1, 2)
meg_evoked = np.swapaxes(meg_evoked, 1, 2)
meg_pca = meg @ pca_mat.T
times = np.array(epochs.metadata["time"], dtype=np.float32)
# save
output_dict = dict(
zip(["meg", "meg_last_idx", "pca_mat", "epochs_info", "times", "subject"],
[meg_pca.astype(np.float32),
meg_last_idx, pca_mat, epochs.info, times, log_files.subject[subject]
]))
output_path = os.path.join(output_directory, output_fname % "meg")
print("output path: ", output_path)
th.save(output_dict, output_path)
##########################
# LOAD FORCING
##########################
n_epochs, n_channels, n_times = epochs.get_data().shape
forcing_word = np.zeros((n_epochs, 6, n_times), dtype=np.float32)
for epo_idx in range(n_epochs):
# continuous time interval
on = epochs.metadata.iloc[epo_idx].time
start, end = on - np.abs(tmin), on + tmax
# corresponding words
cond = (start < epochs.metadata.time) & (epochs.metadata.time < end)
words = epochs.metadata[cond].word.values.flatten().tolist()
# recentering the time interval around the main onset
onsets = epochs.metadata[cond].time.values - on + np.abs(tmin)
# converting the time interval from s to Tsampl
onsets = (onsets * epochs.info["sfreq"]).astype(int)
# recovering word durations, then offsets
durations = epochs.metadata[cond].Duration.values.astype(float) * 1e-4 # unit: second
durations = (durations * epochs.info["sfreq"]).astype(int) # unit: time sample
offsets = onsets + durations
# getting features
word_lengths = get_word_length(words)
word_freqs = get_word_freq(words)
word_n_before = epochs.metadata[cond].n_words_before.values.flatten().tolist()
# add + 1 to make difference with no forcing
word_n_after = (epochs.metadata[cond].n_words_after.values.flatten() + 1).tolist()
# placing square on word presence
for idx, (onset, offset) in enumerate(zip(onsets, offsets)):
forcing_word[epo_idx, 0, onset: offset] = 1.
forcing_word[epo_idx, 1, onset:offset] = word_lengths[idx]
forcing_word[epo_idx, 2, onset:offset] = word_freqs[idx]
forcing_word[epo_idx, 3, onset:offset] = word_n_before[idx]
forcing_word[epo_idx, 4, onset:offset] = word_n_after[idx]
if idx == 0:
# mask for first forcing used to shuffle features
forcing_word[epo_idx, 5, onset:offset] = 1.
# save forcing
forcing_names = ["word_onsets", "word_lengths", "word_freqs",
"word_n_before", "word_n_after", "first_mask"]
forcing = [forcing_word[:, 0, :][:, None, :],
forcing_word[:, 1, :][:, None, :],
forcing_word[:, 2, :][:, None, :],
forcing_word[:, 3, :][:, None, :],
forcing_word[:, 4, :][:, None, :],
forcing_word[:, 5, :][:, None, :],
]
# reformat
forcing = [np.swapaxes(f, 1, 2) for f in forcing]
# save
output_dict = dict(zip(forcing_names, forcing))
output_path = os.path.join(output_directory,
output_fname % "forcing")
th.save(output_dict, output_path)
except Exception as e:
print(f"Error {e} with subject {subject} {log_file}")
traceback.print_exc()
return
else:
print("SUBJECT", subject, "done")