From 5f095a6b729e12644797b770d1c4c066bfd19605 Mon Sep 17 00:00:00 2001 From: Dmitry Tarasov Date: Sat, 19 Mar 2022 01:51:24 +0300 Subject: [PATCH 1/2] fix use_cpu option for classification eval cuda trained model --- test_classification.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test_classification.py b/test_classification.py index 79ac09723..3a322e5ef 100644 --- a/test_classification.py +++ b/test_classification.py @@ -38,11 +38,13 @@ def test(model, loader, num_class=40, vote_num=1): class_acc = np.zeros((num_class, 3)) for j, (points, target) in tqdm(enumerate(loader), total=len(loader)): + vote_pool = torch.zeros(target.size()[0], num_class) + if not args.use_cpu: - points, target = points.cuda(), target.cuda() + points, target, vote_pool = points.cuda(), target.cuda(), vote_pool.cuda() points = points.transpose(2, 1) - vote_pool = torch.zeros(target.size()[0], num_class).cuda() + for _ in range(vote_num): pred, _ = classifier(points) @@ -102,7 +104,9 @@ def log_string(str): if not args.use_cpu: classifier = classifier.cuda() - checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') + torch_load_map_location = torch.device('cpu') if args.use_cpu else torch.device('cuda') + torch.load(str(experiment_dir) + '/checkpoints/best_model.pth', map_location=torch_load_map_location) + classifier.load_state_dict(checkpoint['model_state_dict']) with torch.no_grad(): From 1616ff219cc8b606bab0fa1a0e97546d77fc4c77 Mon Sep 17 00:00:00 2001 From: "d.tarasov" Date: Sat, 19 Mar 2022 09:56:12 +0300 Subject: [PATCH 2/2] fix lost checkpoint variable --- test_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_classification.py b/test_classification.py index 3a322e5ef..5e120aed5 100644 --- a/test_classification.py +++ b/test_classification.py @@ -105,7 +105,7 @@ def log_string(str): classifier = classifier.cuda() torch_load_map_location = torch.device('cpu') if args.use_cpu else torch.device('cuda') - torch.load(str(experiment_dir) + '/checkpoints/best_model.pth', map_location=torch_load_map_location) + checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth', map_location=torch_load_map_location) classifier.load_state_dict(checkpoint['model_state_dict'])