Skip to content

Commit 08d23d2

Browse files
committed
Confusion Mat. DONE
1 parent 79d7c4a commit 08d23d2

6 files changed

+95
-14
lines changed
3.44 KB
Binary file not shown.
87 Bytes
Binary file not shown.
-5 Bytes
Binary file not shown.

melNET_confusion_mat.py

+82-7
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
# importing libs
22
from tkinter import *
33
import os
4+
import csv
45

56

67
#######################################################################################################################
78
# Main >>>
8-
def main(classes=['cat', 'dog']):
9-
# Window Creation: Train to melNET
9+
def main(classes):
10+
# Window Creation: Create Confusion Matrix
1011
root = Tk()
11-
root.title("Save Confusion Matrix") # Title
12+
root.title("Confusion Matrix Creation") # Title
1213
global applied
1314
applied = False
1415

@@ -55,21 +56,95 @@ def quit_button_applied():
5556
exit()
5657

5758
if class_1:
58-
pos_class = classes[0]
59+
pos_class = classes[0].upper()
60+
neg_class = classes[1].upper()
5961
else:
60-
pos_class = classes[1]
62+
pos_class = classes[1].upper()
63+
neg_class = classes[0].upper()
6164

65+
fold_num = 5
6266
if five_fold:
67+
for fold in range(fold_num):
68+
root = os.getcwd() + "/aug_Data/Five_Fold_(Aug)/Fold_"+str(fold+1)+"/Test"
69+
result_path = root + "/_result.csv"
70+
conf_mat_path = root + "/conf_mat.csv"
71+
error_path = root + "/error_analysis.csv"
72+
err_thresh = 70
73+
conf_mat_make(result_path, conf_mat_path, error_path, err_thresh, pos_class, neg_class)
6374

6475
else:
65-
res_root = os.getcwd() + "/aug_Data/Single_Fold_(Aug)/Test"
66-
76+
result_path = os.getcwd() + "/aug_Data/Single_Fold_(Aug)/Test/_result.csv"
77+
conf_mat_path = os.getcwd() + "/aug_Data/Single_Fold_(Aug)/Test/conf_mat.csv"
78+
error_path = os.getcwd() + "/aug_Data/Single_Fold_(Aug)/Test/error_analysis.csv"
79+
err_thresh = 70
80+
conf_mat_make(result_path, conf_mat_path, error_path, err_thresh, pos_class, neg_class)
6781

6882
else:
6983
print("Confusion Matrix is not being SAVED!")
7084
exit()
7185

7286

87+
def conf_mat_make(result_path, conf_mat_path, error_path, err_thresh, pos_class, neg_class):
88+
# Result Evaluation
89+
conf_mat = open(conf_mat_path, 'w')
90+
error = open(error_path, 'w')
91+
92+
for thresh in range(0, 105, 5):
93+
tp = fp = tn = fn = 0
94+
pos_idx = 1
95+
96+
with open(result_path) as result:
97+
result_reader = csv.reader(result, delimiter=',')
98+
line_count = 0
99+
100+
for row in result_reader:
101+
# Setting Headers and positive class
102+
if line_count == 0:
103+
if thresh == 0:
104+
conf_mat.write('Threshold, TP, FP, TN, FN, Sensitivity, Specificity, Accuracy\n')
105+
error.write(f'{", ".join(row)}' + '\n')
106+
if pos_class == str(row[2]):
107+
pos_idx = 2
108+
line_count += 1
109+
110+
# Going through every row (starting from 2nd one) and evaluate
111+
else:
112+
truth = row[4].upper()
113+
if (float(row[pos_idx]) * 100) >= thresh:
114+
decision = pos_class
115+
if decision == truth:
116+
tp += 1
117+
else:
118+
fp += 1
119+
if thresh == err_thresh:
120+
error.write(f'{", ".join(row)}' + '\n')
121+
else:
122+
decision = neg_class
123+
if decision == truth:
124+
tn += 1
125+
else:
126+
fn += 1
127+
if thresh == 50:
128+
error.write(f'{", ".join(row)}' + '\n')
129+
130+
# Calculation
131+
acc = (tp + tn) / (tp + fp + tn + fn)
132+
sen = tp / (tp + fn)
133+
spe = tn / (tn + fp)
134+
# pre = tp / (tp+fp)
135+
136+
'''print(">>> Threshold: " + str(thresh))
137+
print('TP: ' + str(tp) + ', FP: ' + str(fp) + ', TN: ' + str(tn) + ', FN: ' + str(fn))
138+
print('Accuracy: ' + str(acc))
139+
print('Sensitivity: ' + str(sen))
140+
print('Specificity: ' + str(spe))
141+
# print('Precision: ' + str(pre))'''
142+
143+
# Storing calculations in a file
144+
conf_mat.write(str(thresh*0.01) + ',' + str(tp) + ',' + str(fp) + ',' + str(tn) + ',' + str(fn) + ',' + str(sen)
145+
+ ',' + str(spe) + ',' + str(acc) + '\n')
146+
147+
73148
#######################################################################################################################
74149
# Main Call Func. >>>
75150
if __name__ == "__main__":

