Documentation Index Fetch the complete documentation index at: https://mintlify.com/microsoft/onnxruntime/llms.txt
Use this file to discover all available pages before exploring further.
ORTModule for PyTorch Integration
ORTModule is the easiest way to accelerate your PyTorch training. It’s a drop-in replacement for torch.nn.Module that leverages ONNX Runtime’s optimized training backend.
Quick Start
Add just 2 lines to your existing PyTorch training code:
from onnxruntime.training.ortmodule import ORTModule
model = build_model()
model = ORTModule(model) # Wrap your model
# Rest of your training code remains the same
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-4 )
for data, target in dataloader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
Complete Training Example
Here’s a complete MNIST training example showing ORTModule in action:
import torch
from torchvision import datasets, transforms
from onnxruntime.training.ortmodule import ORTModule
class NeuralNet ( torch . nn . Module ):
def __init__ ( self , input_size , hidden_size , num_classes ):
super (). __init__ ()
self .fc1 = torch.nn.Linear(input_size, hidden_size)
self .relu = torch.nn.ReLU()
self .fc2 = torch.nn.Linear(hidden_size, num_classes)
def forward ( self , input1 ):
out = self .fc1(input1)
out = self .relu(out)
out = self .fc2(out)
return out
# Create model and wrap with ORTModule
device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" )
model = NeuralNet( input_size = 784 , hidden_size = 500 , num_classes = 10 ).to(device)
model = ORTModule(model)
# Setup optimizer and loss
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01 )
loss_fn = torch.nn.CrossEntropyLoss()
# Load data
train_loader = torch.utils.data.DataLoader(
datasets.MNIST( './data' , train = True , download = True ,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(( 0.1307 ,), ( 0.3081 ,))
])),
batch_size = 32 , shuffle = True )
# Training loop
model.train()
for epoch in range ( 5 ):
for batch_idx, (data, target) in enumerate (train_loader):
data, target = data.to(device), target.to(device)
data = data.reshape(data.shape[ 0 ], - 1 )
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0 :
print ( f 'Epoch: { epoch } , Batch: { batch_idx } , Loss: { loss.item() :.4f} ' )
ORTModule works seamlessly with HuggingFace transformers:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AdamW
from onnxruntime.training.ortmodule import ORTModule
# Load pre-trained model
model = AutoModelForSequenceClassification.from_pretrained(
'bert-base-uncased' ,
num_labels = 2
)
# Wrap with ORTModule for acceleration
model = ORTModule(model)
model.to( 'cuda' )
# Standard training setup
optimizer = AdamW(model.parameters(), lr = 5e-5 )
# Training loop
model.train()
for epoch in range ( 3 ):
for batch in train_dataloader:
optimizer.zero_grad()
input_ids = batch[ 'input_ids' ].to( 'cuda' )
attention_mask = batch[ 'attention_mask' ].to( 'cuda' )
labels = batch[ 'labels' ].to( 'cuda' )
outputs = model(
input_ids = input_ids,
attention_mask = attention_mask,
labels = labels
)
loss = outputs.loss
loss.backward()
optimizer.step()
Debug Options
For development and debugging, ORTModule provides detailed logging and graph export:
from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
model = build_model()
# Enable debug options
debug_options = DebugOptions(
save_onnx = True , # Export ONNX graphs
log_level = LogLevel. VERBOSE , # Detailed logging
onnx_prefix = "model_name" # Prefix for exported files
)
model = ORTModule(model, debug_options)
Log Levels
WARNING (default): User-facing warnings and errors
INFO : Experimental feature stats, more error details
DEVINFO : Recommended for debugging, includes all rank logs
VERBOSE : Maximum verbosity, backend and exporter logs
Environment Variables
ORTModule behavior can be customized via environment variables:
Fallback Policy
# Disable fallback to PyTorch (useful for benchmarking)
export ORTMODULE_FALLBACK_POLICY = "FALLBACK_DISABLE"
ONNX Opset Version
# Pin to specific ONNX opset version
export ORTMODULE_ONNX_OPSET_VERSION = 14
Save ONNX Models
# Export ONNX models for inspection
export ORTMODULE_SAVE_ONNX_PATH = "/path/to/output"
export ORTMODULE_LOG_LEVEL = "INFO"
Memory Optimization
# Enable gradient checkpointing (level 0-2)
export ORTMODULE_MEMORY_OPT_LEVEL = 1
# Enable memory-efficient gradient management
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT = 1
Cache Exported Models
# Cache exported models to reduce startup time
export ORTMODULE_CACHE_DIR = "/path/to/cache"
Computation Optimizations
# Enable/disable compute optimizer
export ORTMODULE_ENABLE_COMPUTE_OPTIMIZER = 1
# Enable/disable embedding sparse optimizer
export ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER = 1
# Enable/disable label sparse optimizer
export ORTMODULE_ENABLE_LABEL_SPARSE_OPTIMIZER = 1
Attention Optimizations
# Enable Flash Attention (requires Triton)
export ORTMODULE_USE_FLASH_ATTENTION = 1
# Enable efficient attention ATen kernel
export ORTMODULE_USE_EFFICIENT_ATTENTION = 1
# Enable scaled dot product attention fallback
export ORTMODULE_ATEN_SDPA_FALLBACK = 1
Triton Integration
# Enable OpenAI Triton for kernel execution
export ORTMODULE_USE_TRITON = 1
# Specify custom Triton config
export ORTMODULE_TRITON_CONFIG_FILE = "triton_config.json"
# Enable kernel tuning
export ORTMODULE_ENABLE_TUNING = 1
export ORTMODULE_MAX_TUNING_DURATION_MS = 10000
export ORTMODULE_TUNING_RESULTS_PATH = "/path/to/results"
# Enable Triton debug mode
export ORTMODULE_TRITON_DEBUG = 1
Custom Autograd Functions
# Enable/disable custom autograd functions
export ORTMODULE_ENABLE_CUSTOM_AUTOGRAD = 1
# Allow gradient checkpointing
export ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT = 1
Debugging Options
# Print input data sparsity inspection
export ORTMODULE_PRINT_INPUT_DENSITY = 1
# Print memory statistics
export ORTMODULE_PRINT_MEMORY_STATS = 1
# Control deep copy before export
export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT = 1
FusedAdam Optimizer
Replace PyTorch’s AdamW with FusedAdam for faster parameter updates:
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.optim import FusedAdam
model = ORTModule(build_model())
optimizer = FusedAdam(model.parameters(), lr = 1e-4 )
Combined with DeepSpeed
Combine ORTModule with DeepSpeed for maximum performance:
import deepspeed
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.optim import FusedAdam
from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
# Wrap model with ORTModule first
model = ORTModule(build_model())
# Use FusedAdam
optimizer = FusedAdam(model.parameters(), lr = 1e-4 )
# Initialize DeepSpeed
model, optimizer, _, lr_scheduler = deepspeed.initialize(
model = model,
optimizer = optimizer,
args = args,
lr_scheduler = lr_scheduler,
mpu = mpu,
dist_init_required = False
)
# Wrap with FP16_Optimizer
optimizer = FP16_Optimizer(optimizer)
Memory Optimization
Reduce memory usage to train larger models:
import os
# Enable gradient checkpointing (level 1 or 2)
os.environ[ 'ORTMODULE_MEMORY_OPT_LEVEL' ] = '1'
# Enable memory-efficient gradient management
os.environ[ 'ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT' ] = '1'
from onnxruntime.training.ortmodule import ORTModule
model = ORTModule(build_model())
Memory Optimization Levels
Level 0 (default): No recomputation
Level 1 : Recompute detected subgraphs (equivalent to PyTorch gradient checkpointing)
Level 2 : Aggressive recomputation including compromised subgraphs
Best Practices
Wrap Order Matters
Recommended : Wrap with ORTModule before other wrappers
# Good
model = ORTModule(model)
model = DistributedDataParallel(model)
# Also works
model = ORTModule(model)
model = deepspeed.initialize( ... )
Compatibility Notes
✅ Compatible with torch.nn.parallel.DistributedDataParallel
✅ Compatible with DeepSpeed
✅ Compatible with PyTorch Lightning
❌ NOT compatible with torch.nn.DataParallel (use DDP instead)
Convergence Debugging
If you encounter convergence issues, collect activation statistics:
from onnxruntime.training.utils.hooks import (
GlobalSubscriberManager,
StatisticsSubscriber
)
model = ORTModule(model)
GlobalSubscriberManager.subscribe(
model,
[StatisticsSubscriber(
output_dir = "ort_out" ,
override_output_dir = True
)]
)
Typical speedups with ORTModule:
BERT-Large : 1.4x faster training
GPT-2 : 1.5x faster training
Vision Transformers : 1.3-1.6x faster training
Memory reduction : 20-40% lower peak memory usage with optimization
Next Steps
Distributed Training Scale ORTModule across multiple GPUs
Training Overview Learn about other training options