Skip to content

Commit 24dda56

Browse files
committed
created flag in feature engineer to concatenate imaginary values of the wavelet transform in to the X values, tested in muse and simulated, works well, added test
1 parent 08e0a00 commit 24dda56

File tree

4 files changed

+79
-53
lines changed

4 files changed

+79
-53
lines changed

Muse_P3example.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,19 @@
66
#python
77
from utils import *
88
data_dir = 'visual/cueing'
9-
subs = [101,102]
9+
subs = [101, 102, 103, 104, 105, 106, 108, 109, 110, 111, 112,
10+
202, 203, 204, 205, 207, 208, 209, 210, 211,
11+
301, 302, 303, 304, 305, 306, 307, 308, 309]
12+
subs = [101, 102, 103, 104]
1013
nsesh = 2
1114
event_id = {'LeftCue': 1,'RightCue': 2}
1215
#Load Data
1316
raw = LoadMuseData(subs,nsesh,data_dir)
1417
#Pre-Process EEG Data
15-
epochs = PreProcess(raw,event_id)
18+
epochs = PreProcess(raw,event_id,epoch_time=(-1,2))
1619
#Engineer Features for Model
17-
feats = FeatureEngineer(epochs)
20+
feats = FeatureEngineer(epochs,frequency_domain=True,
21+
include_phase=True)
1822
#Create Model
1923
model,_ = CreateModel(feats)
2024
#Train with validation, then Test

simulate.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@
99
# raw,event_id = SimulateRaw(amp1=10, amp2=5, freq=2., batch=1)
1010
# raw.save(raw_filename,overwrite=True)
1111

12-
raw,event_id = SimulateRaw(amp1=100, amp2=50, freq=2., batch=4)
13-
epochs = PreProcess(raw,event_id,filter_data=False,plot_erp=True)
12+
raw,event_id = SimulateRaw(amp1=100, amp2=50, freq=2., batch=1)
13+
epochs = PreProcess(raw,event_id,filter_data=False,plot_erp=False,
14+
epoch_time=(-1,2))
1415

1516
pick = 33
1617
for event in event_id.keys():
1718
fig = plt.imshow(epochs[event]._data[:,pick,:])
1819
plt.show()
1920

20-
feats = FeatureEngineer(epochs,model_type='CNN')
21-
model,_ = CreateModel(feats, units=[256,128,128,64,32,16])
21+
feats = FeatureEngineer(epochs,model_type='NN',
22+
frequency_domain=True,include_phase=True)
23+
model,_ = CreateModel(feats, units=[64,32,16])
2224
TrainTestVal(model,feats)
2325

tests.py

+22
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,28 @@ def test_simulate_raw(self):
7676

7777
self.assertLess(data['acc'], 1)
7878

79+
def test_frequencydomain_complex(self):
80+
"""
81+
Testing simulated data pipeline.
82+
"""
83+
# Simulate Data
84+
raw,event_id = SimulateRaw(amp1=50, amp2=60, freq=1.)
85+
86+
# Pre-Process EEG Data
87+
epochs = PreProcess(raw,event_id)
88+
89+
# Engineer Features for Model
90+
feats = FeatureEngineer(epochs,frequency_domain=True,
91+
include_phase=True)
92+
93+
# Create Model
94+
model, _ = CreateModel(feats, units=[16,16])
95+
96+
# Train with validation, then Test
97+
model, data = TrainTestVal(model,feats,
98+
train_epochs=1,show_plots=False)
99+
100+
self.assertLess(data['acc'], 1)
79101

80102
if __name__ == '__main__':
81103
unittest.main()

utils.py

+44-46
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,6 @@ def PreProcess(raw, event_id, plot_psd=False, filter_data=True,
377377
emcp_raw=False, emcp_epochs=False, epoch_decim=1, plot_electrodes=False,
378378
plot_erp=False):
379379

