Skip to content

Commit 6f71bf1

Browse files
committed
delete useless
1 parent bf38ada commit 6f71bf1

File tree

1 file changed

+5
-78
lines changed

1 file changed

+5
-78
lines changed

tools.py

Lines changed: 5 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,12 @@
11
import torch
22
import torch.nn.functional as F
3-
from torch.utils.data import DataLoader
43
from torch.utils.tensorboard import SummaryWriter
5-
import torchvision.transforms as transforms
64
import torchvision.utils as utils
7-
from dataset import CSL_Isolated
8-
from models.Conv3D import resnet18, resnet34, resnet50, r2plus1d_18
9-
import os
105
import cv2
11-
import argparse
126
from datetime import datetime
137
import numpy as np
14-
from numpy import savetxt
158
import matplotlib.pyplot as plt
16-
from sklearn.metrics import accuracy_score, confusion_matrix
9+
from sklearn.metrics import confusion_matrix
1710

1811

1912
def get_label_and_pred(model, dataloader, device):
@@ -70,7 +63,7 @@ def plot_confusion_matrix(model, dataloader, device, save_path='confmat.png', no
7063
# print(type(sorted_index[i]))
7164
print(test_set.label_to_word(int(sorted_index[i])), confmat[sorted_index[i]][sorted_index[i]])
7265
# Save to csv
73-
savetxt('matrix.csv', confmat, delimiter=',')
66+
np.savetxt('matrix.csv', confmat, delimiter=',')
7467

7568

7669
def visualize_attn(I, c):
@@ -150,74 +143,8 @@ def wer(r, h):
150143
return float(d[len(r)][len(h)]) / len(r) * 100
151144

152145

153-
# Parameters manager
154-
parser = argparse.ArgumentParser(description='Visualization')
155-
parser.add_argument('--data_path', default='/home/haodong/Data/CSL_Isolated/color_video_125000',
156-
type=str, help='Path to data')
157-
parser.add_argument('--label_path', default='/home/haodong/Data/CSL_Isolated/dictionary.txt',
158-
type=str, help='Path to labels')
159-
parser.add_argument('--model', default='resnet18',
160-
type=str, help='Model to use')
161-
parser.add_argument('--checkpoint', default='/home/haodong/Data/visualize_models/resnet18.pth',
162-
type=str, help='Path to checkpoint')
163-
parser.add_argument('--device', default='0',
164-
type=str, help='CUDA visible devices')
165-
parser.add_argument('--num_classes', default=100,
166-
type=int, help='Number of classes')
167-
parser.add_argument('--batch_size', default=16,
168-
type=int, help='Batch size')
169-
parser.add_argument('--sample_size', default=128,
170-
type=int, help='Sample size')
171-
parser.add_argument('--sample_duration', default=16,
172-
type=int, help='Sample duration')
173-
parser.add_argument('--confusion_matrix', action='store_true',
174-
help='Draw confusion matrix')
175-
parser.add_argument('--attention_map', action='store_true',
176-
help='Draw attention map')
177-
parser.add_argument('--calculate_wer', action='store_true',
178-
help='Calculate Word Error Rate')
179-
args = parser.parse_args()
180-
181-
# Use specific gpus
182-
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
183-
# Device setting
184-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
185-
186-
# Hyperparams
187-
num_classes = args.num_classes
188-
batch_size = args.batch_size
189-
sample_size = args.sample_size
190-
sample_duration = args.sample_duration
191-
192146
if __name__ == '__main__':
193-
# Load data
194-
transform = transforms.Compose([transforms.Resize([sample_size, sample_size]),
195-
transforms.ToTensor(),
196-
transforms.Normalize(mean=[0.5], std=[0.5])])
197-
test_set = CSL_Isolated(data_path=args.data_path, label_path=args.label_path, frames=sample_duration,
198-
num_classes=num_classes, train=False, transform=transform)
199-
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
200-
# Create model
201-
if args.model == 'resnet18':
202-
model = resnet18(pretrained=True, progress=True, sample_size=sample_size,
203-
sample_duration=sample_duration, attention=args.attention_map, num_classes=num_classes).to(device)
204-
# Run the model parallelly
205-
if torch.cuda.device_count() > 1:
206-
logger.info("Using {} GPUs".format(torch.cuda.device_count()))
207-
model = nn.DataParallel(model)
208-
# Load model
209-
model.load_state_dict(torch.load(args.checkpoint))
210-
211-
# Draw confusion matrix
212-
if args.confusion_matrix:
213-
plot_confusion_matrix(model, test_loader, device)
214-
215-
# Draw attention map
216-
if args.attention_map:
217-
plot_attention_map(model, test_loader, device)
218-
219147
# Calculate WER
220-
if args.calculate_wer:
221-
r = [1,2,3,4]
222-
h = [1,1,3,5,6]
223-
print(wer(r, h))
148+
r = [1,2,3,4]
149+
h = [1,1,3,5,6]
150+
print(wer(r, h))

0 commit comments

Comments
 (0)