PyTorch Model Export for Deployment

dakshdeepHERE's avatar
Published Feb 14, 2025
Contribute to Docs

Model export in PyTorch involves converting trained models into formats that can be loaded and executed in production systems. Common deployment targets include mobile devices, web servers, or edge devices using tools like TorchScript, ONNX, or PyTorch’s native serialization.

  • Machine Learning Data Scientists solve problems at scale, make predictions, find patterns, and more! They use Python, SQL, and algorithms.
    • Includes 27 Courses
    • With Professional Certification
    • Beginner Friendly.
      95 hours

Syntax

Core Export Methods

  1. Native PyTorch (.pt/.pth):

    torch.save(model.state_dict(), "model.pth")  # Saves model weights
    torch.save(model, "full_model.pt")  # Saves entire model (weights + architecture)
    
  2. TorchScript (for optimized inference):

    scripted_model = torch.jit.script(model)    # Converts model to TorchScript
    scripted_model.save("model.pt")
    
  3. ONNX Export (for cross-framework compatibility):

    torch.onnx.export(
     model,
     dummy_input,
     "model.onnx",
     input_names=["input"],
     output_names=["output"],
     dynamic_axes={"input": {0: "batch_size"}}
    )
    

Example

Exporting with TorchScript

Convert a trained model to TorchScript for deployment in C++/mobile:

import torch
import torchvision
# Load a pretrained model
model = torchvision.models.resnet18(weights="IMAGENET1K_V1")
model.eval()
# Convert to TorchScript via scripting
scripted_model = torch.jit.script(model)
scripted_model.save("resnet18_scripted.pt")
# Convert via tracing (alternative method)
dummy_input = torch.rand(1, 3, 224, 224)
traced_model = torch.jit.trace(model, dummy_input)
traced_model.save("resnet18_traced.pt")

Exporting to ONNX

Convert a model to ONNX format for use with TensorRT, OpenVINO, etc.:

# Export to ONNX
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"resnet18.onnx",
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}}
)

Here are some Key Considerations to make

  1. State Dict vs. Full Model

    • torch.save(model.state_dict()) is preferred for resuming training.
    • Full-model serialization (torch.save(model)) may break across PyTorch versions.
  2. Device Compatibility

    • Export models on the same device type (CPU/GPU).
  3. Custom Layers

    • Register custom layers with torch.jit.script or define them in ONNX-compatible ways.
  4. Optimization Tools

    • Use torch.utils.mobile_optimizer for mobile deployment or ONNX Runtime for inference acceleration.

All contributors

Contribute to Docs

Learn PyTorch on Codecademy

  • Machine Learning Data Scientists solve problems at scale, make predictions, find patterns, and more! They use Python, SQL, and algorithms.
    • Includes 27 Courses
    • With Professional Certification
    • Beginner Friendly.
      95 hours