|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | + |
| 4 | +''' 3 DIFFERENT METHODS TO REMEMBER: |
| 5 | + - torch.save(arg, PATH) # can be model, tensor, or dictionary |
| 6 | + - torch.load(PATH) |
| 7 | + - torch.load_state_dict(arg) |
| 8 | +''' |
| 9 | + |
| 10 | +''' 2 DIFFERENT WAYS OF SAVING |
| 11 | +# 1) lazy way: save whole model |
| 12 | +torch.save(model, PATH) |
| 13 | +
|
| 14 | +# model class must be defined somewhere |
| 15 | +model = torch.load(PATH) |
| 16 | +model.eval() |
| 17 | +
|
| 18 | +# 2) recommended way: save only the state_dict |
| 19 | +torch.save(model.state_dict(), PATH) |
| 20 | +
|
| 21 | +# model must be created again with parameters |
| 22 | +model = Model(*args, **kwargs) |
| 23 | +model.load_state_dict(torch.load(PATH)) |
| 24 | +model.eval() |
| 25 | +''' |
| 26 | + |
| 27 | + |
| 28 | +class Model(nn.Module): |
| 29 | + def __init__(self, n_input_features): |
| 30 | + super(Model, self).__init__() |
| 31 | + self.linear = nn.Linear(n_input_features, 1) |
| 32 | + |
| 33 | + def forward(self, x): |
| 34 | + y_pred = torch.sigmoid(self.linear(x)) |
| 35 | + return y_pred |
| 36 | + |
| 37 | +model = Model(n_input_features=6) |
| 38 | +# train your model... |
| 39 | + |
| 40 | +####################save all ###################################### |
| 41 | +for param in model.parameters(): |
| 42 | + print(param) |
| 43 | + |
| 44 | +# save and load entire model |
| 45 | + |
| 46 | +FILE = "model.pth" |
| 47 | +torch.save(model, FILE) |
| 48 | + |
| 49 | +loaded_model = torch.load(FILE) |
| 50 | +loaded_model.eval() |
| 51 | + |
| 52 | +for param in loaded_model.parameters(): |
| 53 | + print(param) |
| 54 | + |
| 55 | + |
| 56 | +############save only state dict ######################### |
| 57 | + |
| 58 | +# save only state dict |
| 59 | +FILE = "model.pth" |
| 60 | +torch.save(model.state_dict(), FILE) |
| 61 | + |
| 62 | +print(model.state_dict()) |
| 63 | +loaded_model = Model(n_input_features=6) |
| 64 | +loaded_model.load_state_dict(torch.load(FILE)) # it takes the loaded dictionary, not the path file itself |
| 65 | +loaded_model.eval() |
| 66 | + |
| 67 | +print(loaded_model.state_dict()) |
| 68 | + |
| 69 | + |
| 70 | +###########load checkpoint##################### |
| 71 | +learning_rate = 0.01 |
| 72 | +optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) |
| 73 | + |
| 74 | +checkpoint = { |
| 75 | +"epoch": 90, |
| 76 | +"model_state": model.state_dict(), |
| 77 | +"optim_state": optimizer.state_dict() |
| 78 | +} |
| 79 | +print(optimizer.state_dict()) |
| 80 | +FILE = "checkpoint.pth" |
| 81 | +torch.save(checkpoint, FILE) |
| 82 | + |
| 83 | +model = Model(n_input_features=6) |
| 84 | +optimizer = optimizer = torch.optim.SGD(model.parameters(), lr=0) |
| 85 | + |
| 86 | +checkpoint = torch.load(FILE) |
| 87 | +model.load_state_dict(checkpoint['model_state']) |
| 88 | +optimizer.load_state_dict(checkpoint['optim_state']) |
| 89 | +epoch = checkpoint['epoch'] |
| 90 | + |
| 91 | +model.eval() |
| 92 | +# - or - |
| 93 | +# model.train() |
| 94 | + |
| 95 | +print(optimizer.state_dict()) |
| 96 | + |
| 97 | +# Remember that you must call model.eval() to set dropout and batch normalization layers |
| 98 | +# to evaluation mode before running inference. Failing to do this will yield |
| 99 | +# inconsistent inference results. If you wish to resuming training, |
| 100 | +# call model.train() to ensure these layers are in training mode. |
| 101 | + |
| 102 | +""" SAVING ON GPU/CPU |
| 103 | +
|
| 104 | +# 1) Save on GPU, Load on CPU |
| 105 | +device = torch.device("cuda") |
| 106 | +model.to(device) |
| 107 | +torch.save(model.state_dict(), PATH) |
| 108 | +
|
| 109 | +device = torch.device('cpu') |
| 110 | +model = Model(*args, **kwargs) |
| 111 | +model.load_state_dict(torch.load(PATH, map_location=device)) |
| 112 | +
|
| 113 | +# 2) Save on GPU, Load on GPU |
| 114 | +device = torch.device("cuda") |
| 115 | +model.to(device) |
| 116 | +torch.save(model.state_dict(), PATH) |
| 117 | +
|
| 118 | +model = Model(*args, **kwargs) |
| 119 | +model.load_state_dict(torch.load(PATH)) |
| 120 | +model.to(device) |
| 121 | +
|
| 122 | +# Note: Be sure to use the .to(torch.device('cuda')) function |
| 123 | +# on all model inputs, too! |
| 124 | +
|
| 125 | +# 3) Save on CPU, Load on GPU |
| 126 | +torch.save(model.state_dict(), PATH) |
| 127 | +
|
| 128 | +device = torch.device("cuda") |
| 129 | +model = Model(*args, **kwargs) |
| 130 | +model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want |
| 131 | +model.to(device) |
| 132 | +
|
| 133 | +# This loads the model to a given GPU device. |
| 134 | +# Next, be sure to call model.to(torch.device('cuda')) to convert the model’s parameter tensors to CUDA tensors |
| 135 | +""" |
| 136 | + |
0 commit comments