Skip to content

Commit aa2573f

Browse files
committed
added 17_save_load tutotial
1 parent 6bef936 commit aa2573f

File tree

1 file changed

+136
-0
lines changed

1 file changed

+136
-0
lines changed

17_save_load.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
''' 3 DIFFERENT METHODS TO REMEMBER:
5+
- torch.save(arg, PATH) # can be model, tensor, or dictionary
6+
- torch.load(PATH)
7+
- torch.load_state_dict(arg)
8+
'''
9+
10+
''' 2 DIFFERENT WAYS OF SAVING
11+
# 1) lazy way: save whole model
12+
torch.save(model, PATH)
13+
14+
# model class must be defined somewhere
15+
model = torch.load(PATH)
16+
model.eval()
17+
18+
# 2) recommended way: save only the state_dict
19+
torch.save(model.state_dict(), PATH)
20+
21+
# model must be created again with parameters
22+
model = Model(*args, **kwargs)
23+
model.load_state_dict(torch.load(PATH))
24+
model.eval()
25+
'''
26+
27+
28+
class Model(nn.Module):
29+
def __init__(self, n_input_features):
30+
super(Model, self).__init__()
31+
self.linear = nn.Linear(n_input_features, 1)
32+
33+
def forward(self, x):
34+
y_pred = torch.sigmoid(self.linear(x))
35+
return y_pred
36+
37+
model = Model(n_input_features=6)
38+
# train your model...
39+
40+
####################save all ######################################
41+
for param in model.parameters():
42+
print(param)
43+
44+
# save and load entire model
45+
46+
FILE = "model.pth"
47+
torch.save(model, FILE)
48+
49+
loaded_model = torch.load(FILE)
50+
loaded_model.eval()
51+
52+
for param in loaded_model.parameters():
53+
print(param)
54+
55+
56+
############save only state dict #########################
57+
58+
# save only state dict
59+
FILE = "model.pth"
60+
torch.save(model.state_dict(), FILE)
61+
62+
print(model.state_dict())
63+
loaded_model = Model(n_input_features=6)
64+
loaded_model.load_state_dict(torch.load(FILE)) # it takes the loaded dictionary, not the path file itself
65+
loaded_model.eval()
66+
67+
print(loaded_model.state_dict())
68+
69+
70+
###########load checkpoint#####################
71+
learning_rate = 0.01
72+
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
73+
74+
checkpoint = {
75+
"epoch": 90,
76+
"model_state": model.state_dict(),
77+
"optim_state": optimizer.state_dict()
78+
}
79+
print(optimizer.state_dict())
80+
FILE = "checkpoint.pth"
81+
torch.save(checkpoint, FILE)
82+
83+
model = Model(n_input_features=6)
84+
optimizer = optimizer = torch.optim.SGD(model.parameters(), lr=0)
85+
86+
checkpoint = torch.load(FILE)
87+
model.load_state_dict(checkpoint['model_state'])
88+
optimizer.load_state_dict(checkpoint['optim_state'])
89+
epoch = checkpoint['epoch']
90+
91+
model.eval()
92+
# - or -
93+
# model.train()
94+
95+
print(optimizer.state_dict())
96+
97+
# Remember that you must call model.eval() to set dropout and batch normalization layers
98+
# to evaluation mode before running inference. Failing to do this will yield
99+
# inconsistent inference results. If you wish to resuming training,
100+
# call model.train() to ensure these layers are in training mode.
101+
102+
""" SAVING ON GPU/CPU
103+
104+
# 1) Save on GPU, Load on CPU
105+
device = torch.device("cuda")
106+
model.to(device)
107+
torch.save(model.state_dict(), PATH)
108+
109+
device = torch.device('cpu')
110+
model = Model(*args, **kwargs)
111+
model.load_state_dict(torch.load(PATH, map_location=device))
112+
113+
# 2) Save on GPU, Load on GPU
114+
device = torch.device("cuda")
115+
model.to(device)
116+
torch.save(model.state_dict(), PATH)
117+
118+
model = Model(*args, **kwargs)
119+
model.load_state_dict(torch.load(PATH))
120+
model.to(device)
121+
122+
# Note: Be sure to use the .to(torch.device('cuda')) function
123+
# on all model inputs, too!
124+
125+
# 3) Save on CPU, Load on GPU
126+
torch.save(model.state_dict(), PATH)
127+
128+
device = torch.device("cuda")
129+
model = Model(*args, **kwargs)
130+
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
131+
model.to(device)
132+
133+
# This loads the model to a given GPU device.
134+
# Next, be sure to call model.to(torch.device('cuda')) to convert the model’s parameter tensors to CUDA tensors
135+
"""
136+

0 commit comments

Comments
 (0)