Back to course overview

Saving And Loading Models - PyTorch Beginner 17

Learn all the basics you need to get started with this deep learning framework! In this part we will learn how to save and load our model. I will show you the different functions you have to remember, and the different ways of saving our model. I also show you what you must consider when using a GPU.

Functions you must know: - torch.save() - torch.load() - torch.nn.Module().loadstatedict()

All code from this course can be found on GitHub.

Saving and Loading in PyTorch

import torch import torch.nn as nn ''' 3 DIFFERENT METHODS TO REMEMBER: - torch.save(arg, PATH) # can be model, tensor, or dictionary - torch.load(PATH) - torch.load_state_dict(arg) ''' ''' 2 DIFFERENT WAYS OF SAVING # 1) lazy way: save whole model torch.save(model, PATH) # model class must be defined somewhere model = torch.load(PATH) model.eval() # 2) recommended way: save only the state_dict torch.save(model.state_dict(), PATH) # model must be created again with parameters model = Model(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval() ''' class Model(nn.Module): def __init__(self, n_input_features): super(Model, self).__init__() self.linear = nn.Linear(n_input_features, 1) def forward(self, x): y_pred = torch.sigmoid(self.linear(x)) return y_pred model = Model(n_input_features=6) # train your model... ####################save all ###################################### for param in model.parameters(): print(param) # save and load entire model FILE = "model.pth" torch.save(model, FILE) loaded_model = torch.load(FILE) loaded_model.eval() for param in loaded_model.parameters(): print(param) ############save only state dict ######################### # save only state dict FILE = "model.pth" torch.save(model.state_dict(), FILE) print(model.state_dict()) loaded_model = Model(n_input_features=6) loaded_model.load_state_dict(torch.load(FILE)) # it takes the loaded dictionary, not the path file itself loaded_model.eval() print(loaded_model.state_dict()) ###########load checkpoint##################### learning_rate = 0.01 optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) checkpoint = { "epoch": 90, "model_state": model.state_dict(), "optim_state": optimizer.state_dict() } print(optimizer.state_dict()) FILE = "checkpoint.pth" torch.save(checkpoint, FILE) model = Model(n_input_features=6) optimizer = optimizer = torch.optim.SGD(model.parameters(), lr=0) checkpoint = torch.load(FILE) model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optim_state']) epoch = checkpoint['epoch'] model.eval() # - or - # model.train() print(optimizer.state_dict()) # Remember that you must call model.eval() to set dropout and batch normalization layers # to evaluation mode before running inference. Failing to do this will yield # inconsistent inference results. If you wish to resuming training, # call model.train() to ensure these layers are in training mode. """ SAVING ON GPU/CPU # 1) Save on GPU, Load on CPU device = torch.device("cuda") model.to(device) torch.save(model.state_dict(), PATH) device = torch.device('cpu') model = Model(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location=device)) # 2) Save on GPU, Load on GPU device = torch.device("cuda") model.to(device) torch.save(model.state_dict(), PATH) model = Model(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.to(device) # Note: Be sure to use the .to(torch.device('cuda')) function # on all model inputs, too! # 3) Save on CPU, Load on GPU torch.save(model.state_dict(), PATH) device = torch.device("cuda") model = Model(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want model.to(device) # This loads the model to a given GPU device. # Next, be sure to call model.to(torch.device('cuda')) to convert the model’s parameter tensors to CUDA tensors """

FREE VS Code / PyCharm Extensions I Use

🪁 Code faster with Kite, AI-powered autocomplete: Link *

✅ Write cleaner code with Sourcery, instant refactoring suggestions: Link *

* These are affiliate links. By clicking on it you will not have any additional costs, instead you will support me and my project. Thank you! 🙏

Check out my Courses