1
+ import os
2
+ import time
3
+ import torch
4
+ import random
5
+ import numpy as np
6
+ import seaborn as sns
7
+ import matplotlib .pyplot as plt
8
+ import pandas as pd
9
+
10
+
11
+ from PIL import Image
12
+ from torch import nn
13
+ from torch import optim
14
+ from torch .nn import functional as F
15
+ from torch .utils .data import DataLoader
16
+
17
+ from src .models .AEresnet import AEResnet
18
+ from src .models .utils import train , evaluate
19
+ from src .models .utils import count_parameters , calculate_accuracy , epoch_time
20
+
21
+
22
+ from torchvision .transforms import Compose , ToTensor , Lambda
23
+ from torchvision .transforms import Resize , Normalize
24
+ from torchvision .datasets import ImageFolder
25
+ from sklearn .metrics import confusion_matrix
26
+
27
+
28
+ #random seed setting
29
+ SEED = 1234
30
+
31
+ random .seed (SEED )
32
+ np .random .seed (SEED )
33
+ torch .manual_seed (SEED )
34
+ torch .cuda .manual_seed (SEED )
35
+ torch .backends .cudnn .deterministic = True
36
+
37
+
38
+ # data directories initiation
39
+ train_data_dir = os .path .join (os .curdir ,'data' ,'preprocessed' ,'classification' ,'train' )
40
+ val_data_dir = os .path .join (os .curdir ,'data' ,'preprocessed' ,'classification' ,'val' )
41
+ weights_path = os .path .join (os .curdir ,'models' ,'Autoencoder-weights' ,'resnet34.pt' )
42
+ #ultimate_weights = os.path.join(os.curdir,'exp10','AEpretrained_resnet34_weights.pt')
43
+
44
+ #defining the pretrained model
45
+ model = AEResnet (res34 = True ,output_dim = 4 )
46
+
47
+ # Auto encoder data loading
48
+ model .load_state_dict (torch .load (weights_path , map_location = torch .device ('cpu' )),strict = False )
49
+
50
+ # classification layer defination
51
+ INPUT_DIM = model .fc .in_features
52
+ OUTPUT_DIM = model .fc .out_features
53
+
54
+ FC_layer = nn .Linear (INPUT_DIM ,OUTPUT_DIM )
55
+ model .fc = FC_layer
56
+ model .fc .weight .requires_grad = True
57
+ model .fc .bias .requires_grad = True
58
+
59
+ print (f'The model has { count_parameters (model ):,} trainable parameters' )
60
+
61
+ #hyperparametres and setting
62
+ lr = 0.001
63
+ batch_size = 1
64
+ epochs = 10
65
+ weight_decay = 0
66
+ optimizer = optim .Adam (model .parameters (),lr = lr ,weight_decay = weight_decay )
67
+ criterion = nn .CrossEntropyLoss ()
68
+ schedular = optim .lr_scheduler .StepLR (optimizer , gamma = 0.5 ,step_size = 1 ,verbose = True )
69
+ scaler = torch .cuda .amp .GradScaler ()
70
+ device = torch .device ('cuda' if torch .cuda .is_available () else 'cpu' )
71
+ model = model .to (device )
72
+ criterion = criterion .to (device )
73
+
74
+ # related transformation defination
75
+ ROCT_MEANS = [0.20041628 ,0.20041628 ,0.20041628 ]
76
+ ROCT_STDEVS = [0.20288454 ,0.20288454 ,0.20288454 ]
77
+
78
+
79
+
80
+ transforms = Compose ([
81
+ Resize (224 ),
82
+ Lambda (lambda x : x .convert ('RGB' )),
83
+ ToTensor (),
84
+ Normalize (ROCT_MEANS ,ROCT_STDEVS )
85
+ ])
86
+
87
+
88
+ # Data loading and labeling
89
+ train_data = ImageFolder (root = train_data_dir ,
90
+ transform = transforms ,
91
+ )
92
+
93
+ val_data = ImageFolder (root = val_data_dir ,
94
+ transform = transforms ,
95
+ )
96
+
97
+
98
+
99
+ #data iterator defination
100
+
101
+ train_iterator = DataLoader (train_data ,
102
+ shuffle = True ,
103
+ batch_size = batch_size )
104
+
105
+ val_iterator = DataLoader (val_data ,
106
+ shuffle = True ,
107
+ batch_size = batch_size )
108
+
109
+
110
+ # Model Training loop defination
111
+ best_valid_loss = float ('inf' )
112
+ model_name = 'AEpretrained_resnet34_weights.pt'
113
+ log = pd .DataFrame (columns = ['train_loss' ,'train_acc' ,'val_loss' , 'val_acc' ])
114
+
115
+ for epoch in range (epochs ):
116
+
117
+ start_time = time .monotonic ()
118
+
119
+ train_loss , train_acc = train (model , train_iterator , optimizer , criterion ,device ,schedular ,scaler = False )
120
+ val_loss , val_acc = evaluate (model , val_iterator , criterion , device )
121
+
122
+ if val_loss < best_valid_loss :
123
+ best_valid_loss = val_loss
124
+ torch .save (model .state_dict (), model_name )
125
+
126
+ end_time = time .monotonic ()
127
+
128
+ epoch_mins , epoch_secs = epoch_time (start_time , end_time )
129
+
130
+ log .loc [len (log .index )] = [train_loss ,train_acc ,val_loss ,val_acc ]
131
+ log .to_csv ('log.csv' )
132
+
133
+ # print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s, current time: {time.ctime()}')
134
+ # print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
135
+ # print(f'\t Val. Loss: {val_loss:.3f} | Val. Acc: {val_acc*100:.2f}%')
0 commit comments