Skip to content

Commit 29ef3b5

Browse files
author
nauyan
committed
minor changes
1 parent d34e78d commit 29ef3b5

File tree

3 files changed

+21
-15
lines changed

3 files changed

+21
-15
lines changed

config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
checkpoints_dir = "checkpoints"
1515
tensorboard_logs = "run_logs"
1616

17-
encoder = "tu-tf_efficientnet_b5_ns"
17+
# encoder = "tu-tf_efficientnet_b5_ns"
18+
encoder = "tu-tf_efficientnet_b0_ns"
1819
encoder_weights = None
1920

2021
# CoNSep Weights
@@ -24,7 +25,7 @@
2425
inference_weights = "NC-Net_all_metric.pth"
2526

2627
device = "cuda"
27-
batch_size = 16
28+
batch_size = 32
2829
epochs = 1000
2930
learning_rate = 0.01
3031
watershed_threshold = 0.2

src/metrics.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,24 @@ def dice_score(pred, true):
1111
pred = F.softmax(pred.permute(0, 2, 3, 1).contiguous(), dim=-1)
1212
pred = torch.argmax(pred, dim=-1, keepdim=False)
1313

14-
true = np.copy(true.cpu().detach().numpy())
15-
pred = np.copy(pred.cpu().detach().numpy())
14+
# true = np.copy(true.cpu().detach().numpy())
15+
# pred = np.copy(pred.cpu().detach().numpy())
1616

1717
inter = true * pred
1818
denom = true + pred
19-
return (2.0 * np.sum(inter)) / np.sum(denom)
19+
20+
score = (2.0 * torch.sum(inter)) / torch.sum(denom)
21+
return score.cpu().detach().numpy()
2022

2123

2224
def mse_metric(pred, true):
2325

2426
pred = pred[:, 2, :, :]
2527
true = true[:, :, :, 1]
2628

27-
true = np.copy(true.cpu().detach().numpy())
28-
pred = np.copy(pred.cpu().detach().numpy())
29+
# true = np.copy(true.cpu().detach().numpy())
30+
# pred = np.copy(pred.cpu().detach().numpy())
2931

3032
loss = pred - true
3133
loss = (loss * loss).mean()
32-
return loss
34+
return loss.cpu().detach().numpy()

train.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ def train(train_dir, test_dir, dataset_name):
4444
train_loader = DataLoader(
4545
train_dataset,
4646
batch_size=batch_size,
47-
shuffle=True,
47+
shuffle=False,
4848
num_workers=multiprocessing.cpu_count(),
4949
pin_memory=True,
5050
persistent_workers=True,
51-
prefetch_factor=8,
51+
prefetch_factor=32,
5252
)
5353
valid_loader = DataLoader(
5454
valid_dataset,
@@ -88,8 +88,9 @@ def train(train_dir, test_dir, dataset_name):
8888
verbose=True,
8989
)
9090

91-
writer_path = "./{}/NC-Net_{}".format(config.tensorboard_logs,
92-
dataset_name)
91+
writer_path = "./{}/NC-Net_{}_{}".format(config.tensorboard_logs,
92+
config.encoder,
93+
dataset_name)
9394
writer = SummaryWriter(writer_path)
9495
min_loss = 9999
9596
max_score = 0
@@ -110,7 +111,8 @@ def train(train_dir, test_dir, dataset_name):
110111
min_loss = valid_logs[loss_fn.__name__]
111112
torch.save(
112113
model.state_dict(),
113-
"./{}/NC-Net_{}.pth".format(config.checkpoints_dir,
114+
"./{}/NC-Net_{}_{}.pth".format(config.checkpoints_dir,
115+
config.encoder,
114116
dataset_name),
115117
)
116118
last_save = i
@@ -120,8 +122,9 @@ def train(train_dir, test_dir, dataset_name):
120122
max_score = valid_logs[metrics[0].__name__]
121123
torch.save(
122124
model.state_dict(),
123-
"./{}/NC-Net_{}_metric.pth".format(config.checkpoints_dir,
124-
dataset_name),
125+
"./{}/NC-Net__{}_{}_metric.pth".format(config.checkpoints_dir,
126+
config.encoder,
127+
dataset_name),
125128
)
126129
last_save = i
127130
print("Model saved Metric!")

0 commit comments

Comments
 (0)