Skip to content

1 tensor 的保存和加载

python
def tensor_save():
    tensor = torch.ones(5, 5)
    torch.save(tensor, "tensor.t")
    tensor_new = torch.load("tensor.t")
    print(tensor_new)
    
def load_to_gpu():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.load('mnist.pt', map_location=device)
    print(f"model device: {model}")

2 模型状态的保存

2.1 定义一个模型

python
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

2.2 保存模型的状态

python
def save_para_demo():
    model = Net()
    torch.save(model.state_dict(), "mnist_para.pth")

2.3 加载模型的状态

python
def load_para_demo():
    param = torch.load("mnist_para.pth")
    model = Net()
    model.load_state_dict(param)
    input = torch.rand(1, 1, 28, 28)
    output = model(input)   
    print(f"output shape: {output.shape}")

2.4 思考与尝试

  • 不同的后缀名,对模型的保存和加载有影响吗???
  • 保存模型状态后,可以直接给别人使用吗??? (前提:他人本地没有这个模型)
  • 模型状态中都保存了什么内容呢???

3 保存与加载模型

3.1 保存模型

def save_demo_v1():
    model = Net()
    input = torch.rand(1, 1, 28, 28)
    output = model(input)
    torch.save(model, "mnist.pt") # 4.6M : 保存

3.2 加载模型

python
def load_demo_v1():
    model = torch.load("mnist.pt")

    device = torch.device("cuda")
    model = torch.load('mnist.pt', map_location=device) # 直接将模型加载到device上
    
    input = torch.rand(1, 1, 28, 28)
    output = model(input)
    print(f"output shape: {output.shape}")

3.3 思考与尝试

  • 此时保存后的模型可以直接给他人使用吗??? (前提:他人本地没有这个模型)
  • 此时保存的模型中都包含哪些内容呢???

4 训练中的保存和加载

思考: 训练过程中断,我们需要保存哪些内容以在之前基础上继续训练,而无需重新开始整个训练过程 ???

4.1 保存训练中的状态

python
def save_ckpt_demo():
    model = Net()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    loss = torch.Tensor([0.25])
    epoch = 10
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        # 'loss': loss.item(),
        # 可以添加其他训练信息
    }

    torch.save(checkpoint, 'mnist.ckpt')

4.2 加载训练中的状态

python
def load_ckpt_demo():
    checkpoint = torch.load('model.ckpt')
    model = Net() # 需要事先定义一个net的实例
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    input = torch.rand(1, 1, 28, 28)
    output = model(input)
    print("output shape: ", output.shape)

5 保存和加载模型的静态图

思考:如何将模型安全的交付???

5.1 保存模型静态图

python
def save_trace_model():
    model = Net().eval()
    # 通过trace 得到了一个新的model,我们最终保存的是这个新的model
    traced_model = torch.jit.trace(model, torch.randn(1, 1, 28, 28))
    traced_model.save("traced_model.pt")
    # torch.save(traced_model, "mnist_trace.pt")

5.2 加载模型静态图

python
def load_trace_model():
    mm = torch.jit.load("traced_model.pt")
    output = mm(torch.randn(1, 1, 28, 28))
    print("load model succsessfully !")
    print("output: ", output)

6 通用格式onnx的保存

6.1 保存onnx 静态图模型

python
import torch
import torchvision

model = torchvision.models.resnet18()
input_tensor = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, input_tensor, "model.onnx")

6.2 运行onnx 模型

  • 方式1:
python
def run(model_path, image_path):
  session = onnxruntime.InferenceSession(model_path)
  input_data = []
  input_data.append(image_process(image_path))
  input_name_1 = session.get_inputs()[0].name
  outputs = session.run([],{input_name_1:input_data})
  return outputs
  • 方式2:
def onnx_model_infer(input_data_list : list, onnx_model):
  session = onnxruntime.InferenceSession(onnx_model.SerializeToString())
  input_name = [item.name for item in session.get_inputs()]
  assert(len(input_data_list) == len(input_name))
  input_dict = dict()
  for i, data in enumerate(input_data_list):
    input_dict[input_name[i]] = data

  outputs = session.run([], input_feed = input_dict)
  print("onnx_model_infer run successfully !!!")
  return outputs

6.3 shape infer

python
def shape_infer(onnx_model):
  model = onnx.shape_inference.infer_shapes(onnx_model)
  print("shape_infer fun run successfully !!!")
  return model