import os
import mne
import numpy as np
import matplotlib.pyplot as plt
import pickle
import pandas as pd
from scipy import signal
from scipy.io import wavfile
# from pybv import write_brainvision
from pyprep.prep_pipeline import PrepPipeline
from mne_icalabel import label_components
Codeswtich preprocessing pipeline
1 Libraries
2 Trigger lag fix
2.1 parameters
# directory
= os.getcwd() + '/data/'
input_dir = os.getcwd() + '/preprocessed/1_trigger_lag_corrected/'
output_dir # create a folder if the folder doesn't exist
=True)
os.makedirs(output_dir, exist_ok
# subjects to exclude
= [
exclude_subs
]
# trigger searching window (actual trigger time based on audio - trigger time in the data)
= -0.01
t_left = 0.5 t_right
Here we fix trigger lag fix and recode the event for item-level analysis.
#### create dictionaries for item-level codes and descriptions ####
= pd.read_csv("mapping_file.txt", delimiter='\t')
df = dict(zip(df['filename'], df['item_code']))
mapping_file2code = dict(zip(df['item_code'], df['description']))
mapping_code2description #############################################################
# get list of file names
= os.listdir(input_dir)
all_files
# for each file
for file in all_files:
if file.endswith('.vhdr') and (file.split('.')[0] not in exclude_subs) and (file.split('.')[0]+ '_corr.fif' not in os.listdir(output_dir)):
# read in vhdr files
= mne.io.read_raw_brainvision(input_dir + file, preload = True)
raw
# extract sampling rate
= raw.info['sfreq']
eeg_sfreq
#### get trigger code, audio data, and audio length ####
# initialize dictionaries
= {} # marker: [description list]
trigger_dict = {} # filename: data
audio = {} # filename: audio length
lengths
# read in the mapping file
with open('mapping_file.txt','r') as f:
# skip the first line (header)
next(f)
for line in f:
# read in the current line
= line.replace('\n','')
line # get info
= line.split('\t')
filename, marker, description, item_code
# initialize a filename list for each trigger code
if marker not in trigger_dict.keys():
= []
trigger_dict[marker] # add the filename to the list
trigger_dict[marker].append(filename)
# get audio data for each file
if filename not in audio.keys():
# get sample rate and data of the audio file
= wavfile.read('codeswitch_mystim/stimuli/{}'.format(filename))
sampleRate, data
# the sound file is stereo, so take only 1 channel's data
= data[:,0]
data
# calculate sound file length
= len(data)/sampleRate
lengths[filename]
# reduce the sampling rate of the audio file by the factor of int(sampleRate/eeg_sfreq)
= signal.decimate(data, int(sampleRate/eeg_sfreq), ftype='fir')
data_downsampled
# add audio data the audio dictionary
= data_downsampled
audio[filename] ####################################################
#### get events ####
# for each stimulus, mark the block info
= mne.events_from_annotations(raw, verbose='WARNING')
events_from_annot, event_dict # only events with trigger code 1-4 are useful
= events_from_annot[events_from_annot[:, 2] <= 4]
events_from_annot ######################
#### cross correction to find the audio file and correct lag correction ####
# initialize
= np.array([]) # a delay list
delays = [] # a bad stim list
bad_stim = [] # list of each event's max cross-correlation coefficient
corr_results = [] # list of each event's filename
filename_results
# loop over each event
for i in range(len(events_from_annot)):
# get current event
= events_from_annot[i]
event # get the onset latency
= event[0]/eeg_sfreq
time # get the marker
= str(event[2])
marker
# initialize dictonary of each file and its max correlation coefficent
= {}
singleFile_maxCorr_dict # initialize dictionary for each file and its lag corresponding to the max correlation coefficent
= {}
singleFile_maxCorrLag_dict
#### find the audio file for the current event, recode the marker, and record the lag info ####
for name in trigger_dict[marker]:
# get the data from the sound channel
= raw.get_data(
audio_eeg = ['StimTrak'],
picks = time + t_left,
tmin = time + lengths[name] + t_right,
tmax 0]
)[
# get actual stimulus data
= audio[name]
audio_stim
# z-score normalization (subtract mean, divide by std)
= (audio_eeg - np.mean(audio_eeg)) / np.std(audio_eeg)
audio_eeg = (audio_stim - np.mean(audio_stim)) / np.std(audio_stim)
audio_stim
# cross-correlation
= signal.correlate(audio_eeg, audio_stim, mode='full')
corr # normalize for signal duration
= corr / (np.linalg.norm(audio_eeg) * np.linalg.norm(audio_stim))
corr # find peak correlation value
= np.max(corr)
singleFile_maxCorr_dict[name]
# get lags for cross-correlation
= signal.correlation_lags(
lags
audio_eeg.size,
audio_stim.size,="full")
mode# find the lag for peak correlation
= lags[np.argmax(corr)] + t_left*eeg_sfreq
singleFile_maxCorrLag_dict[name]
# get the file giving max correlation
= max(singleFile_maxCorr_dict, key=singleFile_maxCorr_dict.get)
max_file # get the maximum correlation among all files
= singleFile_maxCorr_dict[max_file]
max_corr # get the lag
= singleFile_maxCorrLag_dict[max_file]
lag # add item-level trigger code
2] = mapping_file2code[max_file]
events_from_annot[i][
# if the maximum correction is less than a threshold
if round(max_corr,1) < 0.5:
# mark the stim bad
bad_stim.append(i)
# add the maximum correlation info for the current event
corr_results.append(max_corr)
filename_results.append(max_file)= np.append(delays,lag)
delays ##################################################
#### plot the stimtrak eeg and the audio data of the event with the minimum correlation of the current file ####
# get min corr
= np.argmin(corr_results)
min_corr # get current event info
= events_from_annot[min_corr]
event # get the onset latency
= event[0]/eeg_sfreq
time # get the file name of the event
= filename_results[min_corr]
name # get the stimtrak data
= raw.get_data(
audio_eeg = ['StimTrak'],
picks = time + t_left,
tmin = time + lengths[name] + t_right,
tmax 0]
)[# actual stimulus data
= audio[name]
audio_stim # z-score normalization (subtract mean, divide by std)
= (audio_eeg - np.mean(audio_eeg)) / np.std(audio_eeg)
audio_eeg = (audio_stim - np.mean(audio_stim)) / np.std(audio_stim)
audio_stim # plot
= plt.subplots()
fig, ax = 'StimTrak', alpha = 0.6)
ax.plot(audio_eeg, label = 'wave', alpha = 0.6)
ax.plot(audio_stim, label file)
ax.set_title(
ax.legend()+ file.split('.')[0] + "_minCor.png", dpi=300, bbox_inches='tight')
fig.savefig(output_dir ##########################
# record number of bad stims of the current file
if len(bad_stim)>0:
# wave the number of bad stims to a file
with open(output_dir + 'bad_stim.txt', 'a+') as f:
=f.write(file + '\t' + str(len(bad_stim)) + ' bad stims' + '\n')
_
# remove events of bad stims
= np.delete(events_from_annot, bad_stim, 0)
events_from_annot
# remove lags of bad stims
= np.delete(delays, bad_stim, 0)
delays
# correct for trigger lag
0] = events_from_annot[:,0] + delays
events_from_annot[:,
# create item-level annotations
= mne.annotations_from_events(
annot_from_events = events_from_annot,
events = mapping_code2description, # item-level mapping
event_desc = eeg_sfreq
sfreq
)
# set annotations
raw.set_annotations(annot_from_events)
# drop the audio channel in data
'StimTrak'])
raw.drop_channels([
# save single-trial delay file
+ file.replace('.vhdr', '_delays.txt'), delays, fmt='%i')
np.savetxt(output_dir
# save as a file-into-file data
+ file.split('.')[0]+ '_corr.fif') raw.save(output_dir
3 Bad channel correction
- filtering
- resampling
- remove line noise
- bad channel detection & repairing
- add back reference channel TP9
3.1 parameters
#### parameters ####
# set directory
= os.getcwd() + '/preprocessed/1_trigger_lag_corrected/'
input_dir = os.getcwd() + '/preprocessed/2_bad_channel_corrected/'
output_dir # create a folder if the folder doesn't exist
=True)
os.makedirs(output_dir, exist_ok
# filter cutoff frequencies (low/high)
= 1
f_low = 100
f_high
# resampling frequency
= 250
f_res
# line frequency
= 60
line_freq
# preprocessing parameters
= {
prep_params "ref_chs": 'eeg',
"reref_chs": 'eeg', # average re-reference
"line_freqs": np.arange(line_freq, f_res/2, line_freq),
}
# create a montage file for the pipeline
= mne.channels.make_standard_montage("standard_1020")
montage
# interpolation method
# method=dict(eeg="spline")
#####################################################
#### Preprocessing (filtering, resampling, bad channel detection/interpoloation, re-reference) ####
#####################################################
# get all file namesin the folder
= os.listdir(input_dir)
all_input = os.listdir(output_dir)
all_output
# for each file
for file in all_input:
if file.endswith("corr.fif") and (file.split('.')[0]+ '_prep.fif' not in all_output):
# read in file
= mne.io.read_raw_fif(input_dir + file, preload=True)
raw
# set channel type for EOG channels
'Fp1':'eog', 'Fp2':'eog'})
raw.set_channel_types({
# filter
filter(l_freq = f_low, h_freq = f_high)
raw.
#### cut off the beginning and ending part ####b
# get the onset of the first and the last event ####
= mne.events_from_annotations(raw, verbose='WARNING')
events_from_annot, event_dict
#### crop the file to cut off the first the last 10s portion which maybe noisy ####
# define the beginning time (in seconds)
= events_from_annot[0][0]/raw.info['sfreq'] - 10
crop_start
# define the ending time (in seconds)
= events_from_annot[-1][0]/raw.info['sfreq'] + 10
crop_end
# crop the data
raw.crop(=max(crop_start, raw.times[0]),
tmin=min(crop_end, raw.times[-1])
tmax
)####################################################################################
# resample
= f_res)
raw.resample(sfreq
# read in channel location info
raw.set_montage(montage)
#### Use PrePipeline to preprocess ####
'''
1. detect and interpolate bad channels
2. remove line noise
3. re-reference
'''
# apply pyprep
= PrepPipeline(raw, prep_params, montage, random_state=42)
prep
prep.fit()
# export a txt file for the interpolated channel info
with open(output_dir + 'bad_channel.txt', 'a+') as f:
=f.write(
_ file + ':\n' +
"- Bad channels original: {}".format(prep.noisy_channels_original["bad_all"]) + '\n' +
"- Bad channels after robust average reference: {}".format(prep.interpolated_channels) + '\n' +
"- Bad channels after interpolation: {}".format(prep.still_noisy_channels) + '\n'
)
# save the pypred preprocessed data into the raw data structure
= prep.raw
raw
# add back the reference channel
= mne.add_reference_channels(raw,'TP9')
raw
# add the channel loc info (for the newly added reference channel)
raw.set_montage(montage)
# save
+ file.split('.')[0]+ '_prep.fif') raw.save(output_dir
4 ICA artifact subtraction
4.1 parameters
# directory
= os.getcwd() + '/preprocessed/2_bad_channel_corrected/'
input_dir = os.getcwd() + '/preprocessed/3_ica/'
output_dir # create a folder if the folder doesn't exist
=True)
os.makedirs(output_dir, exist_ok
# up to which IC you want to consider
= 15
ic_upto # ic_upto = 99
# get all file names in the folder
= os.listdir(input_dir)
all_input = os.listdir(output_dir)
all_output
# initialize a dictionary for files
for file in all_input:
if file.endswith("prep.fif") and (file.split('.')[0] + '_ica.fif' not in all_output):
# read in file
= mne.io.read_raw_fif(input_dir + file, preload=True)
raw
# make a filtered file copy ICA. It works better on signals with 1 Hz high-pass filtered and 100 Hz low-pass filtered
= raw.copy().filter(l_freq = 1, h_freq = 100)
raw_filt
# apply a common average referencing, to comply with the ICLabel requirements
"average")
raw_filt.set_eeg_reference(
# initialize ica parameters
= mne.preprocessing.ICA(
ica # n_components=0.999999,
='auto', # n-1
max_iter# use ‘extended infomax’ method for fitting the ICA, to comply with the ICLabel requirements
= 'infomax',
method = dict(extended=True),
fit_params = 42,
random_state
)
#### get ica solution ####
= ['eeg'])
ica.fit(raw_filt, picks
# save ica solutions
+ file.split('.')[0]+ '_icaSolution.fif')
ica.save(output_dir
#### ICLabel ####
= label_components(raw_filt, ica, method="iclabel")
ic_labels
# save
with open(output_dir + file.split('.')[0]+ '_icLabels.pickle', 'wb') as f:
pickle.dump(ic_labels, f)
#### auto select brain AND other ####
= ic_labels["labels"]
labels = [
exclude_idx for idx, label in enumerate(labels) if idx<ic_upto and label not in ["brain", "other"]
idx
]
# ica.apply() changes the Raw object in-place
apply(raw, exclude=exclude_idx)
ica.
# record the bad ICs in bad_ICs.txt
with open(output_dir + '/bad_ICs.txt', 'a+') as f:
= f.write(file + '\t' + str(exclude_idx) + '\n')
_
# save data after ICA
+ file.split('.')[0]+ '_ica.fif') raw.save(output_dir
5 Segmentation
segmenting continuous EEG into epochs - re-reference - segmentation
5.1 parameters
#### parameters ####
# directory
= os.getcwd() + '/preprocessed/3_ica/'
input_dir = os.getcwd() + '/preprocessed/4_erp_epochs/' # for ERP
output_dir # create a folder if the folder doesn't exist
=True)
os.makedirs(output_dir, exist_ok
# reject data
= [
reject_subs
]
# Epoch window:
= -0.2; erp_t_end = 0.8
erp_t_start = (-0.2, 0)
baseline
# criteria to reject epoch
# reject_criteria = dict(eeg = 100e-6) # 100 µV
# reject_criteria = dict(eeg = 150e-6) # 150 µV
= dict(eeg=200e-6) # 200 µV reject_criteria
# epochs for ERP
# initialize a list for subjects with too many bad trials
= []
too_many_bad_trial_subjects
# get file names
= os.listdir(input_dir)
all_input = os.listdir(output_dir)
all_output
#### re-reference, and epoch ####
for file in all_input:
if file.endswith("ica.fif") and (file.split('.')[0] + '_epo.fif' not in all_output):
# skip the rejected subject
if file.split('_')[1] in reject_subs:
continue
# read in data
= mne.io.read_raw_fif(input_dir + file, preload = True)
raw
# average-mastoids re-reference
= ['TP9', 'TP10'])
raw.set_eeg_reference(ref_channels
#### this is for source calculation ####
# filter the data, optional
# raw = raw.filter(l_freq=None, h_freq=30)
# sphere = mne.make_sphere_model('auto', 'auto', raw.info)
# src = mne.setup_volume_source_space(sphere=sphere, exclude=30., pos=15.)
# forward = mne.make_forward_solution(raw.info, trans=None, src=src, bem=sphere)
# raw = raw.set_eeg_reference('REST', forward=forward)
########################################
# get event info for segmentation
= mne.events_from_annotations(raw, verbose='WARNING')
events_from_annot, event_dict
# segmentation for ERP
= mne.Epochs(
epochs
raw,= events_from_annot, event_id = event_dict,
events = erp_t_start, tmax = erp_t_end,
tmin # apply baseline correction
= baseline,
baseline # remove epochs that meet the rejection criteria
= reject_criteria,
reject = True,
preload
)
# for each event, remove 0 trial events, record info, and check if a subject is bad
for k, v in event_dict.items():
# good trial count
= len(epochs[k])
trial_count
# remove 0 trial event
if trial_count==0:
del epochs.event_id[k]
# good trial rate
= round( trial_count/sum(events_from_annot[:,2]==v), 2 )
goodTrial_rate
# record epoch summary
with open(output_dir + 'epoch_summary.txt', 'a+') as f:
=f.write(file.split('_')[1] + '\t' + k + '\t' + str(trial_count) + '\t' + str(goodTrial_rate) + '\n')
_
# mark a subject bad if any condition has fewer than 1/2 trials
if ( goodTrial_rate < 0.5 ):
# mark the subject file as bad
if file.split('_')[1] not in too_many_bad_trial_subjects:
file.split('_')[1])
too_many_bad_trial_subjects.append(
# save single subject file
+ file.split('.')[0] + '_epo.fif',
epochs.save(output_dir =True)
overwrite
# export the record of bad subjects for ERP
with open(output_dir + 'too_many_bad_trial_subjects.txt', 'w') as file:
# Write each item in the list to the file
for item in too_many_bad_trial_subjects:
file.write(item + '\n')
6 ERP
6.1 parameters
# directory
= os.getcwd() + '/preprocessed/4_erp_epochs/'
input_dir = os.getcwd() + '/preprocessed/5_averaged/'
output_dir # create a folder if the folder doesn't exist
=True) os.makedirs(output_dir, exist_ok
#### get ERP ####
# get file names
= os.listdir(input_dir)
all_input = os.listdir(output_dir)
all_output
# initialize a dictionary to store data
= {}
all_evokeds
# bad subjects with 0 good trials in any condition
= [
bad_subs
]
# for each file
for file in all_input:
if file.endswith("_epo.fif"):
# extract subject number
= file.split('_')[1]
subject
# skip rejected subjects
if subject in bad_subs:
continue
# read in data
= mne.read_epochs(input_dir + file, preload = True)
epochs
# average | get ERP for each condition
= epochs.average(by_event_type=True)
evoked
# initialize dictionary for single-subject ERP
file.split('_')[1]] = {}
all_evokeds[
# add key for each condition for analysis
for cond in evoked:
# append the evoked data to the dictioncary of evoked data
file.split('_')[1]][cond.comment] = cond
all_evokeds[
# Saving the ERP data:
with open(output_dir + '/all_evokeds.pkl', 'wb') as f:
pickle.dump(all_evokeds, f)# del all_evokeds
7 Visulization
7.1 parameters
# directory
= os.getcwd() + '/preprocessed/5_averaged/'
input_dir
# participants to exclude
= [] exclude_ppts
7.2 single-participant, single-condition butterfly
# read in the ERP data:
with open(input_dir + '/all_evokeds.pkl', 'rb') as file: # Python 3: open(..., 'rb')
= pickle.load(file)
all_evokeds
# get data
= all_evokeds['836']['local_noswitch_lunchbox']
evoked
# waveform
evoked.plot()
# scalp topography
= [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
times =times, colorbar=True)
evoked.plot_topomap(times
plt.show()
7.3 topographical subplots
# read in the ERP data:
with open(input_dir + '/all_evokeds.pkl', 'rb') as file:
= pickle.load(file)
all_evokeds
# get the list of all participants that came this far
= list(all_evokeds.keys())
all_ppts
# get participants that meet criteria
= []
sub_ppts for ppt in all_ppts:
# if it is not in the bad subject list #
if ppt not in exclude_ppts:
# append that subject to the list
sub_ppts.append(ppt)
# extract ERPs for each condition
= []
local_switched = []
local_noswitch = []
mando_switched = []
mando_noswitch
# for each participant
for ppt in sub_ppts:
# extract item labels
= all_evokeds[ppt].keys()
items
for cond in ['local_switched', 'local_noswitch', 'mando_switched', 'mando_noswitch']:
# get condition list
= [ x for x in items if x.rsplit('_', 1)[0]==cond ]
cond_list
# compute erp
= mne.combine_evoked([all_evokeds[ppt][x] for x in cond_list],
tmp='equal')
weights# append erp to list
eval(cond).append(tmp)
# add erp data to dictionary for plotting
= {}
evokeds for cond in ['local_switched', 'local_noswitch', 'mando_switched', 'mando_noswitch']:
= eval(cond)
evokeds[cond]
################################
#### Topographical subplots ####
# figure title for the waveform
= '(n = ' + str(len(sub_ppts)) + ')'
waveform_title
# waveforms across scalp topo
# NOTE: I don't know how to save these plots using the code
= mne.viz.plot_compare_evokeds(
fig
evokeds,='topo',
axes# picks=pick_chans,
# combine="mean",
=True,
show_sensors# colors=colors,
= waveform_title,
title # ylim=dict(eeg=[-5, 5]),
="ms",
time_unit=False,
show;
)##############################