380-
381-
382380
sfreq = raw.info['sfreq']
383381
#create new output freq for after epoch or wavelet decim
384382
nsfreq = sfreq/epoch_decim
@@ -459,9 +457,10 @@ def PreProcess(raw, event_id, plot_psd=False, filter_data=True,
459457

460458
def FeatureEngineer(epochs, model_type='NN',
461459
frequency_domain=False,
462-
normalization=True, electrode_median=False,
463-
wavelet_decim=1,flims=(3,30),
464-
f_bins=20,wave_cycles=3,
460+
normalization=False, electrode_median=False,
461+
wavelet_decim=1, flims=(3,30), include_phase=False,
462+
f_bins=20, wave_cycles=3,
463+
wavelet_electrodes = [11,12,13,14,15],
465464
spect_baseline=[-1,-.5],
466465
test_split = 0.2, val_split = 0.2,
467466
random_seed=1017, watermark = False):
@@ -492,51 +491,50 @@ def FeatureEngineer(epochs, model_type='NN',
492491
f_low = flims[0]
493492
f_high = flims[1]
494493
frequencies = np.linspace(f_low, f_high, f_bins, endpoint=True)
495-
eeg_chans = pick_types(epochs.info,eeg=True,eog=False)
496494

497-
####
498-
## Condition0 ##
499-
print('Computing Morlet Wavelets on ' + event_names[0])
500-
tfr0 = tfr_morlet(epochs[event_names[0]], freqs=frequencies,
501-
n_cycles=wave_cycles, return_itc=False,
502-
picks=eeg_chans, average=False,
503-
decim=wavelet_decim)
504-
tfr0 = tfr0.apply_baseline(spect_baseline,mode='mean')
505-
#reshape data
506-
stim_onset = np.argmax(tfr0.times>0)
507-
feats.new_times = tfr0.times[stim_onset:]
508-
509-
#move electrodes last
510-
cond0_power_out = np.moveaxis(tfr0.data[:,:,:,stim_onset:],1,3)
511-
# move time second
512-
cond0_power_out = np.moveaxis(cond0_power_out,1,2)
513-
####
514-
515-
####
516-
## Condition1 ##
517-
print('Computing Morlet Wavelets on ' + event_names[1])
518-
tfr1 = tfr_morlet(epochs[event_names[1]], freqs=frequencies,
519-
n_cycles=wave_cycles, return_itc=False,
520-
picks=eeg_chans, average=False,
521-
decim=wavelet_decim)
522-
tfr1 = tfr1.apply_baseline(spect_baseline,mode='mean')
523-
#reshape data
524-
cond1_power_out = np.moveaxis(tfr1.data[:,:,:,stim_onset:],1,3)
525-
cond1_power_out = np.moveaxis(cond1_power_out,1,2) # move time second
526-
####
527-
528-
print('Condition one trials: ' + str(len(cond1_power_out)))
529-
print(event_names[1] + ' Time Points: ' + str(len(feats.new_times)))
530-
print(event_names[1] + ' Frequencies: ' + str(len(tfr1.freqs)))
531-
print('Condition zero trials: ' + str(len(cond0_power_out)))
532-
print(event_names[0] + ' Time Points: ' + str(len(feats.new_times)))
533-
print(event_names[0] + ' Frequencies: ' + str(len(tfr0.freqs)))
495+
if wavelet_electrodes == 'all':
496+
wavelet_electrodes = pick_types(epochs.info,eeg=True,eog=False)
534497

498+
#type of output from wavelet analysis
499+
if include_phase:
500+
tfr_output_type = 'complex'
501+
else:
502+
tfr_output_type = 'power'
503+
504+
tfr_dict = {}
505+
for event in event_names:
506+
print('Computing Morlet Wavelets on ' + event)
507+
tfr_temp = tfr_morlet(epochs[event], freqs=frequencies,
508+
n_cycles=wave_cycles, return_itc=False,
509+
picks=wavelet_electrodes, average=False,
510+
decim=wavelet_decim, output=tfr_output_type)
511+
tfr_temp = tfr_temp.apply_baseline(spect_baseline,mode='mean')
512+
stim_onset = np.argmax(tfr_temp.times>0)
513+
power_out_temp = np.moveaxis(tfr_temp.data[:,:,:,stim_onset:],1,3)
514+
power_out_temp = np.moveaxis(power_out_temp,1,2)
515+
print(event + ' trials: ' + str(len(power_out_temp)))
516+
tfr_dict[event] = power_out_temp
517+
518+
#reshape data (sloppy but just use the last temp tfr)
519+
feats.new_times = tfr_temp.times[stim_onset:]
520+
521+
for event in event_names:
522+
print(event + ' Time Points: ' + str(len(feats.new_times)))
523+
print(event + ' Frequencies: ' + str(len(tfr_temp.freqs)))
535524

536525
#Construct X and Y
537-
X = np.append(cond0_power_out,cond1_power_out,0);
538-
Y_class = np.append(np.zeros(len(cond0_power_out)),
539-
np.ones(len(cond1_power_out)),0)
526+
for ievent,event in enumerate(event_names):
527+
if ievent == 0:
528+
X = tfr_dict[event]
529+
Y_class = np.zeros(len(tfr_dict[event]))
530+
else:
531+
X = np.append(X,tfr_dict[event],0)
532+
Y_class = np.append(Y_class,np.ones(len(tfr_dict[event]))*ievent,0)
533+
534+
#concatenate real and imaginary data
535+
if include_phase:
536+
print('Concatenating the real and imaginary components')
537+
X = np.append(np.real(X),np.imag(X),2)
540538

541539
if electrode_median:
542540
print('Computing Median over electrodes')

0 commit comments

Comments
 (0)