Skip to content

Commit 8ea4d1a

Browse files
committed
quick sync
1 parent 7bf8bbc commit 8ea4d1a

File tree

4 files changed

+531
-0
lines changed

4 files changed

+531
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
.DS_Store
12
# Byte-compiled / optimized / DLL files
23
__pycache__/
34
*.py[cod]

1_pytorch-distilbert.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import os
2+
import os.path as op
3+
import time
4+
5+
from datasets import load_dataset
6+
import matplotlib.pyplot as plt
7+
import pandas as pd
8+
import torch
9+
from torch.utils.data import DataLoader
10+
import torchmetrics
11+
from transformers import AutoTokenizer
12+
from transformers import AutoModelForSequenceClassification
13+
from watermark import watermark
14+
15+
from local_dataset_utilities import (
16+
download_dataset,
17+
load_dataset_into_to_dataframe,
18+
partition_dataset,
19+
)
20+
from local_dataset_utilities import IMDBDataset
21+
22+
23+
def tokenize_text(batch):
24+
return tokenizer(batch["text"], truncation=True, padding=True)
25+
26+
27+
def plot_logs(log_dir):
28+
metrics = pd.read_csv(op.join(log_dir, "metrics.csv"))
29+
30+
aggreg_metrics = []
31+
agg_col = "epoch"
32+
for i, dfg in metrics.groupby(agg_col):
33+
agg = dict(dfg.mean())
34+
agg[agg_col] = i
35+
aggreg_metrics.append(agg)
36+
37+
df_metrics = pd.DataFrame(aggreg_metrics)
38+
df_metrics[["train_loss", "val_loss"]].plot(
39+
grid=True, legend=True, xlabel="Epoch", ylabel="Loss"
40+
)
41+
plt.savefig(op.join(log_dir, "loss.pdf"))
42+
43+
df_metrics[["train_acc", "val_acc"]].plot(
44+
grid=True, legend=True, xlabel="Epoch", ylabel="Accuracy"
45+
)
46+
plt.savefig(op.join(log_dir, "acc.pdf"))
47+
48+
49+
def train(num_epochs, model, optimizer, train_loader, val_loader, device):
50+
for epoch in range(num_epochs):
51+
train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(device)
52+
53+
for batch_idx, batch in enumerate(train_loader):
54+
model.train()
55+
for s in ["input_ids", "attention_mask", "label"]:
56+
batch[s] = batch[s].to(device)
57+
58+
### FORWARD AND BACK PROP
59+
outputs = model(
60+
batch["input_ids"],
61+
attention_mask=batch["attention_mask"],
62+
labels=batch["label"],
63+
)
64+
optimizer.zero_grad()
65+
outputs["loss"].backward()
66+
67+
### UPDATE MODEL PARAMETERS
68+
optimizer.step()
69+
70+
### LOGGING
71+
if not batch_idx % 300:
72+
print(
73+
f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {outputs['loss']:.4f}"
74+
)
75+
76+
model.eval()
77+
with torch.no_grad():
78+
predicted_labels = torch.argmax(outputs["logits"], 1)
79+
train_acc.update(predicted_labels, batch["label"])
80+
81+
### MORE LOGGING
82+
with torch.no_grad():
83+
model.eval()
84+
val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(device)
85+
for batch in val_loader:
86+
for s in ["input_ids", "attention_mask", "label"]:
87+
batch[s] = batch[s].to(device)
88+
outputs = model(
89+
batch["input_ids"],
90+
attention_mask=batch["attention_mask"],
91+
labels=batch["label"],
92+
)
93+
predicted_labels = torch.argmax(outputs["logits"], 1)
94+
val_acc.update(predicted_labels, batch["label"])
95+
96+
print(
97+
f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%"
98+
)
99+
100+
101+
if __name__ == "__main__":
102+
print(watermark(packages="torch,lightning,transformers", python=True))
103+
print("Torch CUDA available?", torch.cuda.is_available())
104+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
105+
106+
torch.manual_seed(123)
107+
108+
##########################
109+
### 1 Loading the Dataset
110+
##########################
111+
download_dataset()
112+
df = load_dataset_into_to_dataframe()
113+
if not (op.exists("train.csv") and op.exists("val.csv") and op.exists("test.csv")):
114+
partition_dataset(df)
115+
116+
imdb_dataset = load_dataset(
117+
"csv",
118+
data_files={
119+
"train": "train.csv",
120+
"validation": "val.csv",
121+
"test": "test.csv",
122+
},
123+
)
124+
125+
#########################################
126+
### 2 Tokenization and Numericalization
127+
#########################################
128+
129+
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
130+
print("Tokenizer input max length:", tokenizer.model_max_length, flush=True)
131+
print("Tokenizer vocabulary size:", tokenizer.vocab_size, flush=True)
132+
133+
print("Tokenizing ...", flush=True)
134+
imdb_tokenized = imdb_dataset.map(tokenize_text, batched=True, batch_size=None)
135+
del imdb_dataset
136+
imdb_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])
137+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
138+
139+
#########################################
140+
### 3 Set Up DataLoaders
141+
#########################################
142+
143+
train_dataset = IMDBDataset(imdb_tokenized, partition_key="train")
144+
val_dataset = IMDBDataset(imdb_tokenized, partition_key="validation")
145+
test_dataset = IMDBDataset(imdb_tokenized, partition_key="test")
146+
147+
train_loader = DataLoader(
148+
dataset=train_dataset,
149+
batch_size=12,
150+
shuffle=True,
151+
num_workers=4,
152+
drop_last=True,
153+
)
154+
155+
val_loader = DataLoader(
156+
dataset=val_dataset,
157+
batch_size=12,
158+
num_workers=2,
159+
drop_last=True,
160+
)
161+
162+
test_loader = DataLoader(
163+
dataset=test_dataset,
164+
batch_size=12,
165+
num_workers=2,
166+
drop_last=True,
167+
)
168+
169+
#########################################
170+
### 4 Initializing the Model
171+
#########################################
172+
173+
model = AutoModelForSequenceClassification.from_pretrained(
174+
"distilbert-base-uncased", num_labels=2
175+
)
176+
177+
model.to(device)
178+
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
179+
180+
#########################################
181+
### 5 Finetuning
182+
#########################################
183+
184+
start = time.time()
185+
train(
186+
num_epochs=3,
187+
model=model,
188+
optimizer=optimizer,
189+
train_loader=train_loader,
190+
val_loader=val_loader,
191+
device=device,
192+
)
193+
194+
end = time.time()
195+
elapsed = end - start
196+
print(f"Time elapsed {elapsed/60:.2f} min")
197+
198+
with torch.no_grad():
199+
model.eval()
200+
test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(device)
201+
for batch in test_loader:
202+
for s in ["input_ids", "attention_mask", "label"]:
203+
batch[s] = batch[s].to(device)
204+
outputs = model(
205+
batch["input_ids"],
206+
attention_mask=batch["attention_mask"],
207+
labels=batch["label"],
208+
)
209+
predicted_labels = torch.argmax(outputs["logits"], 1)
210+
test_acc.update(predicted_labels, batch["label"])
211+
212+
print(f"Test accuracy {test_acc.compute()*100:.2f}%")

0 commit comments

Comments
 (0)