-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
98 lines (81 loc) · 4.02 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from datetime import datetime
import argparse
import os
from data import get_loader_splits
from utils import get_trainable, validate, lr_schedule
from models import LeNet, WideResNet22, Hindus
torch.manual_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument("-n", "--name", type=str, required=True, help="Name of the run. Used for creating directories "
"with tensorboard files and states of the network.")
parser.add_argument("--no_augmentation", action="store_true", help="Whether to use data augmentation.")
parser.add_argument("--augment_valid", action="store_true", help="Wheter to use data augmentation for validation.")
parser.add_argument("-e", "--epochs", type=int, default=70, help="Number of epochs.")
parser.add_argument("-r", "--regularization", type=float, default=0.0002, help="Value of L2 regularization parameter.")
parser.add_argument("-bs", "--batch_size", type=int, default=64, help="Size of a batch.")
parser.add_argument("-vbs", "--valid_batch_size", type=int, default=128, help="Size of a validation batch.")
args = parser.parse_args()
print(args)
states_dir = "states/{}".format(args.name)
if not os.path.exists(states_dir):
os.makedirs(states_dir, exist_ok=True)
trainloader, validloader = get_loader_splits(batch_size=args.batch_size,
valid_batch_size=args.valid_batch_size,
augment=not args.no_augmentation,
augment_valid=args.augment_valid)
net = WideResNet22()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Moving network parameters to device
net.to(device)
print("Network parameters moved to {}".format(device))
# Tensorboard
writer = SummaryWriter('runs/{}'.format(args.name))
# Optimization
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(get_trainable(net.parameters()), lr=0.01, weight_decay=args.regularization)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_schedule)
# Training loop
best_valid_accuracy = None
n_epochs = args.epochs
for epoch in range(n_epochs):
running_loss = 0.0
n_correct = 0
print("Epoch {} / {}".format(epoch + 1, n_epochs))
print("Time", datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
for i, (x, y) in enumerate(tqdm(trainloader, leave=False)):
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
y_pred = net(x)
loss = criterion(y_pred, y)
loss.backward()
optimizer.step()
running_loss += loss.item()
with torch.no_grad():
n_correct += (torch.argmax(y_pred, 1) == y).sum().int().item()
step = (epoch + 1) * len(trainloader)
valid_loss, valid_accuracy = validate(net, criterion, validloader, device)
train_accuracy = 100 * n_correct / (len(trainloader) * trainloader.batch_size)
# ...log the running loss
writer.add_scalar('training loss', running_loss / 1000, step)
writer.add_scalar('training accuracy', train_accuracy, step)
writer.add_scalar('validation loss', valid_loss, step)
writer.add_scalar('validation accuracy', valid_accuracy, step)
print("Training loss: {:.4f}, training accuracy: {:.2f}, validation loss: {:.4f}, validation accuracy: {:.2f}"
.format(running_loss / len(trainloader), train_accuracy, valid_loss, valid_accuracy))
running_loss = 0.0
n_correct = 0
# useful when best_valid_accuracy is None at the beginning
best_valid_accuracy = best_valid_accuracy or valid_accuracy
# Saving model parameters
if valid_accuracy > best_valid_accuracy:
best_valid_accuracy = valid_accuracy
print("Saving network state of epoch {} with valid accuracy {:.2f}".format(epoch, valid_accuracy))
torch.save(net.state_dict(), "states/{}/state".format(args.name))
scheduler.step()
writer.add_graph(net, x)