1
+ import os
2
+ import os .path as op
3
+ import time
4
+
5
+ from datasets import load_dataset
6
+ import matplotlib .pyplot as plt
7
+ import pandas as pd
8
+ import torch
9
+ from torch .utils .data import DataLoader
10
+ import torchmetrics
11
+ from transformers import AutoTokenizer
12
+ from transformers import AutoModelForSequenceClassification
13
+ from watermark import watermark
14
+
15
+ from local_dataset_utilities import (
16
+ download_dataset ,
17
+ load_dataset_into_to_dataframe ,
18
+ partition_dataset ,
19
+ )
20
+ from local_dataset_utilities import IMDBDataset
21
+
22
+
23
+ def tokenize_text (batch ):
24
+ return tokenizer (batch ["text" ], truncation = True , padding = True )
25
+
26
+
27
+ def plot_logs (log_dir ):
28
+ metrics = pd .read_csv (op .join (log_dir , "metrics.csv" ))
29
+
30
+ aggreg_metrics = []
31
+ agg_col = "epoch"
32
+ for i , dfg in metrics .groupby (agg_col ):
33
+ agg = dict (dfg .mean ())
34
+ agg [agg_col ] = i
35
+ aggreg_metrics .append (agg )
36
+
37
+ df_metrics = pd .DataFrame (aggreg_metrics )
38
+ df_metrics [["train_loss" , "val_loss" ]].plot (
39
+ grid = True , legend = True , xlabel = "Epoch" , ylabel = "Loss"
40
+ )
41
+ plt .savefig (op .join (log_dir , "loss.pdf" ))
42
+
43
+ df_metrics [["train_acc" , "val_acc" ]].plot (
44
+ grid = True , legend = True , xlabel = "Epoch" , ylabel = "Accuracy"
45
+ )
46
+ plt .savefig (op .join (log_dir , "acc.pdf" ))
47
+
48
+
49
+ def train (num_epochs , model , optimizer , train_loader , val_loader , device ):
50
+ for epoch in range (num_epochs ):
51
+ train_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 ).to (device )
52
+
53
+ for batch_idx , batch in enumerate (train_loader ):
54
+ model .train ()
55
+ for s in ["input_ids" , "attention_mask" , "label" ]:
56
+ batch [s ] = batch [s ].to (device )
57
+
58
+ ### FORWARD AND BACK PROP
59
+ outputs = model (
60
+ batch ["input_ids" ],
61
+ attention_mask = batch ["attention_mask" ],
62
+ labels = batch ["label" ],
63
+ )
64
+ optimizer .zero_grad ()
65
+ outputs ["loss" ].backward ()
66
+
67
+ ### UPDATE MODEL PARAMETERS
68
+ optimizer .step ()
69
+
70
+ ### LOGGING
71
+ if not batch_idx % 300 :
72
+ print (
73
+ f"Epoch: { epoch + 1 :04d} /{ num_epochs :04d} | Batch { batch_idx :04d} /{ len (train_loader ):04d} | Loss: { outputs ['loss' ]:.4f} "
74
+ )
75
+
76
+ model .eval ()
77
+ with torch .no_grad ():
78
+ predicted_labels = torch .argmax (outputs ["logits" ], 1 )
79
+ train_acc .update (predicted_labels , batch ["label" ])
80
+
81
+ ### MORE LOGGING
82
+ with torch .no_grad ():
83
+ model .eval ()
84
+ val_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 ).to (device )
85
+ for batch in val_loader :
86
+ for s in ["input_ids" , "attention_mask" , "label" ]:
87
+ batch [s ] = batch [s ].to (device )
88
+ outputs = model (
89
+ batch ["input_ids" ],
90
+ attention_mask = batch ["attention_mask" ],
91
+ labels = batch ["label" ],
92
+ )
93
+ predicted_labels = torch .argmax (outputs ["logits" ], 1 )
94
+ val_acc .update (predicted_labels , batch ["label" ])
95
+
96
+ print (
97
+ f"Epoch: { epoch + 1 :04d} /{ num_epochs :04d} | Train acc.: { train_acc .compute ()* 100 :.2f} % | Val acc.: { val_acc .compute ()* 100 :.2f} %"
98
+ )
99
+
100
+
101
+ if __name__ == "__main__" :
102
+ print (watermark (packages = "torch,lightning,transformers" , python = True ))
103
+ print ("Torch CUDA available?" , torch .cuda .is_available ())
104
+ device = "cuda:0" if torch .cuda .is_available () else "cpu"
105
+
106
+ torch .manual_seed (123 )
107
+
108
+ ##########################
109
+ ### 1 Loading the Dataset
110
+ ##########################
111
+ download_dataset ()
112
+ df = load_dataset_into_to_dataframe ()
113
+ if not (op .exists ("train.csv" ) and op .exists ("val.csv" ) and op .exists ("test.csv" )):
114
+ partition_dataset (df )
115
+
116
+ imdb_dataset = load_dataset (
117
+ "csv" ,
118
+ data_files = {
119
+ "train" : "train.csv" ,
120
+ "validation" : "val.csv" ,
121
+ "test" : "test.csv" ,
122
+ },
123
+ )
124
+
125
+ #########################################
126
+ ### 2 Tokenization and Numericalization
127
+ #########################################
128
+
129
+ tokenizer = AutoTokenizer .from_pretrained ("distilbert-base-uncased" )
130
+ print ("Tokenizer input max length:" , tokenizer .model_max_length , flush = True )
131
+ print ("Tokenizer vocabulary size:" , tokenizer .vocab_size , flush = True )
132
+
133
+ print ("Tokenizing ..." , flush = True )
134
+ imdb_tokenized = imdb_dataset .map (tokenize_text , batched = True , batch_size = None )
135
+ del imdb_dataset
136
+ imdb_tokenized .set_format ("torch" , columns = ["input_ids" , "attention_mask" , "label" ])
137
+ os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
138
+
139
+ #########################################
140
+ ### 3 Set Up DataLoaders
141
+ #########################################
142
+
143
+ train_dataset = IMDBDataset (imdb_tokenized , partition_key = "train" )
144
+ val_dataset = IMDBDataset (imdb_tokenized , partition_key = "validation" )
145
+ test_dataset = IMDBDataset (imdb_tokenized , partition_key = "test" )
146
+
147
+ train_loader = DataLoader (
148
+ dataset = train_dataset ,
149
+ batch_size = 12 ,
150
+ shuffle = True ,
151
+ num_workers = 4 ,
152
+ drop_last = True ,
153
+ )
154
+
155
+ val_loader = DataLoader (
156
+ dataset = val_dataset ,
157
+ batch_size = 12 ,
158
+ num_workers = 2 ,
159
+ drop_last = True ,
160
+ )
161
+
162
+ test_loader = DataLoader (
163
+ dataset = test_dataset ,
164
+ batch_size = 12 ,
165
+ num_workers = 2 ,
166
+ drop_last = True ,
167
+ )
168
+
169
+ #########################################
170
+ ### 4 Initializing the Model
171
+ #########################################
172
+
173
+ model = AutoModelForSequenceClassification .from_pretrained (
174
+ "distilbert-base-uncased" , num_labels = 2
175
+ )
176
+
177
+ model .to (device )
178
+ optimizer = torch .optim .Adam (model .parameters (), lr = 5e-5 )
179
+
180
+ #########################################
181
+ ### 5 Finetuning
182
+ #########################################
183
+
184
+ start = time .time ()
185
+ train (
186
+ num_epochs = 3 ,
187
+ model = model ,
188
+ optimizer = optimizer ,
189
+ train_loader = train_loader ,
190
+ val_loader = val_loader ,
191
+ device = device ,
192
+ )
193
+
194
+ end = time .time ()
195
+ elapsed = end - start
196
+ print (f"Time elapsed { elapsed / 60 :.2f} min" )
197
+
198
+ with torch .no_grad ():
199
+ model .eval ()
200
+ test_acc = torchmetrics .Accuracy (task = "multiclass" , num_classes = 2 ).to (device )
201
+ for batch in test_loader :
202
+ for s in ["input_ids" , "attention_mask" , "label" ]:
203
+ batch [s ] = batch [s ].to (device )
204
+ outputs = model (
205
+ batch ["input_ids" ],
206
+ attention_mask = batch ["attention_mask" ],
207
+ labels = batch ["label" ],
208
+ )
209
+ predicted_labels = torch .argmax (outputs ["logits" ], 1 )
210
+ test_acc .update (predicted_labels , batch ["label" ])
211
+
212
+ print (f"Test accuracy { test_acc .compute ()* 100 :.2f} %" )
0 commit comments