@@ -28,17 +28,16 @@ def quit_button_applied():
28
28
29
29
# Variables
30
30
_five_fold = BooleanVar ()
31
- _learning_rate = IntVar (0 )
31
+ _learning_rate1 = BooleanVar ()
32
+ _learning_rate2 = BooleanVar ()
33
+ _learning_rate3 = BooleanVar ()
34
+ _learning_rate4 = BooleanVar ()
32
35
_batch_size = IntVar (0 )
33
36
_iteration = IntVar (0 )
34
37
35
38
# UI
36
39
Checkbutton (root , text = "Five-Fold (Default: Single-Fold)" , variable = _five_fold ).grid (row = 1 , column = 3 , sticky = W )
37
40
38
- Label (root , text = "Learning Rate: " ).grid (row = 2 , column = 0 )
39
- Label (root , text = "(Default: 0.01)" ).grid (row = 3 , column = 0 )
40
- Entry (root , textvariable = _learning_rate ).grid (row = 2 , column = 1 )
41
-
42
41
Label (root , text = "Batch Size: " ).grid (row = 4 , column = 0 )
43
42
Label (root , text = "(Default: 100)" ).grid (row = 5 , column = 0 )
44
43
Entry (root , textvariable = _batch_size ).grid (row = 4 , column = 1 )
@@ -47,32 +46,62 @@ def quit_button_applied():
47
46
Label (root , text = "(Default: 4000)" ).grid (row = 7 , column = 0 )
48
47
Entry (root , textvariable = _iteration ).grid (row = 6 , column = 1 )
49
48
50
- Button (root , text = "Quit" , command = quit_button_applied , width = 15 ).grid (row = 8 , column = 2 , sticky = W )
51
- Button (root , text = "Apply" , command = apply_button_applied , width = 15 ).grid (row = 8 , column = 3 , sticky = W )
49
+ Checkbutton (root , text = "Learning Rate: 0.005" , variable = _learning_rate1 ).grid (row = 8 , column = 0 , sticky = W )
50
+ Checkbutton (root , text = "Learning Rate: 0.001" , variable = _learning_rate2 ).grid (row = 9 , column = 0 , sticky = W )
51
+ Checkbutton (root , text = "Learning Rate: 0.01" , variable = _learning_rate3 ).grid (row = 10 , column = 0 , sticky = W )
52
+ Checkbutton (root , text = "Learning Rate: 0.015" , variable = _learning_rate4 ).grid (row = 11 , column = 0 , sticky = W )
53
+ Label (root , text = "If NONE or more than one is chosen, default Learning Rate is 0.01" ).grid (row = 12 , column = 0 )
54
+
55
+ Button (root , text = "Quit" , command = quit_button_applied , width = 15 ).grid (row = 14 , column = 2 , sticky = W )
56
+ Button (root , text = "Apply" , command = apply_button_applied , width = 15 ).grid (row = 14 , column = 3 , sticky = W )
52
57
53
58
root .mainloop ()
54
59
55
60
if applied :
56
61
five_fold = _five_fold .get ()
57
- learning_rate = _learning_rate .get ()
62
+
63
+ learning_rate1 = _learning_rate1 .get ()
64
+ learning_rate2 = _learning_rate2 .get ()
65
+ learning_rate3 = _learning_rate3 .get ()
66
+ learning_rate4 = _learning_rate4 .get ()
67
+
58
68
batch_size = _batch_size .get ()
59
69
iteration = _iteration .get ()
60
70
61
71
# Default Values
62
- if learning_rate is 0 :
63
- learning_rate = 0.01
64
72
if batch_size is 0 :
65
73
batch_size = 100
66
74
if iteration is 0 :
67
75
iteration = 4000
68
76
77
+ if learning_rate1 :
78
+ learning_rate = 0.005
79
+ elif learning_rate2 :
80
+ learning_rate = 0.001
81
+ elif learning_rate3 :
82
+ learning_rate = 0.01
83
+ elif learning_rate4 :
84
+ learning_rate = 0.015
85
+
86
+ learning_rate_sum = int (learning_rate1 ) + int (learning_rate2 ) + int (learning_rate3 ) + int (learning_rate4 )
87
+
88
+ if learning_rate_sum > 1 or learning_rate_sum == 0 :
89
+ learning_rate = 0.01
90
+
69
91
# Managing directories
70
92
accessories_path = os .getcwd () + "/Accessories" # same name in malNET_train_src
71
93
if tf .gfile .Exists (accessories_path ):
72
94
tf .gfile .DeleteRecursively (accessories_path )
73
95
74
96
tf .gfile .MakeDirs (accessories_path )
75
97
98
+ # Save hyper-parameters
99
+ hyper_path = accessories_path + "/hyper_parameters.csv"
100
+ hyper = open (hyper_path , 'w' )
101
+ hyper .write ('Learning Rate, Iterations, Batch Size, Five-Fold\n ' + str (learning_rate ) + ',' + str (iteration ) +
102
+ ',' + str (batch_size ) + ',' + str (five_fold ))
103
+ hyper .close ()
104
+
76
105
# Getting Train-data directory
77
106
aug_data_path = os .getcwd () + "/aug_Data"
78
107
if tf .gfile .Exists (aug_data_path ):
@@ -84,7 +113,8 @@ def quit_button_applied():
84
113
weight_path = accessories_path + "/Five_Fold_Trained/Trained_Fold_" + str (num + 1 )
85
114
tf .gfile .MakeDirs (weight_path )
86
115
if tf .gfile .Exists (train_path ):
87
- melNET_train_src .main (train_path , weight_path , learning_rate , batch_size , iteration )
116
+ melNET_train_src .main (train_path , weight_path , float (learning_rate ), int (batch_size ),
117
+ int (iteration ))
88
118
else :
89
119
print ("Training Data is NOT available!" )
90
120
else : # Single-Fold
@@ -93,7 +123,7 @@ def quit_button_applied():
93
123
weight_path = accessories_path + "/Single_Fold_Trained"
94
124
tf .gfile .MakeDirs (weight_path )
95
125
if tf .gfile .Exists (train_path ):
96
- melNET_train_src .main (train_path , weight_path , learning_rate , batch_size , iteration )
126
+ melNET_train_src .main (train_path , weight_path , float ( learning_rate ), int ( batch_size ), int ( iteration ) )
97
127
else :
98
128
print ("Training Data is NOT available!" )
99
129
else :
0 commit comments