Model Export with TorchScript
Published Feb 14, 2025
Contribute to Docs
TorchScript is a PyTorch intermediate representation that allows models to be serialized and optimized for execution in resource-constrained environments like mobile devices, embedded systems, or C++ applications. It decouples models from Python dependencies while preserving their logic.
Syntax
Scripting (for models with dynamic control flow):
scripted_model = torch.jit.script(model) # Converts model to TorchScript scripted_model.save("model.pt")
Tracing (for static models):
traced_model = torch.jit.trace(model, example_input) # Records operations via example input traced_model.save("model.pt")
Hybrid Approach (script/trace specific submodules):
@torch.jit.script def custom_function(x: torch.Tensor) -> torch.Tensor: return x * 2
Example
Scripting a Model with Conditional Logic
Export a model that uses if
statements (unsupported by tracing):
import torchclass DynamicModel(torch.nn.Module):def forward(self, x: torch.Tensor) -> torch.Tensor:if x.sum() > 0:return x * 2else:return x - 1model = DynamicModel()scripted_model = torch.jit.script(model) # Handles dynamic control flowscripted_model.save("dynamic_model.pt")# Testing the scripted modelx1 = torch.tensor([1.0, -0.5, 3.0])x2 = torch.tensor([-2.0, -1.5, -0.5])print(scripted_model(x1))print(scripted_model(x2))
The output will be:
tensor([ 2., -1., 6.])tensor([-3.0000, -2.5000, -1.5000])
Tracing a ResNet for Mobile Deployment
Convert a pretrained ResNet using tracing:
import torchimport torchvisionmodel = torchvision.models.resnet18(weights="IMAGENET1K_V1").eval()# Trace with example inputdummy_input = torch.rand(1, 3, 224, 224)traced_model = torch.jit.trace(model, dummy_input)traced_model.save("resnet18_traced.pt")# Running inference with traced modeloutput = traced_model(dummy_input)print(output.shape)
The output will be:
torch.Size([1, 1000])
Note: This confirms that the traced
ResNet
model processes an image and produces 1000 output logits (corresponding to ImageNet classes).
Key Considerations to Make
Scripting vs Tracing:
- Use
torch.jit.script
for models with:- Loops/conditionals
- Variable-length inputs
- Non-tensor data dependencies
- Use
torch.jit.trace
for static models (faster execution).
- Use
Limitations:
- TorchScript supports a subset of Python/PyTorch operations.
- Avoid
**kwargs
or dynamic tensor shapes in traced models.
Model Export with TorchScript
- Model Export for Deployment
- The process of serializing and packaging trained PyTorch models for use in production environments, enabling inference outside of training workflows.
- TorchScript Overview
- Serializes and optimizes PyTorch models for production deployment, enabling execution in high-performance environments without Python dependencies.
Contribute to Docs
- Learn more about how to get involved.
- Edit this page on GitHub to fix an error or make an improvement.
- Submit feedback to let us know how we can improve Docs.