Skip to content

Commit 3c29b67

Browse files
authored
Add torchrun bonus code (#524)
1 parent f90bec7 commit 3c29b67

File tree

4 files changed

+268
-8
lines changed

4 files changed

+268
-8
lines changed
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
2+
# Source for "Build a Large Language Model From Scratch"
3+
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
4+
# Code: https://github.com/rasbt/LLMs-from-scratch
5+
6+
# Appendix A: Introduction to PyTorch (Part 3)
7+
8+
import torch
9+
import torch.nn.functional as F
10+
from torch.utils.data import Dataset, DataLoader
11+
12+
# NEW imports:
13+
import os
14+
import platform
15+
from torch.utils.data.distributed import DistributedSampler
16+
from torch.nn.parallel import DistributedDataParallel as DDP
17+
from torch.distributed import init_process_group, destroy_process_group
18+
19+
20+
# NEW: function to initialize a distributed process group (1 process / GPU)
21+
# this allows communication among processes
22+
def ddp_setup(rank, world_size):
23+
"""
24+
Arguments:
25+
rank: a unique process ID
26+
world_size: total number of processes in the group
27+
"""
28+
# Only set MASTER_ADDR and MASTER_PORT if not already defined by torchrun
29+
if "MASTER_ADDR" not in os.environ:
30+
os.environ["MASTER_ADDR"] = "localhost"
31+
if "MASTER_PORT" not in os.environ:
32+
os.environ["MASTER_PORT"] = "12345"
33+
34+
# initialize process group
35+
if platform.system() == "Windows":
36+
# Disable libuv because PyTorch for Windows isn't built with support
37+
os.environ["USE_LIBUV"] = "0"
38+
# Windows users may have to use "gloo" instead of "nccl" as backend
39+
# gloo: Facebook Collective Communication Library
40+
init_process_group(backend="gloo", rank=rank, world_size=world_size)
41+
else:
42+
# nccl: NVIDIA Collective Communication Library
43+
init_process_group(backend="nccl", rank=rank, world_size=world_size)
44+
45+
torch.cuda.set_device(rank)
46+
47+
48+
class ToyDataset(Dataset):
49+
def __init__(self, X, y):
50+
self.features = X
51+
self.labels = y
52+
53+
def __getitem__(self, index):
54+
one_x = self.features[index]
55+
one_y = self.labels[index]
56+
return one_x, one_y
57+
58+
def __len__(self):
59+
return self.labels.shape[0]
60+
61+
62+
class NeuralNetwork(torch.nn.Module):
63+
def __init__(self, num_inputs, num_outputs):
64+
super().__init__()
65+
66+
self.layers = torch.nn.Sequential(
67+
# 1st hidden layer
68+
torch.nn.Linear(num_inputs, 30),
69+
torch.nn.ReLU(),
70+
71+
# 2nd hidden layer
72+
torch.nn.Linear(30, 20),
73+
torch.nn.ReLU(),
74+
75+
# output layer
76+
torch.nn.Linear(20, num_outputs),
77+
)
78+
79+
def forward(self, x):
80+
logits = self.layers(x)
81+
return logits
82+
83+
84+
def prepare_dataset():
85+
X_train = torch.tensor([
86+
[-1.2, 3.1],
87+
[-0.9, 2.9],
88+
[-0.5, 2.6],
89+
[2.3, -1.1],
90+
[2.7, -1.5]
91+
])
92+
y_train = torch.tensor([0, 0, 0, 1, 1])
93+
94+
X_test = torch.tensor([
95+
[-0.8, 2.8],
96+
[2.6, -1.6],
97+
])
98+
y_test = torch.tensor([0, 1])
99+
100+
# Uncomment these lines to increase the dataset size to run this script on up to 8 GPUs:
101+
# factor = 4
102+
# X_train = torch.cat([X_train + torch.randn_like(X_train) * 0.1 for _ in range(factor)])
103+
# y_train = y_train.repeat(factor)
104+
# X_test = torch.cat([X_test + torch.randn_like(X_test) * 0.1 for _ in range(factor)])
105+
# y_test = y_test.repeat(factor)
106+
107+
train_ds = ToyDataset(X_train, y_train)
108+
test_ds = ToyDataset(X_test, y_test)
109+
110+
train_loader = DataLoader(
111+
dataset=train_ds,
112+
batch_size=2,
113+
shuffle=False, # NEW: False because of DistributedSampler below
114+
pin_memory=True,
115+
drop_last=True,
116+
# NEW: chunk batches across GPUs without overlapping samples:
117+
sampler=DistributedSampler(train_ds) # NEW
118+
)
119+
test_loader = DataLoader(
120+
dataset=test_ds,
121+
batch_size=2,
122+
shuffle=False,
123+
)
124+
return train_loader, test_loader
125+
126+
127+
# NEW: wrapper
128+
def main(rank, world_size, num_epochs):
129+
130+
ddp_setup(rank, world_size) # NEW: initialize process groups
131+
132+
train_loader, test_loader = prepare_dataset()
133+
model = NeuralNetwork(num_inputs=2, num_outputs=2)
134+
model.to(rank)
135+
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
136+
137+
model = DDP(model, device_ids=[rank]) # NEW: wrap model with DDP
138+
# the core model is now accessible as model.module
139+
140+
for epoch in range(num_epochs):
141+
# NEW: Set sampler to ensure each epoch has a different shuffle order
142+
train_loader.sampler.set_epoch(epoch)
143+
144+
model.train()
145+
for features, labels in train_loader:
146+
147+
features, labels = features.to(rank), labels.to(rank) # New: use rank
148+
logits = model(features)
149+
loss = F.cross_entropy(logits, labels) # Loss function
150+
151+
optimizer.zero_grad()
152+
loss.backward()
153+
optimizer.step()
154+
155+
# LOGGING
156+
print(f"[GPU{rank}] Epoch: {epoch+1:03d}/{num_epochs:03d}"
157+
f" | Batchsize {labels.shape[0]:03d}"
158+
f" | Train/Val Loss: {loss:.2f}")
159+
160+
model.eval()
161+
162+
try:
163+
train_acc = compute_accuracy(model, train_loader, device=rank)
164+
print(f"[GPU{rank}] Training accuracy", train_acc)
165+
test_acc = compute_accuracy(model, test_loader, device=rank)
166+
print(f"[GPU{rank}] Test accuracy", test_acc)
167+
168+
####################################################
169+
# NEW (not in the book):
170+
except ZeroDivisionError as e:
171+
raise ZeroDivisionError(
172+
f"{e}\n\nThis script is designed for 2 GPUs. You can run it as:\n"
173+
"torchrun --nproc_per_node=2 DDP-script-torchrun.py\n"
174+
f"Or, to run it on {torch.cuda.device_count()} GPUs, uncomment the code on lines 103 to 107."
175+
)
176+
####################################################
177+
178+
destroy_process_group() # NEW: cleanly exit distributed mode
179+
180+
181+
def compute_accuracy(model, dataloader, device):
182+
model = model.eval()
183+
correct = 0.0
184+
total_examples = 0
185+
186+
for idx, (features, labels) in enumerate(dataloader):
187+
features, labels = features.to(device), labels.to(device)
188+
189+
with torch.no_grad():
190+
logits = model(features)
191+
predictions = torch.argmax(logits, dim=1)
192+
compare = labels == predictions
193+
correct += torch.sum(compare)
194+
total_examples += len(compare)
195+
return (correct / total_examples).item()
196+
197+
198+
if __name__ == "__main__":
199+
# NEW: Use environment variables set by torchrun if available, otherwise default to single-process.
200+
if "WORLD_SIZE" in os.environ:
201+
world_size = int(os.environ["WORLD_SIZE"])
202+
else:
203+
world_size = 1
204+
205+
if "LOCAL_RANK" in os.environ:
206+
rank = int(os.environ["LOCAL_RANK"])
207+
elif "RANK" in os.environ:
208+
rank = int(os.environ["RANK"])
209+
else:
210+
rank = 0
211+
212+
# Only print on rank 0 to avoid duplicate prints from each GPU process
213+
if rank == 0:
214+
print("PyTorch version:", torch.__version__)
215+
print("CUDA available:", torch.cuda.is_available())
216+
print("Number of GPUs available:", torch.cuda.device_count())
217+
218+
torch.manual_seed(123)
219+
num_epochs = 3
220+
main(rank, world_size, num_epochs)

appendix-A/01_main-chapter-code/DDP-script.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,11 @@ def ddp_setup(rank, world_size):
3131
os.environ["MASTER_ADDR"] = "localhost"
3232
# any free port on the machine
3333
os.environ["MASTER_PORT"] = "12345"
34-
if platform.system() == "Windows":
35-
# Disable libuv because PyTorch for Windows isn't built with support
36-
os.environ["USE_LIBUV"] = "0"
3734

3835
# initialize process group
3936
if platform.system() == "Windows":
37+
# Disable libuv because PyTorch for Windows isn't built with support
38+
os.environ["USE_LIBUV"] = "0"
4039
# Windows users may have to use "gloo" instead of "nccl" as backend
4140
# gloo: Facebook Collective Communication Library
4241
init_process_group(backend="gloo", rank=rank, world_size=world_size)
@@ -99,6 +98,13 @@ def prepare_dataset():
9998
])
10099
y_test = torch.tensor([0, 1])
101100

