@@ -377,8 +377,6 @@ def PreProcess(raw, event_id, plot_psd=False, filter_data=True,
377
377
emcp_raw = False , emcp_epochs = False , epoch_decim = 1 , plot_electrodes = False ,
378
378
plot_erp = False ):
379
379
380
-
381
-
382
380
sfreq = raw .info ['sfreq' ]
383
381
#create new output freq for after epoch or wavelet decim
384
382
nsfreq = sfreq / epoch_decim
@@ -459,9 +457,10 @@ def PreProcess(raw, event_id, plot_psd=False, filter_data=True,
459
457
460
458
def FeatureEngineer (epochs , model_type = 'NN' ,
461
459
frequency_domain = False ,
462
- normalization = True , electrode_median = False ,
463
- wavelet_decim = 1 ,flims = (3 ,30 ),
464
- f_bins = 20 ,wave_cycles = 3 ,
460
+ normalization = False , electrode_median = False ,
461
+ wavelet_decim = 1 , flims = (3 ,30 ), include_phase = False ,
462
+ f_bins = 20 , wave_cycles = 3 ,
463
+ wavelet_electrodes = [11 ,12 ,13 ,14 ,15 ],
465
464
spect_baseline = [- 1 ,- .5 ],
466
465
test_split = 0.2 , val_split = 0.2 ,
467
466
random_seed = 1017 , watermark = False ):
@@ -492,51 +491,50 @@ def FeatureEngineer(epochs, model_type='NN',
492
491
f_low = flims [0 ]
493
492
f_high = flims [1 ]
494
493
frequencies = np .linspace (f_low , f_high , f_bins , endpoint = True )
495
- eeg_chans = pick_types (epochs .info ,eeg = True ,eog = False )
496
494
497
- ####
498
- ## Condition0 ##
499
- print ('Computing Morlet Wavelets on ' + event_names [0 ])
500
- tfr0 = tfr_morlet (epochs [event_names [0 ]], freqs = frequencies ,
501
- n_cycles = wave_cycles , return_itc = False ,
502
- picks = eeg_chans , average = False ,
503
- decim = wavelet_decim )
504
- tfr0 = tfr0 .apply_baseline (spect_baseline ,mode = 'mean' )
505
- #reshape data
506
- stim_onset = np .argmax (tfr0 .times > 0 )
507
- feats .new_times = tfr0 .times [stim_onset :]
508
-
509
- #move electrodes last
510
- cond0_power_out = np .moveaxis (tfr0 .data [:,:,:,stim_onset :],1 ,3 )
511
- # move time second
512
- cond0_power_out = np .moveaxis (cond0_power_out ,1 ,2 )
513
- ####
514
-
515
- ####
516
- ## Condition1 ##
517
- print ('Computing Morlet Wavelets on ' + event_names [1 ])
518
- tfr1 = tfr_morlet (epochs [event_names [1 ]], freqs = frequencies ,
519
- n_cycles = wave_cycles , return_itc = False ,
520
- picks = eeg_chans , average = False ,
521
- decim = wavelet_decim )
522
- tfr1 = tfr1 .apply_baseline (spect_baseline ,mode = 'mean' )
523
- #reshape data
524
- cond1_power_out = np .moveaxis (tfr1 .data [:,:,:,stim_onset :],1 ,3 )
525
- cond1_power_out = np .moveaxis (cond1_power_out ,1 ,2 ) # move time second
526
- ####
527
-
528
- print ('Condition one trials: ' + str (len (cond1_power_out )))
529
- print (event_names [1 ] + ' Time Points: ' + str (len (feats .new_times )))
530
- print (event_names [1 ] + ' Frequencies: ' + str (len (tfr1 .freqs )))
531
- print ('Condition zero trials: ' + str (len (cond0_power_out )))
532
- print (event_names [0 ] + ' Time Points: ' + str (len (feats .new_times )))
533
- print (event_names [0 ] + ' Frequencies: ' + str (len (tfr0 .freqs )))
495
+ if wavelet_electrodes == 'all' :
496
+ wavelet_electrodes = pick_types (epochs .info ,eeg = True ,eog = False )
534
497
498
+ #type of output from wavelet analysis
499
+ if include_phase :
500
+ tfr_output_type = 'complex'
501
+ else :
502
+ tfr_output_type = 'power'
503
+
504
+ tfr_dict = {}
505
+ for event in event_names :
506
+ print ('Computing Morlet Wavelets on ' + event )
507
+ tfr_temp = tfr_morlet (epochs [event ], freqs = frequencies ,
508
+ n_cycles = wave_cycles , return_itc = False ,
509
+ picks = wavelet_electrodes , average = False ,
510
+ decim = wavelet_decim , output = tfr_output_type )
511
+ tfr_temp = tfr_temp .apply_baseline (spect_baseline ,mode = 'mean' )
512
+ stim_onset = np .argmax (tfr_temp .times > 0 )
513
+ power_out_temp = np .moveaxis (tfr_temp .data [:,:,:,stim_onset :],1 ,3 )
514
+ power_out_temp = np .moveaxis (power_out_temp ,1 ,2 )
515
+ print (event + ' trials: ' + str (len (power_out_temp )))
516
+ tfr_dict [event ] = power_out_temp
517
+
518
+ #reshape data (sloppy but just use the last temp tfr)
519
+ feats .new_times = tfr_temp .times [stim_onset :]
520
+
521
+ for event in event_names :
522
+ print (event + ' Time Points: ' + str (len (feats .new_times )))
523
+ print (event + ' Frequencies: ' + str (len (tfr_temp .freqs )))
535
524
536
525
#Construct X and Y
537
- X = np .append (cond0_power_out ,cond1_power_out ,0 );
538
- Y_class = np .append (np .zeros (len (cond0_power_out )),
539
- np .ones (len (cond1_power_out )),0 )
526
+ for ievent ,event in enumerate (event_names ):
527
+ if ievent == 0 :
528
+ X = tfr_dict [event ]
529
+ Y_class = np .zeros (len (tfr_dict [event ]))
530
+ else :
531
+ X = np .append (X ,tfr_dict [event ],0 )
532
+ Y_class = np .append (Y_class ,np .ones (len (tfr_dict [event ]))* ievent ,0 )
533
+
534
+ #concatenate real and imaginary data
535
+ if include_phase :
536
+ print ('Concatenating the real and imaginary components' )
537
+ X = np .append (np .real (X ),np .imag (X ),2 )
540
538
541
539
if electrode_median :
542
540
print ('Computing Median over electrodes' )
0 commit comments