Model Export for Deployment
Published Feb 14, 2025
Contribute to Docs
Model export in PyTorch involves converting trained models into formats that can be loaded and executed in production systems. Common deployment targets include mobile devices, web servers, or edge devices using tools like TorchScript, ONNX, or PyTorch’s native serialization.
Syntax
Core Export Methods
Native PyTorch (
.pt
/.pth
):torch.save(model.state_dict(), "model.pth") # Saves model weights torch.save(model, "full_model.pt") # Saves entire model (weights + architecture)
TorchScript (for optimized inference):
scripted_model = torch.jit.script(model) # Converts model to TorchScript scripted_model.save("model.pt")
ONNX Export (for cross-framework compatibility):
torch.onnx.export( model, dummy_input, "model.onnx", input_names=["input"], output_names=["output"], dynamic_axes={"input": {0: "batch_size"}} )
Example
Exporting with TorchScript
Convert a trained model to TorchScript for deployment in C++/mobile:
import torchimport torchvision# Load a pretrained modelmodel = torchvision.models.resnet18(weights="IMAGENET1K_V1")model.eval()# Convert to TorchScript via scriptingscripted_model = torch.jit.script(model)scripted_model.save("resnet18_scripted.pt")# Convert via tracing (alternative method)dummy_input = torch.rand(1, 3, 224, 224)traced_model = torch.jit.trace(model, dummy_input)traced_model.save("resnet18_traced.pt")
Exporting to ONNX
Convert a model to ONNX format for use with TensorRT, OpenVINO, etc.:
# Export to ONNXdummy_input = torch.randn(1, 3, 224, 224)torch.onnx.export(model,dummy_input,"resnet18.onnx",export_params=True,opset_version=17,do_constant_folding=True,input_names=["input"],output_names=["output"],dynamic_axes={"input": {0: "batch_size"}})
Here are some Key Considerations to make
State Dict vs. Full Model
torch.save(model.state_dict())
is preferred for resuming training.- Full-model serialization (
torch.save(model)
) may break across PyTorch versions.
Device Compatibility
- Export models on the same device type (CPU/GPU).
Custom Layers
- Register custom layers with
torch.jit.script
or define them in ONNX-compatible ways.
- Register custom layers with
Optimization Tools
- Use
torch.utils.mobile_optimizer
for mobile deployment or ONNX Runtime for inference acceleration.
- Use
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.