Documentation Index Fetch the complete documentation index at: https://mintlify.com/microsoft/onnxruntime/llms.txt
Use this file to discover all available pages before exploring further.
On-Device Training
The ONNX Runtime Training API enables training directly on edge devices, mobile platforms, and embedded systems. It provides a lightweight, cross-platform solution for federated learning, model personalization, and privacy-preserving on-device learning.
Overview
Unlike ORTModule which wraps PyTorch models, the On-Device Training API works with pre-compiled ONNX models. This approach offers:
Minimal dependencies : No PyTorch or heavy ML framework required
Small binary size : Optimized for resource-constrained devices
Cross-platform : Works on iOS, Android, Linux, Windows, and embedded systems
Fast startup : Pre-compiled models eliminate export overhead
Privacy-preserving : Train models locally without sending data to the cloud
Key Concepts
Training Artifacts
The Training API requires four artifacts:
Training Model (training_model.onnx): Base model + loss + gradient graph
Evaluation Model (eval_model.onnx): Base model + loss (optional)
Optimizer Model (optimizer.onnx): Optimizer update graph (optional)
Checkpoint (checkpoint.ckpt): Model parameters and optimizer state
These artifacts are generated offline using the generate_artifacts utility.
Quick Start
Step 1: Generate Training Artifacts
First, export your PyTorch model to ONNX and generate training artifacts:
import torch
import onnx
from onnxruntime.training import artifacts
# Export your PyTorch model
model = build_your_model()
sample_input = torch.randn( 1 , 3 , 224 , 224 )
torch.onnx.export(
model,
sample_input,
"base_model.onnx" ,
export_params = True ,
training = torch.onnx.TrainingMode. TRAINING ,
do_constant_folding = False
)
# Load the base model
base_model = onnx.load( "base_model.onnx" )
# Define which parameters require gradients
requires_grad = [ "conv1.weight" , "conv1.bias" , "fc.weight" , "fc.bias" ]
frozen_params = [ "conv2.weight" , "conv2.bias" ] # Optional: freeze some layers
# Generate training artifacts
artifacts.generate_artifacts(
base_model,
requires_grad = requires_grad,
frozen_params = frozen_params,
loss = artifacts.LossType.CrossEntropyLoss,
optimizer = artifacts.OptimType.AdamW
)
# This generates:
# - training_model.onnx
# - eval_model.onnx
# - optimizer.onnx
# - checkpoint (directory)
Step 2: Training Loop
Use the generated artifacts for training:
import numpy as np
from onnxruntime.training.api import Module, Optimizer, CheckpointState
# Load checkpoint
state = CheckpointState.load_checkpoint( "checkpoint.ckpt" )
# Create training module
model = Module(
"training_model.onnx" ,
state,
eval_model_uri = "eval_model.onnx" ,
device = "cuda" # or "cpu"
)
# Create optimizer
optimizer = Optimizer( "optimizer.onnx" , model)
# Training loop
model.train()
for epoch in range (num_epochs):
for batch in train_dataloader:
# Prepare inputs (as numpy arrays)
input_data = batch[ 'data' ].numpy()
labels = batch[ 'labels' ].numpy()
# Forward pass (returns loss and other outputs)
outputs = model(input_data, labels)
loss = outputs[ 0 ]
# Backward pass and optimizer step
optimizer.step()
print ( f "Epoch { epoch } , Loss: { loss } " )
# Evaluation
model.eval()
for batch in val_dataloader:
input_data = batch[ 'data' ].numpy()
labels = batch[ 'labels' ].numpy()
outputs = model(input_data, labels)
# Process evaluation outputs
# Save checkpoint
CheckpointState.save_checkpoint(state, "checkpoint_final.ckpt" )
Complete Example
Here’s a complete example with a simple classifier:
import torch
import torch.nn as nn
import onnx
import numpy as np
from onnxruntime.training import artifacts
from onnxruntime.training.api import Module, Optimizer, CheckpointState
# Step 1: Define and export PyTorch model
class SimpleClassifier ( nn . Module ):
def __init__ ( self ):
super (). __init__ ()
self .fc1 = nn.Linear( 784 , 128 )
self .relu = nn.ReLU()
self .fc2 = nn.Linear( 128 , 10 )
def forward ( self , x ):
x = self .fc1(x)
x = self .relu(x)
x = self .fc2(x)
return x
# Create and export model
pt_model = SimpleClassifier()
sample_input = torch.randn( 1 , 784 )
torch.onnx.export(
pt_model,
sample_input,
"classifier.onnx" ,
export_params = True ,
training = torch.onnx.TrainingMode. TRAINING ,
do_constant_folding = False ,
input_names = [ 'input' ],
output_names = [ 'output' ]
)
# Step 2: Generate training artifacts
base_model = onnx.load( "classifier.onnx" )
artifacts.generate_artifacts(
base_model,
requires_grad = [ "fc1.weight" , "fc1.bias" , "fc2.weight" , "fc2.bias" ],
loss = artifacts.LossType.CrossEntropyLoss,
optimizer = artifacts.OptimType.AdamW,
artifact_directory = "./training_artifacts"
)
# Step 3: Train the model
state = CheckpointState.load_checkpoint(
"./training_artifacts/checkpoint"
)
model = Module(
"./training_artifacts/training_model.onnx" ,
state,
eval_model_uri = "./training_artifacts/eval_model.onnx" ,
device = "cpu"
)
optimizer = Optimizer(
"./training_artifacts/optimizer.onnx" ,
model
)
# Generate dummy training data
def generate_batch ( batch_size = 32 ):
X = np.random.randn(batch_size, 784 ).astype(np.float32)
y = np.random.randint( 0 , 10 , size = batch_size).astype(np.int64)
return X, y
# Training loop
model.train()
for epoch in range ( 5 ):
epoch_loss = 0.0
num_batches = 100
for batch_idx in range (num_batches):
X, y = generate_batch()
# Forward and backward pass
outputs = model(X, y)
loss = outputs[ 0 ] # First output is the loss
# Update parameters
optimizer.step()
epoch_loss += loss
avg_loss = epoch_loss / num_batches
print ( f "Epoch { epoch + 1 } , Average Loss: { avg_loss :.4f} " )
# Save final checkpoint
CheckpointState.save_checkpoint(
state,
"./training_artifacts/checkpoint_final.ckpt"
)
print ( "Training completed!" )
Advanced Features
Custom Loss Functions
Define custom loss functions using ONNXBlock:
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training import artifacts
class WeightedAverageLoss ( onnxblock . Block ):
def __init__ ( self ):
self ._loss1 = onnxblock.loss.MSELoss()
self ._loss2 = onnxblock.loss.MSELoss()
self ._w1 = onnxblock.blocks.Constant( 0.4 )
self ._w2 = onnxblock.blocks.Constant( 0.6 )
self ._add = onnxblock.blocks.Add()
self ._mul = onnxblock.blocks.Mul()
def build ( self , loss_input_1 , loss_input_2 ):
return self ._add(
self ._mul( self ._w1(), self ._loss1(loss_input_1, target_name = "target1" )),
self ._mul( self ._w2(), self ._loss2(loss_input_2, target_name = "target2" ))
)
# Use custom loss
custom_loss = WeightedAverageLoss()
artifacts.generate_artifacts(
base_model,
requires_grad = requires_grad,
loss = custom_loss,
optimizer = artifacts.OptimType.AdamW
)
Nominal Checkpoints
For on-device applications, use nominal checkpoints to reduce package size:
artifacts.generate_artifacts(
base_model,
requires_grad = requires_grad,
loss = artifacts.LossType.CrossEntropyLoss,
optimizer = artifacts.OptimType.AdamW,
nominal_checkpoint = True # Generate lightweight checkpoint
)
Nominal checkpoints contain only parameter metadata, not actual values. They’re useful when:
Packaging models with mobile apps
Parameters will be loaded from a separate source
Reducing initial app download size
Convert models to ORT format for faster loading:
artifacts.generate_artifacts(
base_model,
requires_grad = requires_grad,
loss = artifacts.LossType.CrossEntropyLoss,
optimizer = artifacts.OptimType.AdamW,
ort_format = True # Generate .ort files instead of .onnx
)
Working with OrtValues
For better performance, use OrtValues instead of numpy arrays:
from onnxruntime.capi.onnxruntime_inference_collection import OrtValue
# Create OrtValue from numpy
input_data = np.random.randn( 32 , 784 ).astype(np.float32)
ort_input = OrtValue.ortvalue_from_numpy(input_data)
# Pass to model
outputs = model(ort_input, labels)
Checkpoint Management
The CheckpointState provides parameter access and management:
from onnxruntime.training.api import CheckpointState
# Load checkpoint
state = CheckpointState.load_checkpoint( "checkpoint.ckpt" )
# Access parameters
for param_name in state.parameters:
param = state.parameters[param_name]
print ( f " { param.name } : shape= { param.data.shape } , requires_grad= { param.requires_grad } " )
# Modify parameter
if param.name == "fc1.weight" :
param.data = new_weights # Update weights
# Access gradients
if param.grad is not None :
print ( f "Gradient: { param.grad } " )
# Save modified checkpoint
CheckpointState.save_checkpoint(state, "checkpoint_modified.ckpt" )
Supported Loss Functions
LossType.MSELoss: Mean squared error loss
LossType.CrossEntropyLoss: Cross-entropy loss for classification
LossType.BCEWithLogitsLoss: Binary cross-entropy with logits
LossType.L1Loss: L1 (absolute error) loss
Supported Optimizers
OptimType.AdamW: Adam with weight decay
OptimType.SGD: Stochastic gradient descent
Mobile and Edge Deployment
iOS Example
import onnxruntime_training
// Load checkpoint
let state = try CheckpointState. loadCheckpoint ( "checkpoint.ckpt" )
// Create module
let model = try Module (
trainModelUri : "training_model.onnx" ,
state : state,
device : "cpu"
)
// Training loop
model. setTrainingMode ( true )
for epoch in 0 ..< numEpochs {
let outputs = try model. call ( inputs : inputs)
try optimizer. step ()
}
Android Example
import ai.onnxruntime.training. *
// Load checkpoint
val state = CheckpointState. loadCheckpoint ( "checkpoint.ckpt" )
// Create module
val model = Module (
trainModelUri = "training_model.onnx" ,
state = state,
device = "cpu"
)
// Training loop
model. train ()
for (epoch in 0 until numEpochs) {
val outputs = model (inputs)
optimizer. step ()
}
Use Cases
Federated Learning
Train models across multiple devices without centralizing data:
Deploy initial model to all devices
Each device trains locally
Aggregate parameter updates on server
Distribute updated model
Model Personalization
Adapt pre-trained models to individual users:
Ship pre-trained model with app
Fine-tune on user’s device with their data
Keep personalized model local
Edge AI Applications
Continuous learning on edge devices:
Deploy model to edge device (IoT, robotics)
Collect local data
Train incrementally
Adapt to changing conditions
Use OrtValues : Avoid numpy conversion overhead
Batch Processing : Process multiple samples together
ORT Format : Use .ort format for faster loading
Quantization : Consider quantized models for mobile
Memory Management : Reuse buffers when possible
Next Steps
ORTModule For cloud-based PyTorch training
Training Overview Explore all training options