ℹ️ Skipped - page is already crawled
| Filter | Status | Condition | Details |
|---|---|---|---|
| HTTP status | PASS | download_http_code = 200 | HTTP 200 |
| Age cutoff | PASS | download_stamp > now() - 6 MONTH | 0.1 months ago |
| History drop | PASS | isNull(history_drop_reason) | No drop reason |
| Spam/ban | PASS | fh_dont_index != 1 AND ml_spam_score = 0 | ml_spam_score=0 |
| Canonical | PASS | meta_canonical IS NULL OR = '' OR = src_unparsed | Not set |
| Property | Value |
|---|---|
| URL | https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html |
| Last Crawled | 2026-04-10 06:14:59 (1 day ago) |
| First Indexed | 2025-05-27 08:20:58 (10 months ago) |
| HTTP Status Code | 200 |
| Meta Title | Getting Started with Distributed Checkpoint (DCP) — PyTorch Tutorials 2.11.0+cu130 documentation |
| Meta Description | null |
| Meta Canonical | null |
| Boilerpipe Text | Created On: Oct 02, 2023 | Last Updated: Jul 10, 2025 | Last Verified: Nov 05, 2024
Author
:
Iris Zhang
,
Rodrigo Kumpera
,
Chien-Chin Huang
,
Lucas Pasqualin
Note
View and edit this tutorial in
github
.
Prerequisites:
FullyShardedDataParallel API documents
torch.load API documents
Checkpointing AI models during distributed training could be challenging, as parameters and gradients are partitioned across trainers and the number of trainers available could change when you resume training.
Pytorch Distributed Checkpointing (DCP) can help make this process easier.
In this tutorial, we show how to use DCP APIs with a simple FSDP wrapped model.
How DCP works
#
torch.distributed.checkpoint()
enables saving and loading models from multiple ranks in parallel. You can use this module to save on any number of ranks in parallel,
and then re-shard across differing cluster topologies at load time.
Addditionally, through the use of modules in
torch.distributed.checkpoint.state_dict()
,
DCP offers support for gracefully handling
state_dict
generation and loading in distributed settings.
This includes managing fully-qualified-name (FQN) mappings across models and optimizers, and setting default parameters for PyTorch provided parallelisms.
DCP is different from
torch.save()
and
torch.load()
in a few significant ways:
It produces multiple files per checkpoint, with at least one per rank.
It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead.
DCP offers special handling of Stateful objects (formally defined in
torch.distributed.checkpoint.stateful
), automatically calling both
state_dict
and
load_state_dict
methods if they are defined.
Note
The code in this tutorial runs on an 8-GPU server, but it can be easily
generalized to other environments.
How to use DCP
#
Here we use a toy model wrapped with FSDP for demonstration purposes. Similarly, the APIs and logic can be applied to larger models for checkpointing.
Saving
#
Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it.
import
os
import
torch
import
torch.distributed
as
dist
import
torch.distributed.checkpoint
as
dcp
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
torch.distributed.fsdp
import
fully_shard
from
torch.distributed.checkpoint.state_dict
import
get_state_dict
,
set_state_dict
from
torch.distributed.checkpoint.stateful
import
Stateful
CHECKPOINT_DIR
=
"checkpoint"
class
AppState
(
Stateful
):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def
__init__
(
self
,
model
,
optimizer
=
None
):
self
.
model
=
model
self
.
optimizer
=
optimizer
def
state_dict
(
self
):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict
,
optimizer_state_dict
=
get_state_dict
(
self
.
model
,
self
.
optimizer
)
return
{
"model"
:
model_state_dict
,
"optim"
:
optimizer_state_dict
}
def
load_state_dict
(
self
,
state_dict
):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict
(
self
.
model
,
self
.
optimizer
,
model_state_dict
=
state_dict
[
"model"
],
optim_state_dict
=
state_dict
[
"optim"
]
)
class
ToyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
(
ToyModel
,
self
)
.
__init__
()
self
.
net1
=
nn
.
Linear
(
16
,
16
)
self
.
relu
=
nn
.
ReLU
()
self
.
net2
=
nn
.
Linear
(
16
,
8
)
def
forward
(
self
,
x
):
return
self
.
net2
(
self
.
relu
(
self
.
net1
(
x
)))
def
setup
(
rank
,
world_size
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"12355 "
# initialize the process group
dist
.
init_process_group
(
"nccl"
,
rank
=
rank
,
world_size
=
world_size
)
torch
.
cuda
.
set_device
(
rank
)
def
cleanup
():
dist
.
destroy_process_group
()
def
run_fsdp_checkpoint_save_example
(
rank
,
world_size
):
print
(
f
"Running basic FSDP checkpoint saving example on rank
{
rank
}
."
)
setup
(
rank
,
world_size
)
# create a model and move it to GPU with id rank
model
=
ToyModel
()
.
to
(
rank
)
model
=
fully_shard
(
model
)
loss_fn
=
nn
.
MSELoss
()
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.1
)
optimizer
.
zero_grad
()
model
(
torch
.
rand
(
8
,
16
,
device
=
"cuda"
))
.
sum
()
.
backward
()
optimizer
.
step
()
state_dict
=
{
"app"
:
AppState
(
model
,
optimizer
)
}
dcp
.
save
(
state_dict
,
checkpoint_id
=
CHECKPOINT_DIR
)
cleanup
()
if
__name__
==
"__main__"
:
world_size
=
torch
.
cuda
.
device_count
()
print
(
f
"Running fsdp checkpoint example on
{
world_size
}
devices."
)
mp
.
spawn
(
run_fsdp_checkpoint_save_example
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
,
)
Please go ahead and check the
checkpoint
directory. You should see checkpoint files corresponding to the number of files as shown below. For example, if you have 8 devices, you should see 8 files.
Loading
#
After saving, let’s create the same FSDP-wrapped model, and load the saved state dict from storage into the model. You can load in the same world size or different world size.
Please note that you will have to call
model.state_dict()
prior to loading and pass it to DCP’s
load_state_dict()
API.
This is fundamentally different from
torch.load()
, as
torch.load()
simply requires the path to the checkpoint prior for loading.
The reason that we need the
state_dict
prior to loading is:
DCP uses the pre-allocated storage from model state_dict to load from the checkpoint directory. During loading, the state_dict passed in will be updated in place.
DCP requires the sharding information from the model prior to loading to support resharding.
import
os
import
torch
import
torch.distributed
as
dist
import
torch.distributed.checkpoint
as
dcp
from
torch.distributed.checkpoint.stateful
import
Stateful
from
torch.distributed.checkpoint.state_dict
import
get_state_dict
,
set_state_dict
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
torch.distributed.fsdp
import
fully_shard
CHECKPOINT_DIR
=
"checkpoint"
class
AppState
(
Stateful
):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def
__init__
(
self
,
model
,
optimizer
=
None
):
self
.
model
=
model
self
.
optimizer
=
optimizer
def
state_dict
(
self
):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict
,
optimizer_state_dict
=
get_state_dict
(
self
.
model
,
self
.
optimizer
)
return
{
"model"
:
model_state_dict
,
"optim"
:
optimizer_state_dict
}
def
load_state_dict
(
self
,
state_dict
):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict
(
self
.
model
,
self
.
optimizer
,
model_state_dict
=
state_dict
[
"model"
],
optim_state_dict
=
state_dict
[
"optim"
]
)
class
ToyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
(
ToyModel
,
self
)
.
__init__
()
self
.
net1
=
nn
.
Linear
(
16
,
16
)
self
.
relu
=
nn
.
ReLU
()
self
.
net2
=
nn
.
Linear
(
16
,
8
)
def
forward
(
self
,
x
):
return
self
.
net2
(
self
.
relu
(
self
.
net1
(
x
)))
def
setup
(
rank
,
world_size
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"12355 "
# initialize the process group
dist
.
init_process_group
(
"nccl"
,
rank
=
rank
,
world_size
=
world_size
)
torch
.
cuda
.
set_device
(
rank
)
def
cleanup
():
dist
.
destroy_process_group
()
def
run_fsdp_checkpoint_load_example
(
rank
,
world_size
):
print
(
f
"Running basic FSDP checkpoint loading example on rank
{
rank
}
."
)
setup
(
rank
,
world_size
)
# create a model and move it to GPU with id rank
model
=
ToyModel
()
.
to
(
rank
)
model
=
fully_shard
(
model
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
0.1
)
state_dict
=
{
"app"
:
AppState
(
model
,
optimizer
)}
dcp
.
load
(
state_dict
=
state_dict
,
checkpoint_id
=
CHECKPOINT_DIR
,
)
cleanup
()
if
__name__
==
"__main__"
:
world_size
=
torch
.
cuda
.
device_count
()
print
(
f
"Running fsdp checkpoint example on
{
world_size
}
devices."
)
mp
.
spawn
(
run_fsdp_checkpoint_load_example
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
,
)
If you would like to load the saved checkpoint into a non-FSDP wrapped model in a non-distributed setup, perhaps for inference, you can also do that with DCP.
By default, DCP saves and loads a distributed
state_dict
in Single Program Multiple Data(SPMD) style. However if no process group is initialized, DCP infers
the intent is to save or load in “non-distributed” style, meaning entirely in the current process.
Note
Distributed checkpoint support for Multi-Program Multi-Data is still under development.
import
os
import
torch
import
torch.distributed.checkpoint
as
dcp
import
torch.nn
as
nn
CHECKPOINT_DIR
=
"checkpoint"
class
ToyModel
(
nn
.
Module
):
def
__init__
(
self
):
super
(
ToyModel
,
self
)
.
__init__
()
self
.
net1
=
nn
.
Linear
(
16
,
16
)
self
.
relu
=
nn
.
ReLU
()
self
.
net2
=
nn
.
Linear
(
16
,
8
)
def
forward
(
self
,
x
):
return
self
.
net2
(
self
.
relu
(
self
.
net1
(
x
)))
def
run_checkpoint_load_example
():
# create the non FSDP-wrapped toy model
model
=
ToyModel
()
state_dict
=
{
"model"
:
model
.
state_dict
(),
}
# since no progress group is initialized, DCP will disable any collectives.
dcp
.
load
(
state_dict
=
state_dict
,
checkpoint_id
=
CHECKPOINT_DIR
,
)
model
.
load_state_dict
(
state_dict
[
"model"
])
if
__name__
==
"__main__"
:
print
(
f
"Running basic DCP checkpoint loading example."
)
run_checkpoint_load_example
()
Formats
#
One drawback not yet mentioned is that DCP saves checkpoints in a format which is inherently different then those generated using torch.save.
Since this can be an issue when users wish to share models with users used to the torch.save format, or in general just want to add format flexibility
to their applications. For this case, we provide the
format_utils
module in
torch.distributed.checkpoint.format_utils
.
A command line utility is provided for the users convenience, which follows the following format:
python
-m
torch.distributed.checkpoint.format_utils
<mode>
<checkpoint
location>
<location
to
write
formats
to>
In the command above,
mode
is one of
torch_to_dcp
or
dcp_to_torch
.
Alternatively, methods are also provided for users who may wish to convert checkpoints directly.
import
os
import
torch
import
torch.distributed.checkpoint
as
DCP
from
torch.distributed.checkpoint.format_utils
import
dcp_to_torch_save
,
torch_save_to_dcp
CHECKPOINT_DIR
=
"checkpoint"
TORCH_SAVE_CHECKPOINT_DIR
=
"torch_save_checkpoint.pth"
# convert dcp model to torch.save (assumes checkpoint was generated as above)
dcp_to_torch_save
(
CHECKPOINT_DIR
,
TORCH_SAVE_CHECKPOINT_DIR
)
# converts the torch.save model back to DCP
torch_save_to_dcp
(
TORCH_SAVE_CHECKPOINT_DIR
,
f
"
{
CHECKPOINT_DIR
}
_new"
)
Conclusion
#
In conclusion, we have learned how to use DCP’s
save()
and
load()
APIs, as well as how they are different form
torch.save()
and
torch.load()
.
Additionally, we’ve learned how to use
get_state_dict()
and
set_state_dict()
to automatically manage parallelism-specific FQN’s and defaults during state dict
generation and loading.
For more information, please see the following:
Saving and loading models tutorial
Getting started with FullyShardedDataParallel tutorial |
| Markdown | 
Help us understand how you use PyTorch! Take our quick survey. [Take Survey](https://docs.google.com/forms/d/e/1FAIpQLSfsGAWBcfutRcbO6kfrShBMOMmRuBezRjjOcXk0e9I9luBzvQ/viewform)
[Skip to main content](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#main-content)
Back to top
[ ](https://docs.pytorch.org/tutorials/index.html)
[ ](https://docs.pytorch.org/tutorials/index.html)
[v2.11.0+cu130](https://docs.pytorch.org/tutorials/index.html)
- [Intro](https://docs.pytorch.org/tutorials/intro.html)
- [Learn the Basics](https://docs.pytorch.org/tutorials/beginner/basics/intro.html)
- [Introduction to PyTorch - YouTube Series](https://docs.pytorch.org/tutorials/beginner/introyt/introyt_index.html)
- [Deep Learning with PyTorch: A 60 Minute Blitz](https://docs.pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
- [Learning PyTorch with Examples](https://docs.pytorch.org/tutorials/beginner/pytorch_with_examples.html)
- [What is torch.nn really?](https://docs.pytorch.org/tutorials/beginner/nn_tutorial.html)
- [Understanding requires\_grad, retain\_grad, Leaf, and Non-leaf Tensors](https://docs.pytorch.org/tutorials/beginner/understanding_leaf_vs_nonleaf_tutorial.html)
- [NLP from Scratch](https://docs.pytorch.org/tutorials/intermediate/nlp_from_scratch_index.html)
- [Visualizing Models, Data, and Training with TensorBoard](https://docs.pytorch.org/tutorials/intermediate/tensorboard_tutorial.html)
- [A guide on good usage of non\_blocking and pin\_memory() in PyTorch](https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html)
- [Visualizing Gradients](https://docs.pytorch.org/tutorials/intermediate/visualizing_gradients_tutorial.html)
- [Compilers](https://docs.pytorch.org/tutorials/compilers_index.html)
- [Introduction to torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)
- [torch.compile End-to-End Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_compile_full_example.html)
- [Compiled Autograd: Capturing a larger backward graph for torch.compile](https://docs.pytorch.org/tutorials/intermediate/compiled_autograd_tutorial.html)
- [Inductor CPU backend debugging and profiling](https://docs.pytorch.org/tutorials/intermediate/inductor_debug_cpu.html)
- [Dynamic Compilation Control with torch.compiler.set\_stance](https://docs.pytorch.org/tutorials/recipes/torch_compiler_set_stance_tutorial.html)
- [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html)
- [(beta) Compiling the optimizer with torch.compile](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer.html)
- [(beta) Running the compiled optimizer with an LR Scheduler](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer_lr_scheduler.html)
- [Using Variable Length Attention in PyTorch](https://docs.pytorch.org/tutorials/intermediate/variable_length_attention_tutorial.html)
- [Using User-Defined Triton Kernels with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html)
- [Compile Time Caching in torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html)
- [Reducing torch.compile cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)
- [torch.export Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html)
- [torch.export AOTInductor Tutorial for Python runtime (Beta)](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html)
- [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html)
- [Introduction to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/intro_onnx.html)
- [Export a PyTorch model to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html)
- [Extending the ONNX Exporter Operator Support](https://docs.pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html)
- [Export a model with control flow to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html)
- [Building a Convolution/Batch Norm fuser with torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_conv_bn_fuser.html)
- [(beta) Building a Simple CPU Performance Profiler with FX](https://docs.pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html)
- [Domains](https://docs.pytorch.org/tutorials/domains.html)
- [TorchVision Object Detection Finetuning Tutorial](https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
- [Transfer Learning for Computer Vision Tutorial](https://docs.pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)
- [Adversarial Example Generation](https://docs.pytorch.org/tutorials/beginner/fgsm_tutorial.html)
- [DCGAN Tutorial](https://docs.pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html)
- [Spatial Transformer Networks Tutorial](https://docs.pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html)
- [Reinforcement Learning (DQN) Tutorial](https://docs.pytorch.org/tutorials/intermediate/reinforcement_q_learning.html)
- [Reinforcement Learning (PPO) with TorchRL Tutorial](https://docs.pytorch.org/tutorials/intermediate/reinforcement_ppo.html)
- [Train a Mario-playing RL Agent](https://docs.pytorch.org/tutorials/intermediate/mario_rl_tutorial.html)
- [Pendulum: Writing your environment and transforms with TorchRL](https://docs.pytorch.org/tutorials/advanced/pendulum.html)
- [Introduction to TorchRec](https://docs.pytorch.org/tutorials/intermediate/torchrec_intro_tutorial.html)
- [Exploring TorchRec sharding](https://docs.pytorch.org/tutorials/advanced/sharding.html)
- [Distributed](https://docs.pytorch.org/tutorials/distributed.html)
- [PyTorch Distributed Overview](https://docs.pytorch.org/tutorials/beginner/dist_overview.html)
- [Distributed Data Parallel in PyTorch - Video Tutorials](https://docs.pytorch.org/tutorials/beginner/ddp_series_intro.html)
- [Getting Started with Distributed Data Parallel](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html)
- [Writing Distributed Applications with PyTorch](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html)
- [Getting Started with Fully Sharded Data Parallel (FSDP2)](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- [Introduction to Libuv TCPStore Backend](https://docs.pytorch.org/tutorials/intermediate/TCPStore_libuv_backend.html)
- [Large Scale Transformer model training with Tensor Parallel (TP)](https://docs.pytorch.org/tutorials/intermediate/TP_tutorial.html)
- [Introduction to Distributed Pipeline Parallelism](https://docs.pytorch.org/tutorials/intermediate/pipelining_tutorial.html)
- [Customize Process Group Backends Using Cpp Extensions](https://docs.pytorch.org/tutorials/intermediate/process_group_cpp_extension_tutorial.html)
- [Getting Started with Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html)
- [Implementing a Parameter Server Using Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html)
- [Implementing Batch RPC Processing Using Asynchronous Executions](https://docs.pytorch.org/tutorials/intermediate/rpc_async_execution.html)
- [Interactive Distributed Applications with Monarch](https://docs.pytorch.org/tutorials/intermediate/monarch_distributed_tutorial.html)
- [Combining Distributed DataParallel with Distributed RPC Framework](https://docs.pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html)
- [Distributed Training with Uneven Inputs Using the Join Context Manager](https://docs.pytorch.org/tutorials/advanced/generic_join.html)
- [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html)
- [Deep Dive](https://docs.pytorch.org/tutorials/deep-dive.html)
- [Profiling your PyTorch Module](https://docs.pytorch.org/tutorials/beginner/profiler.html)
- [Parametrizations Tutorial](https://docs.pytorch.org/tutorials/intermediate/parametrizations.html)
- [Pruning Tutorial](https://docs.pytorch.org/tutorials/intermediate/pruning_tutorial.html)
- [Inductor CPU backend debugging and profiling](https://docs.pytorch.org/tutorials/intermediate/inductor_debug_cpu.html)
- [(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html)
- [Knowledge Distillation Tutorial](https://docs.pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html)
- [Channels Last Memory Format in PyTorch](https://docs.pytorch.org/tutorials/intermediate/memory_format_tutorial.html)
- [Forward-mode Automatic Differentiation (Beta)](https://docs.pytorch.org/tutorials/intermediate/forward_ad_usage.html)
- [Jacobians, Hessians, hvp, vhp, and more: composing function transforms](https://docs.pytorch.org/tutorials/intermediate/jacobians_hessians.html)
- [Model ensembling](https://docs.pytorch.org/tutorials/intermediate/ensembling.html)
- [Per-sample-gradients](https://docs.pytorch.org/tutorials/intermediate/per_sample_grads.html)
- [Using the PyTorch C++ Frontend](https://docs.pytorch.org/tutorials/advanced/cpp_frontend.html)
- [Autograd in C++ Frontend](https://docs.pytorch.org/tutorials/advanced/cpp_autograd.html)
- [Extension](https://docs.pytorch.org/tutorials/extension.html)
- [PyTorch Custom Operators](https://docs.pytorch.org/tutorials/advanced/custom_ops_landing_page.html)
- [Custom Python Operators](https://docs.pytorch.org/tutorials/advanced/python_custom_ops.html)
- [Custom C++ and CUDA Operators](https://docs.pytorch.org/tutorials/advanced/cpp_custom_ops.html)
- [Double Backward with Custom Functions](https://docs.pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html)
- [Fusing Convolution and Batch Norm using Custom Function](https://docs.pytorch.org/tutorials/intermediate/custom_function_conv_bn_tutorial.html)
- [Registering a Dispatched Operator in C++](https://docs.pytorch.org/tutorials/advanced/dispatcher.html)
- [Extending dispatcher for a new backend in C++](https://docs.pytorch.org/tutorials/advanced/extend_dispatcher.html)
- [Facilitating New Backend Integration by PrivateUse1](https://docs.pytorch.org/tutorials/advanced/privateuseone.html)
- [Ecosystem](https://docs.pytorch.org/tutorials/ecosystem.html)
- [Hyperparameter tuning using Ray Tune](https://docs.pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html)
- [Serve PyTorch models at scale with Ray Serve](https://docs.pytorch.org/tutorials/beginner/serving_tutorial.html)
- [Multi-Objective NAS with Ax](https://docs.pytorch.org/tutorials/intermediate/ax_multiobjective_nas_tutorial.html)
- [PyTorch Profiler With TensorBoard](https://docs.pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html)
- [Real Time Inference on Raspberry Pi 4 and 5 (40 fps!)](https://docs.pytorch.org/tutorials/intermediate/realtime_rpi.html)
- [Mosaic: Memory Profiling for PyTorch](https://docs.pytorch.org/tutorials/beginner/mosaic_memory_profiling_tutorial.html)
- [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html)
- More
- [Recipes](https://docs.pytorch.org/tutorials/recipes_index.html)
- [Unstable](https://docs.pytorch.org/tutorials/unstable_index.html)
[Go to pytorch.org](https://pytorch.org/)
- [X](https://x.com/PyTorch)
- [GitHub](https://github.com/pytorch/tutorials)
- [Discourse](https://dev-discuss.pytorch.org/)
- [PyPi](https://pypi.org/project/torch/)
[v2.11.0+cu130](https://docs.pytorch.org/tutorials/index.html)
- [Intro](https://docs.pytorch.org/tutorials/intro.html)
- [Learn the Basics](https://docs.pytorch.org/tutorials/beginner/basics/intro.html)
- [Introduction to PyTorch - YouTube Series](https://docs.pytorch.org/tutorials/beginner/introyt/introyt_index.html)
- [Deep Learning with PyTorch: A 60 Minute Blitz](https://docs.pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
- [Learning PyTorch with Examples](https://docs.pytorch.org/tutorials/beginner/pytorch_with_examples.html)
- [What is torch.nn really?](https://docs.pytorch.org/tutorials/beginner/nn_tutorial.html)
- [Understanding requires\_grad, retain\_grad, Leaf, and Non-leaf Tensors](https://docs.pytorch.org/tutorials/beginner/understanding_leaf_vs_nonleaf_tutorial.html)
- [NLP from Scratch](https://docs.pytorch.org/tutorials/intermediate/nlp_from_scratch_index.html)
- [Visualizing Models, Data, and Training with TensorBoard](https://docs.pytorch.org/tutorials/intermediate/tensorboard_tutorial.html)
- [A guide on good usage of non\_blocking and pin\_memory() in PyTorch](https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html)
- [Visualizing Gradients](https://docs.pytorch.org/tutorials/intermediate/visualizing_gradients_tutorial.html)
- [Compilers](https://docs.pytorch.org/tutorials/compilers_index.html)
- [Introduction to torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)
- [torch.compile End-to-End Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_compile_full_example.html)
- [Compiled Autograd: Capturing a larger backward graph for torch.compile](https://docs.pytorch.org/tutorials/intermediate/compiled_autograd_tutorial.html)
- [Inductor CPU backend debugging and profiling](https://docs.pytorch.org/tutorials/intermediate/inductor_debug_cpu.html)
- [Dynamic Compilation Control with torch.compiler.set\_stance](https://docs.pytorch.org/tutorials/recipes/torch_compiler_set_stance_tutorial.html)
- [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html)
- [(beta) Compiling the optimizer with torch.compile](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer.html)
- [(beta) Running the compiled optimizer with an LR Scheduler](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer_lr_scheduler.html)
- [Using Variable Length Attention in PyTorch](https://docs.pytorch.org/tutorials/intermediate/variable_length_attention_tutorial.html)
- [Using User-Defined Triton Kernels with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html)
- [Compile Time Caching in torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html)
- [Reducing torch.compile cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)
- [torch.export Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html)
- [torch.export AOTInductor Tutorial for Python runtime (Beta)](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html)
- [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html)
- [Introduction to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/intro_onnx.html)
- [Export a PyTorch model to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html)
- [Extending the ONNX Exporter Operator Support](https://docs.pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html)
- [Export a model with control flow to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html)
- [Building a Convolution/Batch Norm fuser with torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_conv_bn_fuser.html)
- [(beta) Building a Simple CPU Performance Profiler with FX](https://docs.pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html)
- [Domains](https://docs.pytorch.org/tutorials/domains.html)
- [TorchVision Object Detection Finetuning Tutorial](https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
- [Transfer Learning for Computer Vision Tutorial](https://docs.pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)
- [Adversarial Example Generation](https://docs.pytorch.org/tutorials/beginner/fgsm_tutorial.html)
- [DCGAN Tutorial](https://docs.pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html)
- [Spatial Transformer Networks Tutorial](https://docs.pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html)
- [Reinforcement Learning (DQN) Tutorial](https://docs.pytorch.org/tutorials/intermediate/reinforcement_q_learning.html)
- [Reinforcement Learning (PPO) with TorchRL Tutorial](https://docs.pytorch.org/tutorials/intermediate/reinforcement_ppo.html)
- [Train a Mario-playing RL Agent](https://docs.pytorch.org/tutorials/intermediate/mario_rl_tutorial.html)
- [Pendulum: Writing your environment and transforms with TorchRL](https://docs.pytorch.org/tutorials/advanced/pendulum.html)
- [Introduction to TorchRec](https://docs.pytorch.org/tutorials/intermediate/torchrec_intro_tutorial.html)
- [Exploring TorchRec sharding](https://docs.pytorch.org/tutorials/advanced/sharding.html)
- [Distributed](https://docs.pytorch.org/tutorials/distributed.html)
- [PyTorch Distributed Overview](https://docs.pytorch.org/tutorials/beginner/dist_overview.html)
- [Distributed Data Parallel in PyTorch - Video Tutorials](https://docs.pytorch.org/tutorials/beginner/ddp_series_intro.html)
- [Getting Started with Distributed Data Parallel](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html)
- [Writing Distributed Applications with PyTorch](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html)
- [Getting Started with Fully Sharded Data Parallel (FSDP2)](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- [Introduction to Libuv TCPStore Backend](https://docs.pytorch.org/tutorials/intermediate/TCPStore_libuv_backend.html)
- [Large Scale Transformer model training with Tensor Parallel (TP)](https://docs.pytorch.org/tutorials/intermediate/TP_tutorial.html)
- [Introduction to Distributed Pipeline Parallelism](https://docs.pytorch.org/tutorials/intermediate/pipelining_tutorial.html)
- [Customize Process Group Backends Using Cpp Extensions](https://docs.pytorch.org/tutorials/intermediate/process_group_cpp_extension_tutorial.html)
- [Getting Started with Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html)
- [Implementing a Parameter Server Using Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html)
- [Implementing Batch RPC Processing Using Asynchronous Executions](https://docs.pytorch.org/tutorials/intermediate/rpc_async_execution.html)
- [Interactive Distributed Applications with Monarch](https://docs.pytorch.org/tutorials/intermediate/monarch_distributed_tutorial.html)
- [Combining Distributed DataParallel with Distributed RPC Framework](https://docs.pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html)
- [Distributed Training with Uneven Inputs Using the Join Context Manager](https://docs.pytorch.org/tutorials/advanced/generic_join.html)
- [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html)
- [Deep Dive](https://docs.pytorch.org/tutorials/deep-dive.html)
- [Profiling your PyTorch Module](https://docs.pytorch.org/tutorials/beginner/profiler.html)
- [Parametrizations Tutorial](https://docs.pytorch.org/tutorials/intermediate/parametrizations.html)
- [Pruning Tutorial](https://docs.pytorch.org/tutorials/intermediate/pruning_tutorial.html)
- [Inductor CPU backend debugging and profiling](https://docs.pytorch.org/tutorials/intermediate/inductor_debug_cpu.html)
- [(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html)
- [Knowledge Distillation Tutorial](https://docs.pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html)
- [Channels Last Memory Format in PyTorch](https://docs.pytorch.org/tutorials/intermediate/memory_format_tutorial.html)
- [Forward-mode Automatic Differentiation (Beta)](https://docs.pytorch.org/tutorials/intermediate/forward_ad_usage.html)
- [Jacobians, Hessians, hvp, vhp, and more: composing function transforms](https://docs.pytorch.org/tutorials/intermediate/jacobians_hessians.html)
- [Model ensembling](https://docs.pytorch.org/tutorials/intermediate/ensembling.html)
- [Per-sample-gradients](https://docs.pytorch.org/tutorials/intermediate/per_sample_grads.html)
- [Using the PyTorch C++ Frontend](https://docs.pytorch.org/tutorials/advanced/cpp_frontend.html)
- [Autograd in C++ Frontend](https://docs.pytorch.org/tutorials/advanced/cpp_autograd.html)
- [Extension](https://docs.pytorch.org/tutorials/extension.html)
- [PyTorch Custom Operators](https://docs.pytorch.org/tutorials/advanced/custom_ops_landing_page.html)
- [Custom Python Operators](https://docs.pytorch.org/tutorials/advanced/python_custom_ops.html)
- [Custom C++ and CUDA Operators](https://docs.pytorch.org/tutorials/advanced/cpp_custom_ops.html)
- [Double Backward with Custom Functions](https://docs.pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html)
- [Fusing Convolution and Batch Norm using Custom Function](https://docs.pytorch.org/tutorials/intermediate/custom_function_conv_bn_tutorial.html)
- [Registering a Dispatched Operator in C++](https://docs.pytorch.org/tutorials/advanced/dispatcher.html)
- [Extending dispatcher for a new backend in C++](https://docs.pytorch.org/tutorials/advanced/extend_dispatcher.html)
- [Facilitating New Backend Integration by PrivateUse1](https://docs.pytorch.org/tutorials/advanced/privateuseone.html)
- [Ecosystem](https://docs.pytorch.org/tutorials/ecosystem.html)
- [Hyperparameter tuning using Ray Tune](https://docs.pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html)
- [Serve PyTorch models at scale with Ray Serve](https://docs.pytorch.org/tutorials/beginner/serving_tutorial.html)
- [Multi-Objective NAS with Ax](https://docs.pytorch.org/tutorials/intermediate/ax_multiobjective_nas_tutorial.html)
- [PyTorch Profiler With TensorBoard](https://docs.pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html)
- [Real Time Inference on Raspberry Pi 4 and 5 (40 fps!)](https://docs.pytorch.org/tutorials/intermediate/realtime_rpi.html)
- [Mosaic: Memory Profiling for PyTorch](https://docs.pytorch.org/tutorials/beginner/mosaic_memory_profiling_tutorial.html)
- [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html)
- [Recipes](https://docs.pytorch.org/tutorials/recipes_index.html)
- [Defining a Neural Network in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/defining_a_neural_network.html)
- [(beta) Using TORCH\_LOGS python API with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_logs.html)
- [What is a state\_dict in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html)
- [Warmstarting model using parameters from a different model in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/warmstarting_model_using_parameters_from_a_different_model.html)
- [Zeroing out gradients in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/zeroing_out_gradients.html)
- [PyTorch Profiler](https://docs.pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
- [Model Interpretability using Captum](https://docs.pytorch.org/tutorials/recipes/recipes/Captum_Recipe.html)
- [How to use TensorBoard with PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html)
- [Automatic Mixed Precision](https://docs.pytorch.org/tutorials/recipes/recipes/amp_recipe.html)
- [Performance Tuning Guide](https://docs.pytorch.org/tutorials/recipes/recipes/tuning_guide.html)
- [(beta) Compiling the optimizer with torch.compile](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer.html)
- [Timer quick start](https://docs.pytorch.org/tutorials/recipes/recipes/timer_quick_start.html)
- [Shard Optimizer States with ZeroRedundancyOptimizer](https://docs.pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html)
- [Getting Started with CommDebugMode](https://docs.pytorch.org/tutorials/recipes/distributed_comm_debug_mode.html)
- [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html)
- [PyTorch Benchmark](https://docs.pytorch.org/tutorials/recipes/recipes/benchmark.html)
- [Tips for Loading an nn.Module from a Checkpoint](https://docs.pytorch.org/tutorials/recipes/recipes/module_load_state_dict_tips.html)
- [Reasoning about Shapes in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/reasoning_about_shapes.html)
- [Extension points in nn.Module for load\_state\_dict and tensor subclasses](https://docs.pytorch.org/tutorials/recipes/recipes/swap_tensors.html)
- [torch.export AOTInductor Tutorial for Python runtime (Beta)](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html)
- [How to use TensorBoard with PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html)
- [(beta) Utilizing Torch Function modes with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_torch_function_modes.html)
- [(beta) Running the compiled optimizer with an LR Scheduler](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer_lr_scheduler.html)
- [Explicit horizontal fusion with foreach\_map and torch.compile](https://docs.pytorch.org/tutorials/recipes/foreach_map.html)
- [Using User-Defined Triton Kernels with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html)
- [Compile Time Caching in torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html)
- [Compile Time Caching Configuration](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_configuration_tutorial.html)
- [Reducing torch.compile cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)
- [Reducing AoT cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_aot.html)
- [Ease-of-use quantization for PyTorch with Intel® Neural Compressor](https://docs.pytorch.org/tutorials/recipes/intel_neural_compressor_for_pytorch.html)
- [Getting Started with DeviceMesh](https://docs.pytorch.org/tutorials/recipes/distributed_device_mesh.html)
- [Getting Started with Distributed Checkpoint (DCP)](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html)
- [Asynchronous Saving with Distributed Checkpoint (DCP)](https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html)
- [DebugMode: Recording Dispatched Operations and Numerical Debugging](https://docs.pytorch.org/tutorials/recipes/debug_mode_tutorial.html)
- [Unstable](https://docs.pytorch.org/tutorials/unstable_index.html)
- [Introduction to Context Parallel](https://docs.pytorch.org/tutorials/unstable/context_parallel.html)
- [Flight Recorder for Debugging Stuck Jobs](https://docs.pytorch.org/tutorials/unstable/flight_recorder_tutorial.html)
- [TorchInductor C++ Wrapper Tutorial](https://docs.pytorch.org/tutorials/unstable/inductor_cpp_wrapper_tutorial.html)
- [How to use torch.compile on Windows CPU/XPU](https://docs.pytorch.org/tutorials/unstable/inductor_windows.html)
- [torch.vmap](https://docs.pytorch.org/tutorials/unstable/vmap_recipe.html)
- [Getting Started with Nested Tensors](https://docs.pytorch.org/tutorials/unstable/nestedtensor.html)
- [MaskedTensor Overview](https://docs.pytorch.org/tutorials/unstable/maskedtensor_overview.html)
- [MaskedTensor Sparsity](https://docs.pytorch.org/tutorials/unstable/maskedtensor_sparsity.html)
- [MaskedTensor Advanced Semantics](https://docs.pytorch.org/tutorials/unstable/maskedtensor_advanced_semantics.html)
- [Efficiently writing “sparse” semantics for Adagrad with MaskedTensor](https://docs.pytorch.org/tutorials/unstable/maskedtensor_adagrad.html)
- [Autoloading Out-of-Tree Extension](https://docs.pytorch.org/tutorials/unstable/python_extension_autoload.html)
- [Using Max-Autotune Compilation on CPU for Better Performance](https://docs.pytorch.org/tutorials/unstable/max_autotune_on_CPU_tutorial.html)
[Go to pytorch.org](https://pytorch.org/)
- [X](https://x.com/PyTorch)
- [GitHub](https://github.com/pytorch/tutorials)
- [Discourse](https://dev-discuss.pytorch.org/)
- [PyPi](https://pypi.org/project/torch/)
Section Navigation
- [Defining a Neural Network in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/defining_a_neural_network.html)
- [(beta) Using TORCH\_LOGS python API with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_logs.html)
- [What is a state\_dict in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html)
- [Warmstarting model using parameters from a different model in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/warmstarting_model_using_parameters_from_a_different_model.html)
- [Zeroing out gradients in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/zeroing_out_gradients.html)
- [PyTorch Profiler](https://docs.pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
- [Model Interpretability using Captum](https://docs.pytorch.org/tutorials/recipes/recipes/Captum_Recipe.html)
- [How to use TensorBoard with PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html)
- [Automatic Mixed Precision](https://docs.pytorch.org/tutorials/recipes/recipes/amp_recipe.html)
- [Performance Tuning Guide](https://docs.pytorch.org/tutorials/recipes/recipes/tuning_guide.html)
- [(beta) Compiling the optimizer with torch.compile](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer.html)
- [Timer quick start](https://docs.pytorch.org/tutorials/recipes/recipes/timer_quick_start.html)
- [Shard Optimizer States with ZeroRedundancyOptimizer](https://docs.pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html)
- [Getting Started with `CommDebugMode`](https://docs.pytorch.org/tutorials/recipes/distributed_comm_debug_mode.html)
- [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html)
- [PyTorch Benchmark](https://docs.pytorch.org/tutorials/recipes/recipes/benchmark.html)
- [Tips for Loading an `nn.Module` from a Checkpoint](https://docs.pytorch.org/tutorials/recipes/recipes/module_load_state_dict_tips.html)
- [Reasoning about Shapes in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/reasoning_about_shapes.html)
- [Extension points in `nn.Module` for `load_state_dict` and tensor subclasses](https://docs.pytorch.org/tutorials/recipes/recipes/swap_tensors.html)
- [`torch.export` AOTInductor Tutorial for Python runtime (Beta)](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html)
- [How to use TensorBoard with PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html)
- [(beta) Utilizing Torch Function modes with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_torch_function_modes.html)
- [(beta) Running the compiled optimizer with an LR Scheduler](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer_lr_scheduler.html)
- [Explicit horizontal fusion with foreach\_map and torch.compile](https://docs.pytorch.org/tutorials/recipes/foreach_map.html)
- [Using User-Defined Triton Kernels with `torch.compile`](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html)
- [Compile Time Caching in `torch.compile`](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html)
- [Compile Time Caching Configuration](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_configuration_tutorial.html)
- [Reducing torch.compile cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)
- [Reducing AoT cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_aot.html)
- [Ease-of-use quantization for PyTorch with Intel® Neural Compressor](https://docs.pytorch.org/tutorials/recipes/intel_neural_compressor_for_pytorch.html)
- [Getting Started with DeviceMesh](https://docs.pytorch.org/tutorials/recipes/distributed_device_mesh.html)
- [Getting Started with Distributed Checkpoint (DCP)](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html)
- [Asynchronous Saving with Distributed Checkpoint (DCP)](https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html)
- [DebugMode: Recording Dispatched Operations and Numerical Debugging](https://docs.pytorch.org/tutorials/recipes/debug_mode_tutorial.html)
- [Recipes](https://docs.pytorch.org/tutorials/recipes_index.html)
- Getting...
Rate this Page
★ ★ ★ ★ ★
recipes/distributed\_checkpoint\_recipe
[ Run in Google Colab Colab]()
[ Download Notebook Notebook]()
[ View on GitHub GitHub]()
# Getting Started with Distributed Checkpoint (DCP)[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#getting-started-with-distributed-checkpoint-dcp "Link to this heading")
Created On: Oct 02, 2023 \| Last Updated: Jul 10, 2025 \| Last Verified: Nov 05, 2024
**Author**: [Iris Zhang](https://github.com/wz337), [Rodrigo Kumpera](https://github.com/kumpera), [Chien-Chin Huang](https://github.com/fegin), [Lucas Pasqualin](https://github.com/lucasllc)
Note
[](https://docs.pytorch.org/tutorials/_images/pencil-16.png) View and edit this tutorial in [github](https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst).
Prerequisites:
- [FullyShardedDataParallel API documents](https://pytorch.org/docs/master/fsdp.html)
- [torch.load API documents](https://pytorch.org/docs/stable/generated/torch.load.html)
Checkpointing AI models during distributed training could be challenging, as parameters and gradients are partitioned across trainers and the number of trainers available could change when you resume training. Pytorch Distributed Checkpointing (DCP) can help make this process easier.
In this tutorial, we show how to use DCP APIs with a simple FSDP wrapped model.
## How DCP works[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#how-dcp-works "Link to this heading")
`torch.distributed.checkpoint()` enables saving and loading models from multiple ranks in parallel. You can use this module to save on any number of ranks in parallel, and then re-shard across differing cluster topologies at load time.
Addditionally, through the use of modules in `torch.distributed.checkpoint.state_dict()`, DCP offers support for gracefully handling `state_dict` generation and loading in distributed settings. This includes managing fully-qualified-name (FQN) mappings across models and optimizers, and setting default parameters for PyTorch provided parallelisms.
DCP is different from [`torch.save()`](https://docs.pytorch.org/docs/stable/generated/torch.save.html#torch.save "(in PyTorch v2.11)") and [`torch.load()`](https://docs.pytorch.org/docs/stable/generated/torch.load.html#torch.load "(in PyTorch v2.11)") in a few significant ways:
- It produces multiple files per checkpoint, with at least one per rank.
- It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead.
- DCP offers special handling of Stateful objects (formally defined in torch.distributed.checkpoint.stateful), automatically calling both state\_dict and load\_state\_dict methods if they are defined.
Note
The code in this tutorial runs on an 8-GPU server, but it can be easily generalized to other environments.
## How to use DCP[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#how-to-use-dcp "Link to this heading")
Here we use a toy model wrapped with FSDP for demonstration purposes. Similarly, the APIs and logic can be applied to larger models for checkpointing.
### Saving[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#saving "Link to this heading")
Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it.
```
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import fully_shard
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
CHECKPOINT_DIR = "checkpoint"
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355 "
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def run_fsdp_checkpoint_save_example(rank, world_size):
print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
setup(rank, world_size)
# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = fully_shard(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
optimizer.zero_grad()
model(torch.rand(8, 16, device="cuda")).sum().backward()
optimizer.step()
state_dict = { "app": AppState(model, optimizer) }
dcp.save(state_dict, checkpoint_id=CHECKPOINT_DIR)
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running fsdp checkpoint example on {world_size} devices.")
mp.spawn(
run_fsdp_checkpoint_save_example,
args=(world_size,),
nprocs=world_size,
join=True,
)
```
Please go ahead and check the checkpoint directory. You should see checkpoint files corresponding to the number of files as shown below. For example, if you have 8 devices, you should see 8 files.
[](https://docs.pytorch.org/tutorials/_images/distributed_checkpoint_generated_files.png)
### Loading[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#loading "Link to this heading")
After saving, let’s create the same FSDP-wrapped model, and load the saved state dict from storage into the model. You can load in the same world size or different world size.
Please note that you will have to call `model.state_dict()` prior to loading and pass it to DCP’s `load_state_dict()` API. This is fundamentally different from [`torch.load()`](https://docs.pytorch.org/docs/stable/generated/torch.load.html#torch.load "(in PyTorch v2.11)"), as [`torch.load()`](https://docs.pytorch.org/docs/stable/generated/torch.load.html#torch.load "(in PyTorch v2.11)") simply requires the path to the checkpoint prior for loading. The reason that we need the `state_dict` prior to loading is:
- DCP uses the pre-allocated storage from model state\_dict to load from the checkpoint directory. During loading, the state\_dict passed in will be updated in place.
- DCP requires the sharding information from the model prior to loading to support resharding.
```
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import fully_shard
CHECKPOINT_DIR = "checkpoint"
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355 "
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def run_fsdp_checkpoint_load_example(rank, world_size):
print(f"Running basic FSDP checkpoint loading example on rank {rank}.")
setup(rank, world_size)
# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = fully_shard(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
state_dict = { "app": AppState(model, optimizer)}
dcp.load(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running fsdp checkpoint example on {world_size} devices.")
mp.spawn(
run_fsdp_checkpoint_load_example,
args=(world_size,),
nprocs=world_size,
join=True,
)
```
If you would like to load the saved checkpoint into a non-FSDP wrapped model in a non-distributed setup, perhaps for inference, you can also do that with DCP. By default, DCP saves and loads a distributed `state_dict` in Single Program Multiple Data(SPMD) style. However if no process group is initialized, DCP infers the intent is to save or load in “non-distributed” style, meaning entirely in the current process.
Note
Distributed checkpoint support for Multi-Program Multi-Data is still under development.
```
import os
import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn
CHECKPOINT_DIR = "checkpoint"
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def run_checkpoint_load_example():
# create the non FSDP-wrapped toy model
model = ToyModel()
state_dict = {
"model": model.state_dict(),
}
# since no progress group is initialized, DCP will disable any collectives.
dcp.load(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
model.load_state_dict(state_dict["model"])
if __name__ == "__main__":
print(f"Running basic DCP checkpoint loading example.")
run_checkpoint_load_example()
```
## Formats[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#formats "Link to this heading")
One drawback not yet mentioned is that DCP saves checkpoints in a format which is inherently different then those generated using torch.save. Since this can be an issue when users wish to share models with users used to the torch.save format, or in general just want to add format flexibility to their applications. For this case, we provide the `format_utils` module in `torch.distributed.checkpoint.format_utils`.
A command line utility is provided for the users convenience, which follows the following format:
```
python -m torch.distributed.checkpoint.format_utils <mode> <checkpoint location> <location to write formats to>
```
In the command above, `mode` is one of `torch_to_dcp` or `dcp_to_torch`.
Alternatively, methods are also provided for users who may wish to convert checkpoints directly.
```
import os
import torch
import torch.distributed.checkpoint as DCP
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp
CHECKPOINT_DIR = "checkpoint"
TORCH_SAVE_CHECKPOINT_DIR = "torch_save_checkpoint.pth"
# convert dcp model to torch.save (assumes checkpoint was generated as above)
dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_DIR)
# converts the torch.save model back to DCP
torch_save_to_dcp(TORCH_SAVE_CHECKPOINT_DIR, f"{CHECKPOINT_DIR}_new")
```
## Conclusion[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#conclusion "Link to this heading")
In conclusion, we have learned how to use DCP’s `save()` and `load()` APIs, as well as how they are different form [`torch.save()`](https://docs.pytorch.org/docs/stable/generated/torch.save.html#torch.save "(in PyTorch v2.11)") and [`torch.load()`](https://docs.pytorch.org/docs/stable/generated/torch.load.html#torch.load "(in PyTorch v2.11)"). Additionally, we’ve learned how to use `get_state_dict()` and `set_state_dict()` to automatically manage parallelism-specific FQN’s and defaults during state dict generation and loading.
For more information, please see the following:
- [Saving and loading models tutorial](https://pytorch.org/tutorials/beginner/saving_loading_models.html)
- [Getting started with FullyShardedDataParallel tutorial](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
Rate this Page
★ ★ ★ ★ ★
Send Feedback
[previous Getting Started with DeviceMesh](https://docs.pytorch.org/tutorials/recipes/distributed_device_mesh.html "previous page")
[next Asynchronous Saving with Distributed Checkpoint (DCP)](https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html "next page")
Built with the [PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html) 0.15.4.
[previous Getting Started with DeviceMesh](https://docs.pytorch.org/tutorials/recipes/distributed_device_mesh.html "previous page")
[next Asynchronous Saving with Distributed Checkpoint (DCP)](https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html "next page")
On this page
- [How DCP works](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#how-dcp-works)
- [How to use DCP](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#how-to-use-dcp)
- [Saving](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#saving)
- [Loading](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#loading)
- [Formats](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#formats)
- [Conclusion](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#conclusion)
PyTorch Libraries
- [ExecuTorch](https://docs.pytorch.org/executorch)
- [Helion](https://docs.pytorch.org/helion)
- [torchao](https://docs.pytorch.org/ao)
- [kineto](https://github.com/pytorch/kineto)
- [torchtitan](https://github.com/pytorch/torchtitan)
- [TorchRL](https://docs.pytorch.org/rl)
- [torchvision](https://docs.pytorch.org/vision)
- [torchaudio](https://docs.pytorch.org/audio)
- [tensordict](https://docs.pytorch.org/tensordict)
- [PyTorch on XLA Devices](https://docs.pytorch.org/xla)
## Docs
Access comprehensive developer documentation for PyTorch
[View Docs](https://docs.pytorch.org/docs/stable/index.html)
## Tutorials
Get in-depth tutorials for beginners and advanced developers
[View Tutorials](https://docs.pytorch.org/tutorials)
## Resources
Find development resources and get your questions answered
[View Resources](https://pytorch.org/resources)
**Stay in touch** for updates, event info, and the latest news
By submitting this form, I consent to receive marketing emails from the LF and its projects regarding their events, training, research, developments, and related announcements. I understand that I can unsubscribe at any time using the links in the footers of the emails I receive. [Privacy Policy](https://www.linuxfoundation.org/privacy/).
© PyTorch. Copyright © The Linux Foundation®. All rights reserved. The Linux Foundation has registered trademarks and uses trademarks. For more information, including terms of use, privacy policy, and trademark usage, please see our [Policies](https://www.linuxfoundation.org/legal/policies) page. [Trademark Usage](https://www.linuxfoundation.org/trademark-usage). [Privacy Policy](http://www.linuxfoundation.org/privacy).
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: [Cookies Policy](https://opensource.fb.com/legal/cookie-policy).

© Copyright 2024, PyTorch.
Created using [Sphinx](https://www.sphinx-doc.org/) 7.2.6.
Built with the [PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html) 0.15.4. |
| Readable Markdown | Created On: Oct 02, 2023 \| Last Updated: Jul 10, 2025 \| Last Verified: Nov 05, 2024
**Author**: [Iris Zhang](https://github.com/wz337), [Rodrigo Kumpera](https://github.com/kumpera), [Chien-Chin Huang](https://github.com/fegin), [Lucas Pasqualin](https://github.com/lucasllc)
Note
[](https://docs.pytorch.org/tutorials/_images/pencil-16.png) View and edit this tutorial in [github](https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst).
Prerequisites:
- [FullyShardedDataParallel API documents](https://pytorch.org/docs/master/fsdp.html)
- [torch.load API documents](https://pytorch.org/docs/stable/generated/torch.load.html)
Checkpointing AI models during distributed training could be challenging, as parameters and gradients are partitioned across trainers and the number of trainers available could change when you resume training. Pytorch Distributed Checkpointing (DCP) can help make this process easier.
In this tutorial, we show how to use DCP APIs with a simple FSDP wrapped model.
## How DCP works[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#how-dcp-works "Link to this heading")
`torch.distributed.checkpoint()` enables saving and loading models from multiple ranks in parallel. You can use this module to save on any number of ranks in parallel, and then re-shard across differing cluster topologies at load time.
Addditionally, through the use of modules in `torch.distributed.checkpoint.state_dict()`, DCP offers support for gracefully handling `state_dict` generation and loading in distributed settings. This includes managing fully-qualified-name (FQN) mappings across models and optimizers, and setting default parameters for PyTorch provided parallelisms.
DCP is different from [`torch.save()`](https://docs.pytorch.org/docs/stable/generated/torch.save.html#torch.save "(in PyTorch v2.11)") and [`torch.load()`](https://docs.pytorch.org/docs/stable/generated/torch.load.html#torch.load "(in PyTorch v2.11)") in a few significant ways:
- It produces multiple files per checkpoint, with at least one per rank.
- It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead.
- DCP offers special handling of Stateful objects (formally defined in torch.distributed.checkpoint.stateful), automatically calling both state\_dict and load\_state\_dict methods if they are defined.
Note
The code in this tutorial runs on an 8-GPU server, but it can be easily generalized to other environments.
## How to use DCP[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#how-to-use-dcp "Link to this heading")
Here we use a toy model wrapped with FSDP for demonstration purposes. Similarly, the APIs and logic can be applied to larger models for checkpointing.
### Saving[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#saving "Link to this heading")
Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it.
```
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import fully_shard
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
CHECKPOINT_DIR = "checkpoint"
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355 "
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def run_fsdp_checkpoint_save_example(rank, world_size):
print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
setup(rank, world_size)
# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = fully_shard(model)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
optimizer.zero_grad()
model(torch.rand(8, 16, device="cuda")).sum().backward()
optimizer.step()
state_dict = { "app": AppState(model, optimizer) }
dcp.save(state_dict, checkpoint_id=CHECKPOINT_DIR)
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running fsdp checkpoint example on {world_size} devices.")
mp.spawn(
run_fsdp_checkpoint_save_example,
args=(world_size,),
nprocs=world_size,
join=True,
)
```
Please go ahead and check the checkpoint directory. You should see checkpoint files corresponding to the number of files as shown below. For example, if you have 8 devices, you should see 8 files.
[](https://docs.pytorch.org/tutorials/_images/distributed_checkpoint_generated_files.png)
### Loading[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#loading "Link to this heading")
After saving, let’s create the same FSDP-wrapped model, and load the saved state dict from storage into the model. You can load in the same world size or different world size.
Please note that you will have to call `model.state_dict()` prior to loading and pass it to DCP’s `load_state_dict()` API. This is fundamentally different from [`torch.load()`](https://docs.pytorch.org/docs/stable/generated/torch.load.html#torch.load "(in PyTorch v2.11)"), as [`torch.load()`](https://docs.pytorch.org/docs/stable/generated/torch.load.html#torch.load "(in PyTorch v2.11)") simply requires the path to the checkpoint prior for loading. The reason that we need the `state_dict` prior to loading is:
- DCP uses the pre-allocated storage from model state\_dict to load from the checkpoint directory. During loading, the state\_dict passed in will be updated in place.
- DCP requires the sharding information from the model prior to loading to support resharding.
```
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.fsdp import fully_shard
CHECKPOINT_DIR = "checkpoint"
class AppState(Stateful):
"""This is a useful wrapper for checkpointing the Application State. Since this object is compliant
with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
dcp.save/load APIs.
Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
and optimizer.
"""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
# this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {
"model": model_state_dict,
"optim": optimizer_state_dict
}
def load_state_dict(self, state_dict):
# sets our state dicts on the model and optimizer, now that we've loaded
set_state_dict(
self.model,
self.optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"]
)
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355 "
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def run_fsdp_checkpoint_load_example(rank, world_size):
print(f"Running basic FSDP checkpoint loading example on rank {rank}.")
setup(rank, world_size)
# create a model and move it to GPU with id rank
model = ToyModel().to(rank)
model = fully_shard(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
state_dict = { "app": AppState(model, optimizer)}
dcp.load(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running fsdp checkpoint example on {world_size} devices.")
mp.spawn(
run_fsdp_checkpoint_load_example,
args=(world_size,),
nprocs=world_size,
join=True,
)
```
If you would like to load the saved checkpoint into a non-FSDP wrapped model in a non-distributed setup, perhaps for inference, you can also do that with DCP. By default, DCP saves and loads a distributed `state_dict` in Single Program Multiple Data(SPMD) style. However if no process group is initialized, DCP infers the intent is to save or load in “non-distributed” style, meaning entirely in the current process.
Note
Distributed checkpoint support for Multi-Program Multi-Data is still under development.
```
import os
import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn
CHECKPOINT_DIR = "checkpoint"
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(16, 16)
self.relu = nn.ReLU()
self.net2 = nn.Linear(16, 8)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def run_checkpoint_load_example():
# create the non FSDP-wrapped toy model
model = ToyModel()
state_dict = {
"model": model.state_dict(),
}
# since no progress group is initialized, DCP will disable any collectives.
dcp.load(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
model.load_state_dict(state_dict["model"])
if __name__ == "__main__":
print(f"Running basic DCP checkpoint loading example.")
run_checkpoint_load_example()
```
## Formats[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#formats "Link to this heading")
One drawback not yet mentioned is that DCP saves checkpoints in a format which is inherently different then those generated using torch.save. Since this can be an issue when users wish to share models with users used to the torch.save format, or in general just want to add format flexibility to their applications. For this case, we provide the `format_utils` module in `torch.distributed.checkpoint.format_utils`.
A command line utility is provided for the users convenience, which follows the following format:
```
python -m torch.distributed.checkpoint.format_utils <mode> <checkpoint location> <location to write formats to>
```
In the command above, `mode` is one of `torch_to_dcp` or `dcp_to_torch`.
Alternatively, methods are also provided for users who may wish to convert checkpoints directly.
```
import os
import torch
import torch.distributed.checkpoint as DCP
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp
CHECKPOINT_DIR = "checkpoint"
TORCH_SAVE_CHECKPOINT_DIR = "torch_save_checkpoint.pth"
# convert dcp model to torch.save (assumes checkpoint was generated as above)
dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_DIR)
# converts the torch.save model back to DCP
torch_save_to_dcp(TORCH_SAVE_CHECKPOINT_DIR, f"{CHECKPOINT_DIR}_new")
```
## Conclusion[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#conclusion "Link to this heading")
In conclusion, we have learned how to use DCP’s `save()` and `load()` APIs, as well as how they are different form [`torch.save()`](https://docs.pytorch.org/docs/stable/generated/torch.save.html#torch.save "(in PyTorch v2.11)") and [`torch.load()`](https://docs.pytorch.org/docs/stable/generated/torch.load.html#torch.load "(in PyTorch v2.11)"). Additionally, we’ve learned how to use `get_state_dict()` and `set_state_dict()` to automatically manage parallelism-specific FQN’s and defaults during state dict generation and loading.
For more information, please see the following:
- [Saving and loading models tutorial](https://pytorch.org/tutorials/beginner/saving_loading_models.html)
- [Getting started with FullyShardedDataParallel tutorial](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) |
| Shard | 114 (laksa) |
| Root Hash | 14416670112284949514 |
| Unparsed URL | org,pytorch!docs,/tutorials/recipes/distributed_checkpoint_recipe.html s443 |