|
1 | 1 | import torch
|
2 | 2 | import torch.nn.functional as F
|
3 |
| -from torch.utils.data import DataLoader |
4 | 3 | from torch.utils.tensorboard import SummaryWriter
|
5 |
| -import torchvision.transforms as transforms |
6 | 4 | import torchvision.utils as utils
|
7 |
| -from dataset import CSL_Isolated |
8 |
| -from models.Conv3D import resnet18, resnet34, resnet50, r2plus1d_18 |
9 |
| -import os |
10 | 5 | import cv2
|
11 |
| -import argparse |
12 | 6 | from datetime import datetime
|
13 | 7 | import numpy as np
|
14 |
| -from numpy import savetxt |
15 | 8 | import matplotlib.pyplot as plt
|
16 |
| -from sklearn.metrics import accuracy_score, confusion_matrix |
| 9 | +from sklearn.metrics import confusion_matrix |
17 | 10 |
|
18 | 11 |
|
19 | 12 | def get_label_and_pred(model, dataloader, device):
|
@@ -70,7 +63,7 @@ def plot_confusion_matrix(model, dataloader, device, save_path='confmat.png', no
|
70 | 63 | # print(type(sorted_index[i]))
|
71 | 64 | print(test_set.label_to_word(int(sorted_index[i])), confmat[sorted_index[i]][sorted_index[i]])
|
72 | 65 | # Save to csv
|
73 |
| - savetxt('matrix.csv', confmat, delimiter=',') |
| 66 | + np.savetxt('matrix.csv', confmat, delimiter=',') |
74 | 67 |
|
75 | 68 |
|
76 | 69 | def visualize_attn(I, c):
|
@@ -150,74 +143,8 @@ def wer(r, h):
|
150 | 143 | return float(d[len(r)][len(h)]) / len(r) * 100
|
151 | 144 |
|
152 | 145 |
|
153 |
| -# Parameters manager |
154 |
| -parser = argparse.ArgumentParser(description='Visualization') |
155 |
| -parser.add_argument('--data_path', default='/home/haodong/Data/CSL_Isolated/color_video_125000', |
156 |
| - type=str, help='Path to data') |
157 |
| -parser.add_argument('--label_path', default='/home/haodong/Data/CSL_Isolated/dictionary.txt', |
158 |
| - type=str, help='Path to labels') |
159 |
| -parser.add_argument('--model', default='resnet18', |
160 |
| - type=str, help='Model to use') |
161 |
| -parser.add_argument('--checkpoint', default='/home/haodong/Data/visualize_models/resnet18.pth', |
162 |
| - type=str, help='Path to checkpoint') |
163 |
| -parser.add_argument('--device', default='0', |
164 |
| - type=str, help='CUDA visible devices') |
165 |
| -parser.add_argument('--num_classes', default=100, |
166 |
| - type=int, help='Number of classes') |
167 |
| -parser.add_argument('--batch_size', default=16, |
168 |
| - type=int, help='Batch size') |
169 |
| -parser.add_argument('--sample_size', default=128, |
170 |
| - type=int, help='Sample size') |
171 |
| -parser.add_argument('--sample_duration', default=16, |
172 |
| - type=int, help='Sample duration') |
173 |
| -parser.add_argument('--confusion_matrix', action='store_true', |
174 |
| - help='Draw confusion matrix') |
175 |
| -parser.add_argument('--attention_map', action='store_true', |
176 |
| - help='Draw attention map') |
177 |
| -parser.add_argument('--calculate_wer', action='store_true', |
178 |
| - help='Calculate Word Error Rate') |
179 |
| -args = parser.parse_args() |
180 |
| - |
181 |
| -# Use specific gpus |
182 |
| -os.environ["CUDA_VISIBLE_DEVICES"] = args.device |
183 |
| -# Device setting |
184 |
| -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
185 |
| - |
186 |
| -# Hyperparams |
187 |
| -num_classes = args.num_classes |
188 |
| -batch_size = args.batch_size |
189 |
| -sample_size = args.sample_size |
190 |
| -sample_duration = args.sample_duration |
191 |
| - |
192 | 146 | if __name__ == '__main__':
|
193 |
| - # Load data |
194 |
| - transform = transforms.Compose([transforms.Resize([sample_size, sample_size]), |
195 |
| - transforms.ToTensor(), |
196 |
| - transforms.Normalize(mean=[0.5], std=[0.5])]) |
197 |
| - test_set = CSL_Isolated(data_path=args.data_path, label_path=args.label_path, frames=sample_duration, |
198 |
| - num_classes=num_classes, train=False, transform=transform) |
199 |
| - test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) |
200 |
| - # Create model |
201 |
| - if args.model == 'resnet18': |
202 |
| - model = resnet18(pretrained=True, progress=True, sample_size=sample_size, |
203 |
| - sample_duration=sample_duration, attention=args.attention_map, num_classes=num_classes).to(device) |
204 |
| - # Run the model parallelly |
205 |
| - if torch.cuda.device_count() > 1: |
206 |
| - logger.info("Using {} GPUs".format(torch.cuda.device_count())) |
207 |
| - model = nn.DataParallel(model) |
208 |
| - # Load model |
209 |
| - model.load_state_dict(torch.load(args.checkpoint)) |
210 |
| - |
211 |
| - # Draw confusion matrix |
212 |
| - if args.confusion_matrix: |
213 |
| - plot_confusion_matrix(model, test_loader, device) |
214 |
| - |
215 |
| - # Draw attention map |
216 |
| - if args.attention_map: |
217 |
| - plot_attention_map(model, test_loader, device) |
218 |
| - |
219 | 147 | # Calculate WER
|
220 |
| - if args.calculate_wer: |
221 |
| - r = [1,2,3,4] |
222 |
| - h = [1,1,3,5,6] |
223 |
| - print(wer(r, h)) |
| 148 | + r = [1,2,3,4] |
| 149 | + h = [1,1,3,5,6] |
| 150 | + print(wer(r, h)) |
0 commit comments