101+
# Uncomment these lines to increase the dataset size to run this script on up to 8 GPUs:
102+
# factor = 4
103+
# X_train = torch.cat([X_train + torch.randn_like(X_train) * 0.1 for _ in range(factor)])
104+
# y_train = y_train.repeat(factor)
105+
# X_test = torch.cat([X_test + torch.randn_like(X_test) * 0.1 for _ in range(factor)])
106+
# y_test = y_test.repeat(factor)
107+
102108
train_ds = ToyDataset(X_train, y_train)
103109
test_ds = ToyDataset(X_test, y_test)
104110

@@ -153,10 +159,22 @@ def main(rank, world_size, num_epochs):
153159
f" | Train/Val Loss: {loss:.2f}")
154160

155161
model.eval()
156-
train_acc = compute_accuracy(model, train_loader, device=rank)
157-
print(f"[GPU{rank}] Training accuracy", train_acc)
158-
test_acc = compute_accuracy(model, test_loader, device=rank)
159-
print(f"[GPU{rank}] Test accuracy", test_acc)
162+
163+
try:
164+
train_acc = compute_accuracy(model, train_loader, device=rank)
165+
print(f"[GPU{rank}] Training accuracy", train_acc)
166+
test_acc = compute_accuracy(model, test_loader, device=rank)
167+
print(f"[GPU{rank}] Test accuracy", test_acc)
168+
169+
####################################################
170+
# NEW (not in the book):
171+
except ZeroDivisionError as e:
172+
raise ZeroDivisionError(
173+
f"{e}\n\nThis script is designed for 2 GPUs. You can run it as:\n"
174+
"CUDA_VISIBLE_DEVICES=0,1 python DDP-script.py\n"
175+
f"Or, to run it on {torch.cuda.device_count()} GPUs, uncomment the code on lines 103 to 107."
176+
)
177+
####################################################
160178