melNET_test.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# importing scripts
22
import melNET_test_src
3+
import melNET_confusion_mat
34

45
# importing libs
56
import os
@@ -66,8 +67,8 @@ def quit_button_applied():
6667
for folder in folder_names:
6768
ground_truth = re.sub(r'[^a-z0-9]+', ' ', str(folder).lower())
6869
class_path = test_path + "/" + folder
69-
melNET_test_src.main(class_path, weight_path, save_performance, show_detection,
70-
ground_truth, test_path)
70+
classes = melNET_test_src.main(class_path, weight_path, save_performance, show_detection,
71+
ground_truth, test_path)
7172
if save_performance:
7273
merge_csv(test_path)
7374
else:
@@ -84,7 +85,8 @@ def quit_button_applied():
8485
for folder in folder_names:
8586
ground_truth = re.sub(r'[^a-z0-9]+', ' ', str(folder).lower())
8687
class_path = test_path + "/" + folder
87-
melNET_test_src.main(class_path, weight_path, save_performance, show_detection, ground_truth, test_path)
88+
classes = melNET_test_src.main(class_path, weight_path, save_performance, show_detection,
89+
ground_truth, test_path)
8890
if save_performance:
8991
merge_csv(test_path)
9092
else:
@@ -98,6 +100,9 @@ def quit_button_applied():
98100
else:
99101
print("melNET has not been TESTED!")
100102
exit()
103+
104+
if save_performance and len(classes) == 2:
105+
melNET_confusion_mat.main(classes)
101106
#######################################################################################################################
102107

103108

melNET_test_src.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def main(test_path, weight_path, save_performance, show_detection, ground_truth,
6565
f.write(',')
6666
f.write('Prediction, Truth, Match\n')
6767

68-
6968
with tf.Session() as sess:
7069
# for each file in the test images directory . . .
7170
for fileName in os.listdir(TEST_IMAGES_DIR):
@@ -124,8 +123,8 @@ def main(test_path, weight_path, save_performance, show_detection, ground_truth,
124123
strClassification = classifications[prediction]
125124

126125
# if the classification (obtained from the directory name) ends with the letter "s", remove the "s" to change from plural to singular
127-
if strClassification.endswith("s"):
128-
strClassification = strClassification[:-1]
126+
#if strClassification.endswith("s"):
127+
#strClassification = strClassification[:-1]
129128
# end if
130129

131130
# get confidence, then get confidence rounded to 2 places after the decimal
@@ -145,7 +144,7 @@ def main(test_path, weight_path, save_performance, show_detection, ground_truth,
145144
if show_detection:
146145
cv2.imshow(fileName, openCVImage)
147146
# pause
148-
cv2.waitKey(1000)
147+
cv2.waitKey(0)
149148
# after a key is pressed, close the current window to prep for the next time around
150149
cv2.destroyAllWindows()
151150
# mark that we've show the most likely prediction at this point so the additional information in
@@ -161,6 +160,8 @@ def main(test_path, weight_path, save_performance, show_detection, ground_truth,
161160
if save_performance:
162161
f.close()
163162

163+
return classifications
164+
164165
# write the graph to file so we can view with TensorBoard
165166
'''tfFileWriter = tf.summary.FileWriter(os.getcwd())
166167
tfFileWriter.add_graph(sess.graph)

0 commit comments

Comments
 (0)