Documentation Index Fetch the complete documentation index at: https://mintlify.com/microsoft/onnxruntime/llms.txt
Use this file to discover all available pages before exploring further.
Distributed Training
ONNX Runtime Training seamlessly integrates with popular distributed training frameworks to scale training across multiple GPUs and nodes. This guide covers setup and best practices for distributed training with ORTModule.
Supported Frameworks
ORTModule works with:
PyTorch DDP (DistributedDataParallel) : Native PyTorch multi-GPU training
DeepSpeed : Memory-efficient training with ZeRO optimizer
DeepSpeed Pipeline Parallelism : Model parallelism for very large models
PyTorch FSDP : Fully Sharded Data Parallel
Horovod : Multi-framework distributed training
PyTorch DistributedDataParallel (DDP)
Basic Setup
Wrap your model with ORTModule before DDP:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from onnxruntime.training.ortmodule import ORTModule
def setup ( rank , world_size ):
dist.init_process_group(
backend = "nccl" ,
init_method = "env://" ,
world_size = world_size,
rank = rank
)
def train ( rank , world_size ):
setup(rank, world_size)
# Build model
model = build_model().to(rank)
# Wrap with ORTModule first
model = ORTModule(model)
# Then wrap with DDP
model = DDP(model, device_ids = [rank])
# Standard training loop
optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-4 )
for epoch in range (num_epochs):
for batch in dataloader:
optimizer.zero_grad()
inputs = batch[ 'input' ].to(rank)
labels = batch[ 'label' ].to(rank)
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__" :
world_size = torch.cuda.device_count()
torch.multiprocessing.spawn(
train,
args = (world_size,),
nprocs = world_size,
join = True
)
Launch Script
# Single node, 8 GPUs
python -m torch.distributed.launch \
--nproc_per_node=8 \
--use_env \
train.py
# Multi-node training
python -m torch.distributed.launch \
--nproc_per_node=8 \
--nnodes=4 \
--node_rank= $NODE_RANK \
--master_addr= $MASTER_ADDR \
--master_port= $MASTER_PORT \
--use_env \
train.py
Complete DDP Example
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from onnxruntime.training.ortmodule import ORTModule
class Trainer :
def __init__ ( self , model , train_dataset , rank , world_size ):
self .rank = rank
self .world_size = world_size
# Setup distributed
self .setup_distributed()
# Wrap model
self .model = model.to(rank)
self .model = ORTModule( self .model)
self .model = DDP( self .model, device_ids = [rank])
# Setup data loader with distributed sampler
self .sampler = DistributedSampler(
train_dataset,
num_replicas = world_size,
rank = rank,
shuffle = True
)
self .dataloader = DataLoader(
train_dataset,
batch_size = 32 ,
sampler = self .sampler,
num_workers = 4 ,
pin_memory = True
)
# Setup optimizer
self .optimizer = torch.optim.AdamW(
self .model.parameters(),
lr = 1e-4
)
self .criterion = torch.nn.CrossEntropyLoss()
def setup_distributed ( self ):
dist.init_process_group(
backend = "nccl" ,
init_method = "env://"
)
def train_epoch ( self , epoch ):
self .model.train()
self .sampler.set_epoch(epoch) # Shuffle differently each epoch
total_loss = 0.0
for batch_idx, (data, target) in enumerate ( self .dataloader):
data = data.to( self .rank)
target = target.to( self .rank)
self .optimizer.zero_grad()
output = self .model(data)
loss = self .criterion(output, target)
loss.backward()
self .optimizer.step()
total_loss += loss.item()
if batch_idx % 10 == 0 and self .rank == 0 :
print ( f "Epoch { epoch } , Batch { batch_idx } , Loss: { loss.item() :.4f} " )
# Reduce loss across all processes
total_loss_tensor = torch.tensor(total_loss).to( self .rank)
dist.all_reduce(total_loss_tensor, op = dist.ReduceOp. SUM )
avg_loss = total_loss_tensor.item() / ( len ( self .dataloader) * self .world_size)
if self .rank == 0 :
print ( f "Epoch { epoch } , Average Loss: { avg_loss :.4f} " )
def cleanup ( self ):
dist.destroy_process_group()
def main ():
rank = int (os.environ[ "LOCAL_RANK" ])
world_size = int (os.environ[ "WORLD_SIZE" ])
model = build_model()
dataset = load_dataset()
trainer = Trainer(model, dataset, rank, world_size)
for epoch in range ( 10 ):
trainer.train_epoch(epoch)
trainer.cleanup()
if __name__ == "__main__" :
main()
DeepSpeed Integration
DeepSpeed provides memory-efficient training through ZeRO optimizer stages.
Basic DeepSpeed Setup
import deepspeed
from onnxruntime.training.ortmodule import ORTModule
from onnxruntime.training.optim import FusedAdam
from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer
def create_model_and_optimizer ():
model = build_model()
# Wrap with ORTModule first
model = ORTModule(model)
# Use FusedAdam for better performance
optimizer = FusedAdam(model.parameters(), lr = 1e-4 )
return model, optimizer
def train ():
model, optimizer = create_model_and_optimizer()
# DeepSpeed configuration
ds_config = {
"train_batch_size" : 32 ,
"gradient_accumulation_steps" : 1 ,
"fp16" : {
"enabled" : True ,
"loss_scale" : 0 ,
"initial_scale_power" : 16
},
"zero_optimization" : {
"stage" : 2 ,
"offload_optimizer" : {
"device" : "cpu"
}
}
}
# Initialize DeepSpeed
model_engine, optimizer, _, _ = deepspeed.initialize(
model = model,
optimizer = optimizer,
config = ds_config
)
# Optionally wrap with FP16_Optimizer
optimizer = FP16_Optimizer(optimizer)
# Training loop
for epoch in range (num_epochs):
for batch in dataloader:
inputs = batch[ 'input' ].to(model_engine.local_rank)
labels = batch[ 'label' ].to(model_engine.local_rank)
outputs = model_engine(inputs)
loss = criterion(outputs, labels)
model_engine.backward(loss)
model_engine.step()
DeepSpeed Configuration File
{
"train_batch_size" : 32 ,
"train_micro_batch_size_per_gpu" : 4 ,
"gradient_accumulation_steps" : 8 ,
"optimizer" : {
"type" : "AdamW" ,
"params" : {
"lr" : 1e-4 ,
"weight_decay" : 0.01 ,
"betas" : [ 0.9 , 0.999 ],
"eps" : 1e-8
}
},
"scheduler" : {
"type" : "WarmupLR" ,
"params" : {
"warmup_min_lr" : 0 ,
"warmup_max_lr" : 1e-4 ,
"warmup_num_steps" : 1000
}
},
"fp16" : {
"enabled" : true ,
"loss_scale" : 0 ,
"initial_scale_power" : 16 ,
"loss_scale_window" : 1000 ,
"hysteresis" : 2 ,
"min_loss_scale" : 1
},
"zero_optimization" : {
"stage" : 2 ,
"allgather_partitions" : true ,
"allgather_bucket_size" : 2e8 ,
"reduce_scatter" : true ,
"reduce_bucket_size" : 2e8 ,
"overlap_comm" : true ,
"contiguous_gradients" : true
},
"wall_clock_breakdown" : false
}
Launch with DeepSpeed
deepspeed --num_gpus=8 train.py \
--deepspeed \
--deepspeed_config ds_config.json
DeepSpeed Pipeline Parallelism
For models too large for single GPU, use pipeline parallelism:
from onnxruntime.training.ortmodule import DebugOptions, LogLevel
from onnxruntime.training.ortmodule.experimental.pipe import ORTPipelineModule
import deepspeed
def create_pipeline_model ():
# Define model layers
layers = [
nn.Linear( 1024 , 2048 ),
nn.ReLU(),
nn.Linear( 2048 , 2048 ),
nn.ReLU(),
nn.Linear( 2048 , 1024 ),
# ... more layers
]
# Debug options (optional)
debug_options = DebugOptions(
save_onnx = True ,
log_level = LogLevel. INFO ,
onnx_prefix = "pipeline_model"
)
# Create pipeline module
pipeline_model = ORTPipelineModule(
layers,
num_stages = 4 , # Partition across 4 GPUs
partition_method = "parameters" ,
base_seed = 1234 ,
debug_options = debug_options
)
return pipeline_model
def train_pipeline ():
model = create_pipeline_model()
# Initialize DeepSpeed with pipeline config
model_engine, optimizer, _, _ = deepspeed.initialize(
model = model,
model_parameters = [p for p in model.parameters()],
config = "pipeline_config.json"
)
# Training loop
for batch in dataloader:
loss = model_engine(batch)
model_engine.backward(loss)
model_engine.step()
Data Loading Best Practices
Use DistributedSampler
Ensure each process gets different data:
from torch.utils.data import DataLoader, DistributedSampler
sampler = DistributedSampler(
dataset,
num_replicas = world_size,
rank = rank,
shuffle = True ,
drop_last = True
)
dataloader = DataLoader(
dataset,
batch_size = batch_size,
sampler = sampler,
num_workers = 4 ,
pin_memory = True ,
persistent_workers = True
)
# Update sampler epoch for proper shuffling
for epoch in range (num_epochs):
sampler.set_epoch(epoch)
for batch in dataloader:
# training code
Load Balancing for Variable Length Sequences
For NLP and speech tasks with variable length inputs:
from onnxruntime.training.utils.data import (
LoadBalancingDistributedSampler,
LoadBalancingDistributedBatchSampler
)
# Define complexity function (e.g., sequence length)
def complexity_fn ( sample ):
return len (sample[ 'input_ids' ])
# Define batch function
def batch_fn ( samples ):
# Custom batching logic
return collate_fn(samples)
# Create load-balanced sampler
sampler = LoadBalancingDistributedSampler(
dataset,
complexity_fn = complexity_fn
)
batch_sampler = LoadBalancingDistributedBatchSampler(
sampler,
batch_fn = batch_fn
)
loader = torch.utils.data.DataLoader(
dataset,
batch_sampler = batch_sampler
)
for epoch in range (num_epochs):
batch_sampler.set_epoch(epoch)
for batch in loader:
# training code
This helps avoid the “straggler problem” where some GPUs finish faster than others.
Environment Variables for Distributed Training
Essential Variables
# PyTorch DDP
export MASTER_ADDR = "localhost"
export MASTER_PORT = "29500"
export WORLD_SIZE = 8
export RANK = 0
export LOCAL_RANK = 0
# NCCL tuning
export NCCL_DEBUG = INFO
export NCCL_IB_DISABLE = 0
export NCCL_SOCKET_IFNAME = eth0
ORTModule Distributed Settings
# Disable fallback for consistent performance
export ORTMODULE_FALLBACK_POLICY = "FALLBACK_DISABLE"
# Enable memory optimizations
export ORTMODULE_MEMORY_OPT_LEVEL = 1
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT = 1
# Cache exported models (useful for multi-node)
export ORTMODULE_CACHE_DIR = "/shared/cache"
Checkpoint Saving and Loading
Save Checkpoints (Rank 0 only)
import torch.distributed as dist
def save_checkpoint ( model , optimizer , epoch , path ):
if dist.get_rank() == 0 :
checkpoint = {
'epoch' : epoch,
'model_state_dict' : model.module.state_dict(), # .module for DDP
'optimizer_state_dict' : optimizer.state_dict(),
}
torch.save(checkpoint, path)
print ( f "Checkpoint saved to { path } " )
def load_checkpoint ( model , optimizer , path ):
checkpoint = torch.load(path, map_location = f 'cuda: { dist.get_rank() } ' )
model.module.load_state_dict(checkpoint[ 'model_state_dict' ])
optimizer.load_state_dict(checkpoint[ 'optimizer_state_dict' ])
return checkpoint[ 'epoch' ]
Testing and Debugging
Test Distributed Setup
import torch
import torch.distributed as dist
def test_distributed ():
# Initialize process group
dist.init_process_group( backend = "nccl" )
rank = dist.get_rank()
world_size = dist.get_world_size()
print ( f "Rank { rank } / { world_size } initialized" )
# Test all-reduce
tensor = torch.tensor([rank]).cuda()
dist.all_reduce(tensor, op = dist.ReduceOp. SUM )
expected = sum ( range (world_size))
assert tensor.item() == expected, f "All-reduce failed: { tensor.item() } != { expected } "
print ( f "Rank { rank } passed all-reduce test" )
dist.destroy_process_group()
if __name__ == "__main__" :
test_distributed()
Enable Detailed Logging
export ORTMODULE_LOG_LEVEL = DEVINFO
export NCCL_DEBUG = INFO
export TORCH_DISTRIBUTED_DEBUG = DETAIL
Wrap Order : Always wrap with ORTModule before DDP/DeepSpeed
Batch Size : Use largest batch size that fits in memory
Gradient Accumulation : Simulate larger batches with accumulation
Mixed Precision : Enable FP16 training for faster computation
Communication Backend : Use NCCL for GPU training, Gloo for CPU
Pin Memory : Enable pin_memory=True in DataLoader
Persistent Workers : Set persistent_workers=True to avoid respawning
NCCL Tuning : Optimize NCCL settings for your network topology
Common Issues
Hanging on Initialization
# Check network connectivity
export NCCL_DEBUG = INFO
export NCCL_DEBUG_SUBSYS = ALL
# Use different port
export MASTER_PORT = 29501
Out of Memory
# Enable memory optimizations
export ORTMODULE_MEMORY_OPT_LEVEL = 2
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT = 1
# Reduce batch size or enable gradient accumulation
Gradient Synchronization Issues
# Find unused parameters
model = DDP(model, device_ids = [rank], find_unused_parameters = True )
Next Steps
ORTModule Learn more about ORTModule features
Training Overview Explore other training options