161179
destroy_process_group() # NEW: cleanly exit distributed mode
162180

@@ -184,7 +202,6 @@ def compute_accuracy(model, dataloader, device):
184202
print("PyTorch version:", torch.__version__)
185203
print("CUDA available:", torch.cuda.is_available())
186204
print("Number of GPUs available:", torch.cuda.device_count())
187-
188205
torch.manual_seed(123)
189206

190207
# NEW: spawn new processes
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Appendix A: Introduction to PyTorch
2+
3+
### Main Chapter Code
4+
5+
- [code-part1.ipynb](code-part1.ipynb) contains all the section A.1 to A.8 code as it appears in the chapter
6+
- [code-part2.ipynb](code-part2.ipynb) contains all the section A.9 GPU code as it appears in the chapter
7+
- [DDP-script.py](DDP-script.py) contains the script to demonstrate multi-GPU usage (note that Jupyter Notebooks only support single GPUs, so this is a script, not a notebook). You can run it as `python DDP-script.py`. If your machine has more than 2 GPUs, run it as `CUDA_VISIBLE_DEVIVES=0,1 python DDP-script.py`.
8+
- [exercise-solutions.ipynb](exercise-solutions.ipynb) contains the exercise solutions for this chapter
9+
10+
### Optional Code
11+
12+
- [DDP-script-torchrun.py](DDP-script-torchrun.py) is an optional version of the `DDP-script.py` script that runs via the PyTorch `torchrun` command instead of spawning and managing multiple processes ourselves via `multiprocessing.spawn`. The `torchrun` command has the advantage of automatically handling distributed initialization, including multi-node coordination, which slightly simplifies the setup process. You can use this script via `torchrun --nproc_per_node=2 DDP-script-torchrun.py`

appendix-A/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Appendix A: Introduction to PyTorch
2+
3+
 
4+
## Main Chapter Code
5+
6+
- [01_main-chapter-code](01_main-chapter-code) contains the main chapter code
7+
8+
 
9+
## Bonus Materials
10+
11+
- [02_setup-recommendations](02_setup-recommendations) contains Python installation and setup recommendations.

0 commit comments

Comments
 (0)