🕷️ Crawler Inspector

URL Lookup

Direct Parameter Lookup

Raw Queries and Responses

1. Shard Calculation

Query:
Response:
Calculated Shard: 114 (from laksa088)

2. Crawled Status Check

Query:
Response:

3. Robots.txt Check

Query:
Response:

4. Spam/Ban Check

Query:
Response:

5. Seen Status Check

ℹ️ Skipped - page is already crawled

📄
INDEXABLE
CRAWLED
1 day ago
🤖
ROBOTS ALLOWED

Page Info Filters

FilterStatusConditionDetails
HTTP statusPASSdownload_http_code = 200HTTP 200
Age cutoffPASSdownload_stamp > now() - 6 MONTH0.1 months ago
History dropPASSisNull(history_drop_reason)No drop reason
Spam/banPASSfh_dont_index != 1 AND ml_spam_score = 0ml_spam_score=0
CanonicalPASSmeta_canonical IS NULL OR = '' OR = src_unparsedNot set

Page Details

PropertyValue
URLhttps://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
Last Crawled2026-04-10 06:14:59 (1 day ago)
First Indexed2025-05-27 08:20:58 (10 months ago)
HTTP Status Code200
Meta TitleGetting Started with Distributed Checkpoint (DCP) — PyTorch Tutorials 2.11.0+cu130 documentation
Meta Descriptionnull
Meta Canonicalnull
Boilerpipe Text
Created On: Oct 02, 2023 | Last Updated: Jul 10, 2025 | Last Verified: Nov 05, 2024 Author : Iris Zhang , Rodrigo Kumpera , Chien-Chin Huang , Lucas Pasqualin Note View and edit this tutorial in github . Prerequisites: FullyShardedDataParallel API documents torch.load API documents Checkpointing AI models during distributed training could be challenging, as parameters and gradients are partitioned across trainers and the number of trainers available could change when you resume training. Pytorch Distributed Checkpointing (DCP) can help make this process easier. In this tutorial, we show how to use DCP APIs with a simple FSDP wrapped model. How DCP works # torch.distributed.checkpoint() enables saving and loading models from multiple ranks in parallel. You can use this module to save on any number of ranks in parallel, and then re-shard across differing cluster topologies at load time. Addditionally, through the use of modules in torch.distributed.checkpoint.state_dict() , DCP offers support for gracefully handling state_dict generation and loading in distributed settings. This includes managing fully-qualified-name (FQN) mappings across models and optimizers, and setting default parameters for PyTorch provided parallelisms. DCP is different from torch.save() and torch.load() in a few significant ways: It produces multiple files per checkpoint, with at least one per rank. It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead. DCP offers special handling of Stateful objects (formally defined in torch.distributed.checkpoint.stateful ), automatically calling both state_dict and load_state_dict methods if they are defined. Note The code in this tutorial runs on an 8-GPU server, but it can be easily generalized to other environments. How to use DCP # Here we use a toy model wrapped with FSDP for demonstration purposes. Similarly, the APIs and logic can be applied to larger models for checkpointing. Saving # Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it. import os import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.multiprocessing as mp import torch.nn as nn from torch.distributed.fsdp import fully_shard from torch.distributed.checkpoint.state_dict import get_state_dict , set_state_dict from torch.distributed.checkpoint.stateful import Stateful CHECKPOINT_DIR = "checkpoint" class AppState ( Stateful ): """This is a useful wrapper for checkpointing the Application State. Since this object is compliant with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the dcp.save/load APIs. Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model and optimizer. """ def __init__ ( self , model , optimizer = None ): self . model = model self . optimizer = optimizer def state_dict ( self ): # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT model_state_dict , optimizer_state_dict = get_state_dict ( self . model , self . optimizer ) return { "model" : model_state_dict , "optim" : optimizer_state_dict } def load_state_dict ( self , state_dict ): # sets our state dicts on the model and optimizer, now that we've loaded set_state_dict ( self . model , self . optimizer , model_state_dict = state_dict [ "model" ], optim_state_dict = state_dict [ "optim" ] ) class ToyModel ( nn . Module ): def __init__ ( self ): super ( ToyModel , self ) . __init__ () self . net1 = nn . Linear ( 16 , 16 ) self . relu = nn . ReLU () self . net2 = nn . Linear ( 16 , 8 ) def forward ( self , x ): return self . net2 ( self . relu ( self . net1 ( x ))) def setup ( rank , world_size ): os . environ [ "MASTER_ADDR" ] = "localhost" os . environ [ "MASTER_PORT" ] = "12355 " # initialize the process group dist . init_process_group ( "nccl" , rank = rank , world_size = world_size ) torch . cuda . set_device ( rank ) def cleanup (): dist . destroy_process_group () def run_fsdp_checkpoint_save_example ( rank , world_size ): print ( f "Running basic FSDP checkpoint saving example on rank { rank } ." ) setup ( rank , world_size ) # create a model and move it to GPU with id rank model = ToyModel () . to ( rank ) model = fully_shard ( model ) loss_fn = nn . MSELoss () optimizer = torch . optim . Adam ( model . parameters (), lr = 0.1 ) optimizer . zero_grad () model ( torch . rand ( 8 , 16 , device = "cuda" )) . sum () . backward () optimizer . step () state_dict = { "app" : AppState ( model , optimizer ) } dcp . save ( state_dict , checkpoint_id = CHECKPOINT_DIR ) cleanup () if __name__ == "__main__" : world_size = torch . cuda . device_count () print ( f "Running fsdp checkpoint example on { world_size } devices." ) mp . spawn ( run_fsdp_checkpoint_save_example , args = ( world_size ,), nprocs = world_size , join = True , ) Please go ahead and check the checkpoint directory. You should see checkpoint files corresponding to the number of files as shown below. For example, if you have 8 devices, you should see 8 files. Loading # After saving, let’s create the same FSDP-wrapped model, and load the saved state dict from storage into the model. You can load in the same world size or different world size. Please note that you will have to call model.state_dict() prior to loading and pass it to DCP’s load_state_dict() API. This is fundamentally different from torch.load() , as torch.load() simply requires the path to the checkpoint prior for loading. The reason that we need the state_dict prior to loading is: DCP uses the pre-allocated storage from model state_dict to load from the checkpoint directory. During loading, the state_dict passed in will be updated in place. DCP requires the sharding information from the model prior to loading to support resharding. import os import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.state_dict import get_state_dict , set_state_dict import torch.multiprocessing as mp import torch.nn as nn from torch.distributed.fsdp import fully_shard CHECKPOINT_DIR = "checkpoint" class AppState ( Stateful ): """This is a useful wrapper for checkpointing the Application State. Since this object is compliant with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the dcp.save/load APIs. Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model and optimizer. """ def __init__ ( self , model , optimizer = None ): self . model = model self . optimizer = optimizer def state_dict ( self ): # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT model_state_dict , optimizer_state_dict = get_state_dict ( self . model , self . optimizer ) return { "model" : model_state_dict , "optim" : optimizer_state_dict } def load_state_dict ( self , state_dict ): # sets our state dicts on the model and optimizer, now that we've loaded set_state_dict ( self . model , self . optimizer , model_state_dict = state_dict [ "model" ], optim_state_dict = state_dict [ "optim" ] ) class ToyModel ( nn . Module ): def __init__ ( self ): super ( ToyModel , self ) . __init__ () self . net1 = nn . Linear ( 16 , 16 ) self . relu = nn . ReLU () self . net2 = nn . Linear ( 16 , 8 ) def forward ( self , x ): return self . net2 ( self . relu ( self . net1 ( x ))) def setup ( rank , world_size ): os . environ [ "MASTER_ADDR" ] = "localhost" os . environ [ "MASTER_PORT" ] = "12355 " # initialize the process group dist . init_process_group ( "nccl" , rank = rank , world_size = world_size ) torch . cuda . set_device ( rank ) def cleanup (): dist . destroy_process_group () def run_fsdp_checkpoint_load_example ( rank , world_size ): print ( f "Running basic FSDP checkpoint loading example on rank { rank } ." ) setup ( rank , world_size ) # create a model and move it to GPU with id rank model = ToyModel () . to ( rank ) model = fully_shard ( model ) optimizer = torch . optim . Adam ( model . parameters (), lr = 0.1 ) state_dict = { "app" : AppState ( model , optimizer )} dcp . load ( state_dict = state_dict , checkpoint_id = CHECKPOINT_DIR , ) cleanup () if __name__ == "__main__" : world_size = torch . cuda . device_count () print ( f "Running fsdp checkpoint example on { world_size } devices." ) mp . spawn ( run_fsdp_checkpoint_load_example , args = ( world_size ,), nprocs = world_size , join = True , ) If you would like to load the saved checkpoint into a non-FSDP wrapped model in a non-distributed setup, perhaps for inference, you can also do that with DCP. By default, DCP saves and loads a distributed state_dict in Single Program Multiple Data(SPMD) style. However if no process group is initialized, DCP infers the intent is to save or load in “non-distributed” style, meaning entirely in the current process. Note Distributed checkpoint support for Multi-Program Multi-Data is still under development. import os import torch import torch.distributed.checkpoint as dcp import torch.nn as nn CHECKPOINT_DIR = "checkpoint" class ToyModel ( nn . Module ): def __init__ ( self ): super ( ToyModel , self ) . __init__ () self . net1 = nn . Linear ( 16 , 16 ) self . relu = nn . ReLU () self . net2 = nn . Linear ( 16 , 8 ) def forward ( self , x ): return self . net2 ( self . relu ( self . net1 ( x ))) def run_checkpoint_load_example (): # create the non FSDP-wrapped toy model model = ToyModel () state_dict = { "model" : model . state_dict (), } # since no progress group is initialized, DCP will disable any collectives. dcp . load ( state_dict = state_dict , checkpoint_id = CHECKPOINT_DIR , ) model . load_state_dict ( state_dict [ "model" ]) if __name__ == "__main__" : print ( f "Running basic DCP checkpoint loading example." ) run_checkpoint_load_example () Formats # One drawback not yet mentioned is that DCP saves checkpoints in a format which is inherently different then those generated using torch.save. Since this can be an issue when users wish to share models with users used to the torch.save format, or in general just want to add format flexibility to their applications. For this case, we provide the format_utils module in torch.distributed.checkpoint.format_utils . A command line utility is provided for the users convenience, which follows the following format: python -m torch.distributed.checkpoint.format_utils <mode> <checkpoint location> <location to write formats to> In the command above, mode is one of torch_to_dcp or dcp_to_torch . Alternatively, methods are also provided for users who may wish to convert checkpoints directly. import os import torch import torch.distributed.checkpoint as DCP from torch.distributed.checkpoint.format_utils import dcp_to_torch_save , torch_save_to_dcp CHECKPOINT_DIR = "checkpoint" TORCH_SAVE_CHECKPOINT_DIR = "torch_save_checkpoint.pth" # convert dcp model to torch.save (assumes checkpoint was generated as above) dcp_to_torch_save ( CHECKPOINT_DIR , TORCH_SAVE_CHECKPOINT_DIR ) # converts the torch.save model back to DCP torch_save_to_dcp ( TORCH_SAVE_CHECKPOINT_DIR , f " { CHECKPOINT_DIR } _new" ) Conclusion # In conclusion, we have learned how to use DCP’s save() and load() APIs, as well as how they are different form torch.save() and torch.load() . Additionally, we’ve learned how to use get_state_dict() and set_state_dict() to automatically manage parallelism-specific FQN’s and defaults during state dict generation and loading. For more information, please see the following: Saving and loading models tutorial Getting started with FullyShardedDataParallel tutorial
Markdown
![](https://www.facebook.com/tr?id=243028289693773&ev=PageView&noscript=1) Help us understand how you use PyTorch! Take our quick survey. [Take Survey](https://docs.google.com/forms/d/e/1FAIpQLSfsGAWBcfutRcbO6kfrShBMOMmRuBezRjjOcXk0e9I9luBzvQ/viewform) [Skip to main content](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#main-content) Back to top [![PyTorch Tutorials - Home](https://docs.pytorch.org/tutorials/_static/img/logo-dark.svg) ![PyTorch Tutorials - Home](https://docs.pytorch.org/tutorials/_static/img/logo-white.svg)](https://docs.pytorch.org/tutorials/index.html) [![PyTorch Tutorials - Home](https://docs.pytorch.org/tutorials/_static/img/logo-dark.svg) ![PyTorch Tutorials - Home](https://docs.pytorch.org/tutorials/_static/img/logo-white.svg)](https://docs.pytorch.org/tutorials/index.html) [v2.11.0+cu130](https://docs.pytorch.org/tutorials/index.html) - [Intro](https://docs.pytorch.org/tutorials/intro.html) - [Learn the Basics](https://docs.pytorch.org/tutorials/beginner/basics/intro.html) - [Introduction to PyTorch - YouTube Series](https://docs.pytorch.org/tutorials/beginner/introyt/introyt_index.html) - [Deep Learning with PyTorch: A 60 Minute Blitz](https://docs.pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html) - [Learning PyTorch with Examples](https://docs.pytorch.org/tutorials/beginner/pytorch_with_examples.html) - [What is torch.nn really?](https://docs.pytorch.org/tutorials/beginner/nn_tutorial.html) - [Understanding requires\_grad, retain\_grad, Leaf, and Non-leaf Tensors](https://docs.pytorch.org/tutorials/beginner/understanding_leaf_vs_nonleaf_tutorial.html) - [NLP from Scratch](https://docs.pytorch.org/tutorials/intermediate/nlp_from_scratch_index.html) - [Visualizing Models, Data, and Training with TensorBoard](https://docs.pytorch.org/tutorials/intermediate/tensorboard_tutorial.html) - [A guide on good usage of non\_blocking and pin\_memory() in PyTorch](https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html) - [Visualizing Gradients](https://docs.pytorch.org/tutorials/intermediate/visualizing_gradients_tutorial.html) - [Compilers](https://docs.pytorch.org/tutorials/compilers_index.html) - [Introduction to torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) - [torch.compile End-to-End Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_compile_full_example.html) - [Compiled Autograd: Capturing a larger backward graph for torch.compile](https://docs.pytorch.org/tutorials/intermediate/compiled_autograd_tutorial.html) - [Inductor CPU backend debugging and profiling](https://docs.pytorch.org/tutorials/intermediate/inductor_debug_cpu.html) - [Dynamic Compilation Control with torch.compiler.set\_stance](https://docs.pytorch.org/tutorials/recipes/torch_compiler_set_stance_tutorial.html) - [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html) - [(beta) Compiling the optimizer with torch.compile](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer.html) - [(beta) Running the compiled optimizer with an LR Scheduler](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer_lr_scheduler.html) - [Using Variable Length Attention in PyTorch](https://docs.pytorch.org/tutorials/intermediate/variable_length_attention_tutorial.html) - [Using User-Defined Triton Kernels with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html) - [Compile Time Caching in torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) - [Reducing torch.compile cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) - [torch.export Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html) - [torch.export AOTInductor Tutorial for Python runtime (Beta)](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html) - [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html) - [Introduction to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/intro_onnx.html) - [Export a PyTorch model to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html) - [Extending the ONNX Exporter Operator Support](https://docs.pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html) - [Export a model with control flow to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html) - [Building a Convolution/Batch Norm fuser with torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_conv_bn_fuser.html) - [(beta) Building a Simple CPU Performance Profiler with FX](https://docs.pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html) - [Domains](https://docs.pytorch.org/tutorials/domains.html) - [TorchVision Object Detection Finetuning Tutorial](https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html) - [Transfer Learning for Computer Vision Tutorial](https://docs.pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) - [Adversarial Example Generation](https://docs.pytorch.org/tutorials/beginner/fgsm_tutorial.html) - [DCGAN Tutorial](https://docs.pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html) - [Spatial Transformer Networks Tutorial](https://docs.pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html) - [Reinforcement Learning (DQN) Tutorial](https://docs.pytorch.org/tutorials/intermediate/reinforcement_q_learning.html) - [Reinforcement Learning (PPO) with TorchRL Tutorial](https://docs.pytorch.org/tutorials/intermediate/reinforcement_ppo.html) - [Train a Mario-playing RL Agent](https://docs.pytorch.org/tutorials/intermediate/mario_rl_tutorial.html) - [Pendulum: Writing your environment and transforms with TorchRL](https://docs.pytorch.org/tutorials/advanced/pendulum.html) - [Introduction to TorchRec](https://docs.pytorch.org/tutorials/intermediate/torchrec_intro_tutorial.html) - [Exploring TorchRec sharding](https://docs.pytorch.org/tutorials/advanced/sharding.html) - [Distributed](https://docs.pytorch.org/tutorials/distributed.html) - [PyTorch Distributed Overview](https://docs.pytorch.org/tutorials/beginner/dist_overview.html) - [Distributed Data Parallel in PyTorch - Video Tutorials](https://docs.pytorch.org/tutorials/beginner/ddp_series_intro.html) - [Getting Started with Distributed Data Parallel](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html) - [Writing Distributed Applications with PyTorch](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html) - [Getting Started with Fully Sharded Data Parallel (FSDP2)](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html) - [Introduction to Libuv TCPStore Backend](https://docs.pytorch.org/tutorials/intermediate/TCPStore_libuv_backend.html) - [Large Scale Transformer model training with Tensor Parallel (TP)](https://docs.pytorch.org/tutorials/intermediate/TP_tutorial.html) - [Introduction to Distributed Pipeline Parallelism](https://docs.pytorch.org/tutorials/intermediate/pipelining_tutorial.html) - [Customize Process Group Backends Using Cpp Extensions](https://docs.pytorch.org/tutorials/intermediate/process_group_cpp_extension_tutorial.html) - [Getting Started with Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html) - [Implementing a Parameter Server Using Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html) - [Implementing Batch RPC Processing Using Asynchronous Executions](https://docs.pytorch.org/tutorials/intermediate/rpc_async_execution.html) - [Interactive Distributed Applications with Monarch](https://docs.pytorch.org/tutorials/intermediate/monarch_distributed_tutorial.html) - [Combining Distributed DataParallel with Distributed RPC Framework](https://docs.pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html) - [Distributed Training with Uneven Inputs Using the Join Context Manager](https://docs.pytorch.org/tutorials/advanced/generic_join.html) - [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html) - [Deep Dive](https://docs.pytorch.org/tutorials/deep-dive.html) - [Profiling your PyTorch Module](https://docs.pytorch.org/tutorials/beginner/profiler.html) - [Parametrizations Tutorial](https://docs.pytorch.org/tutorials/intermediate/parametrizations.html) - [Pruning Tutorial](https://docs.pytorch.org/tutorials/intermediate/pruning_tutorial.html) - [Inductor CPU backend debugging and profiling](https://docs.pytorch.org/tutorials/intermediate/inductor_debug_cpu.html) - [(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html) - [Knowledge Distillation Tutorial](https://docs.pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html) - [Channels Last Memory Format in PyTorch](https://docs.pytorch.org/tutorials/intermediate/memory_format_tutorial.html) - [Forward-mode Automatic Differentiation (Beta)](https://docs.pytorch.org/tutorials/intermediate/forward_ad_usage.html) - [Jacobians, Hessians, hvp, vhp, and more: composing function transforms](https://docs.pytorch.org/tutorials/intermediate/jacobians_hessians.html) - [Model ensembling](https://docs.pytorch.org/tutorials/intermediate/ensembling.html) - [Per-sample-gradients](https://docs.pytorch.org/tutorials/intermediate/per_sample_grads.html) - [Using the PyTorch C++ Frontend](https://docs.pytorch.org/tutorials/advanced/cpp_frontend.html) - [Autograd in C++ Frontend](https://docs.pytorch.org/tutorials/advanced/cpp_autograd.html) - [Extension](https://docs.pytorch.org/tutorials/extension.html) - [PyTorch Custom Operators](https://docs.pytorch.org/tutorials/advanced/custom_ops_landing_page.html) - [Custom Python Operators](https://docs.pytorch.org/tutorials/advanced/python_custom_ops.html) - [Custom C++ and CUDA Operators](https://docs.pytorch.org/tutorials/advanced/cpp_custom_ops.html) - [Double Backward with Custom Functions](https://docs.pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html) - [Fusing Convolution and Batch Norm using Custom Function](https://docs.pytorch.org/tutorials/intermediate/custom_function_conv_bn_tutorial.html) - [Registering a Dispatched Operator in C++](https://docs.pytorch.org/tutorials/advanced/dispatcher.html) - [Extending dispatcher for a new backend in C++](https://docs.pytorch.org/tutorials/advanced/extend_dispatcher.html) - [Facilitating New Backend Integration by PrivateUse1](https://docs.pytorch.org/tutorials/advanced/privateuseone.html) - [Ecosystem](https://docs.pytorch.org/tutorials/ecosystem.html) - [Hyperparameter tuning using Ray Tune](https://docs.pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html) - [Serve PyTorch models at scale with Ray Serve](https://docs.pytorch.org/tutorials/beginner/serving_tutorial.html) - [Multi-Objective NAS with Ax](https://docs.pytorch.org/tutorials/intermediate/ax_multiobjective_nas_tutorial.html) - [PyTorch Profiler With TensorBoard](https://docs.pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html) - [Real Time Inference on Raspberry Pi 4 and 5 (40 fps!)](https://docs.pytorch.org/tutorials/intermediate/realtime_rpi.html) - [Mosaic: Memory Profiling for PyTorch](https://docs.pytorch.org/tutorials/beginner/mosaic_memory_profiling_tutorial.html) - [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html) - More - [Recipes](https://docs.pytorch.org/tutorials/recipes_index.html) - [Unstable](https://docs.pytorch.org/tutorials/unstable_index.html) [Go to pytorch.org](https://pytorch.org/) - [X](https://x.com/PyTorch) - [GitHub](https://github.com/pytorch/tutorials) - [Discourse](https://dev-discuss.pytorch.org/) - [PyPi](https://pypi.org/project/torch/) [v2.11.0+cu130](https://docs.pytorch.org/tutorials/index.html) - [Intro](https://docs.pytorch.org/tutorials/intro.html) - [Learn the Basics](https://docs.pytorch.org/tutorials/beginner/basics/intro.html) - [Introduction to PyTorch - YouTube Series](https://docs.pytorch.org/tutorials/beginner/introyt/introyt_index.html) - [Deep Learning with PyTorch: A 60 Minute Blitz](https://docs.pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html) - [Learning PyTorch with Examples](https://docs.pytorch.org/tutorials/beginner/pytorch_with_examples.html) - [What is torch.nn really?](https://docs.pytorch.org/tutorials/beginner/nn_tutorial.html) - [Understanding requires\_grad, retain\_grad, Leaf, and Non-leaf Tensors](https://docs.pytorch.org/tutorials/beginner/understanding_leaf_vs_nonleaf_tutorial.html) - [NLP from Scratch](https://docs.pytorch.org/tutorials/intermediate/nlp_from_scratch_index.html) - [Visualizing Models, Data, and Training with TensorBoard](https://docs.pytorch.org/tutorials/intermediate/tensorboard_tutorial.html) - [A guide on good usage of non\_blocking and pin\_memory() in PyTorch](https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html) - [Visualizing Gradients](https://docs.pytorch.org/tutorials/intermediate/visualizing_gradients_tutorial.html) - [Compilers](https://docs.pytorch.org/tutorials/compilers_index.html) - [Introduction to torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) - [torch.compile End-to-End Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_compile_full_example.html) - [Compiled Autograd: Capturing a larger backward graph for torch.compile](https://docs.pytorch.org/tutorials/intermediate/compiled_autograd_tutorial.html) - [Inductor CPU backend debugging and profiling](https://docs.pytorch.org/tutorials/intermediate/inductor_debug_cpu.html) - [Dynamic Compilation Control with torch.compiler.set\_stance](https://docs.pytorch.org/tutorials/recipes/torch_compiler_set_stance_tutorial.html) - [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html) - [(beta) Compiling the optimizer with torch.compile](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer.html) - [(beta) Running the compiled optimizer with an LR Scheduler](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer_lr_scheduler.html) - [Using Variable Length Attention in PyTorch](https://docs.pytorch.org/tutorials/intermediate/variable_length_attention_tutorial.html) - [Using User-Defined Triton Kernels with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html) - [Compile Time Caching in torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) - [Reducing torch.compile cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) - [torch.export Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html) - [torch.export AOTInductor Tutorial for Python runtime (Beta)](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html) - [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html) - [Introduction to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/intro_onnx.html) - [Export a PyTorch model to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html) - [Extending the ONNX Exporter Operator Support](https://docs.pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html) - [Export a model with control flow to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html) - [Building a Convolution/Batch Norm fuser with torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_conv_bn_fuser.html) - [(beta) Building a Simple CPU Performance Profiler with FX](https://docs.pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html) - [Domains](https://docs.pytorch.org/tutorials/domains.html) - [TorchVision Object Detection Finetuning Tutorial](https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html) - [Transfer Learning for Computer Vision Tutorial](https://docs.pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) - [Adversarial Example Generation](https://docs.pytorch.org/tutorials/beginner/fgsm_tutorial.html) - [DCGAN Tutorial](https://docs.pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html) - [Spatial Transformer Networks Tutorial](https://docs.pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html) - [Reinforcement Learning (DQN) Tutorial](https://docs.pytorch.org/tutorials/intermediate/reinforcement_q_learning.html) - [Reinforcement Learning (PPO) with TorchRL Tutorial](https://docs.pytorch.org/tutorials/intermediate/reinforcement_ppo.html) - [Train a Mario-playing RL Agent](https://docs.pytorch.org/tutorials/intermediate/mario_rl_tutorial.html) - [Pendulum: Writing your environment and transforms with TorchRL](https://docs.pytorch.org/tutorials/advanced/pendulum.html) - [Introduction to TorchRec](https://docs.pytorch.org/tutorials/intermediate/torchrec_intro_tutorial.html) - [Exploring TorchRec sharding](https://docs.pytorch.org/tutorials/advanced/sharding.html) - [Distributed](https://docs.pytorch.org/tutorials/distributed.html) - [PyTorch Distributed Overview](https://docs.pytorch.org/tutorials/beginner/dist_overview.html) - [Distributed Data Parallel in PyTorch - Video Tutorials](https://docs.pytorch.org/tutorials/beginner/ddp_series_intro.html) - [Getting Started with Distributed Data Parallel](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html) - [Writing Distributed Applications with PyTorch](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html) - [Getting Started with Fully Sharded Data Parallel (FSDP2)](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html) - [Introduction to Libuv TCPStore Backend](https://docs.pytorch.org/tutorials/intermediate/TCPStore_libuv_backend.html) - [Large Scale Transformer model training with Tensor Parallel (TP)](https://docs.pytorch.org/tutorials/intermediate/TP_tutorial.html) - [Introduction to Distributed Pipeline Parallelism](https://docs.pytorch.org/tutorials/intermediate/pipelining_tutorial.html) - [Customize Process Group Backends Using Cpp Extensions](https://docs.pytorch.org/tutorials/intermediate/process_group_cpp_extension_tutorial.html) - [Getting Started with Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html) - [Implementing a Parameter Server Using Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html) - [Implementing Batch RPC Processing Using Asynchronous Executions](https://docs.pytorch.org/tutorials/intermediate/rpc_async_execution.html) - [Interactive Distributed Applications with Monarch](https://docs.pytorch.org/tutorials/intermediate/monarch_distributed_tutorial.html) - [Combining Distributed DataParallel with Distributed RPC Framework](https://docs.pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html) - [Distributed Training with Uneven Inputs Using the Join Context Manager](https://docs.pytorch.org/tutorials/advanced/generic_join.html) - [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html) - [Deep Dive](https://docs.pytorch.org/tutorials/deep-dive.html) - [Profiling your PyTorch Module](https://docs.pytorch.org/tutorials/beginner/profiler.html) - [Parametrizations Tutorial](https://docs.pytorch.org/tutorials/intermediate/parametrizations.html) - [Pruning Tutorial](https://docs.pytorch.org/tutorials/intermediate/pruning_tutorial.html) - [Inductor CPU backend debugging and profiling](https://docs.pytorch.org/tutorials/intermediate/inductor_debug_cpu.html) - [(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html) - [Knowledge Distillation Tutorial](https://docs.pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html) - [Channels Last Memory Format in PyTorch](https://docs.pytorch.org/tutorials/intermediate/memory_format_tutorial.html) - [Forward-mode Automatic Differentiation (Beta)](https://docs.pytorch.org/tutorials/intermediate/forward_ad_usage.html) - [Jacobians, Hessians, hvp, vhp, and more: composing function transforms](https://docs.pytorch.org/tutorials/intermediate/jacobians_hessians.html) - [Model ensembling](https://docs.pytorch.org/tutorials/intermediate/ensembling.html) - [Per-sample-gradients](https://docs.pytorch.org/tutorials/intermediate/per_sample_grads.html) - [Using the PyTorch C++ Frontend](https://docs.pytorch.org/tutorials/advanced/cpp_frontend.html) - [Autograd in C++ Frontend](https://docs.pytorch.org/tutorials/advanced/cpp_autograd.html) - [Extension](https://docs.pytorch.org/tutorials/extension.html) - [PyTorch Custom Operators](https://docs.pytorch.org/tutorials/advanced/custom_ops_landing_page.html) - [Custom Python Operators](https://docs.pytorch.org/tutorials/advanced/python_custom_ops.html) - [Custom C++ and CUDA Operators](https://docs.pytorch.org/tutorials/advanced/cpp_custom_ops.html) - [Double Backward with Custom Functions](https://docs.pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html) - [Fusing Convolution and Batch Norm using Custom Function](https://docs.pytorch.org/tutorials/intermediate/custom_function_conv_bn_tutorial.html) - [Registering a Dispatched Operator in C++](https://docs.pytorch.org/tutorials/advanced/dispatcher.html) - [Extending dispatcher for a new backend in C++](https://docs.pytorch.org/tutorials/advanced/extend_dispatcher.html) - [Facilitating New Backend Integration by PrivateUse1](https://docs.pytorch.org/tutorials/advanced/privateuseone.html) - [Ecosystem](https://docs.pytorch.org/tutorials/ecosystem.html) - [Hyperparameter tuning using Ray Tune](https://docs.pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html) - [Serve PyTorch models at scale with Ray Serve](https://docs.pytorch.org/tutorials/beginner/serving_tutorial.html) - [Multi-Objective NAS with Ax](https://docs.pytorch.org/tutorials/intermediate/ax_multiobjective_nas_tutorial.html) - [PyTorch Profiler With TensorBoard](https://docs.pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html) - [Real Time Inference on Raspberry Pi 4 and 5 (40 fps!)](https://docs.pytorch.org/tutorials/intermediate/realtime_rpi.html) - [Mosaic: Memory Profiling for PyTorch](https://docs.pytorch.org/tutorials/beginner/mosaic_memory_profiling_tutorial.html) - [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html) - [Recipes](https://docs.pytorch.org/tutorials/recipes_index.html) - [Defining a Neural Network in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/defining_a_neural_network.html) - [(beta) Using TORCH\_LOGS python API with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_logs.html) - [What is a state\_dict in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html) - [Warmstarting model using parameters from a different model in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/warmstarting_model_using_parameters_from_a_different_model.html) - [Zeroing out gradients in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/zeroing_out_gradients.html) - [PyTorch Profiler](https://docs.pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) - [Model Interpretability using Captum](https://docs.pytorch.org/tutorials/recipes/recipes/Captum_Recipe.html) - [How to use TensorBoard with PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) - [Automatic Mixed Precision](https://docs.pytorch.org/tutorials/recipes/recipes/amp_recipe.html) - [Performance Tuning Guide](https://docs.pytorch.org/tutorials/recipes/recipes/tuning_guide.html) - [(beta) Compiling the optimizer with torch.compile](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer.html) - [Timer quick start](https://docs.pytorch.org/tutorials/recipes/recipes/timer_quick_start.html) - [Shard Optimizer States with ZeroRedundancyOptimizer](https://docs.pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html) - [Getting Started with CommDebugMode](https://docs.pytorch.org/tutorials/recipes/distributed_comm_debug_mode.html) - [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html) - [PyTorch Benchmark](https://docs.pytorch.org/tutorials/recipes/recipes/benchmark.html) - [Tips for Loading an nn.Module from a Checkpoint](https://docs.pytorch.org/tutorials/recipes/recipes/module_load_state_dict_tips.html) - [Reasoning about Shapes in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/reasoning_about_shapes.html) - [Extension points in nn.Module for load\_state\_dict and tensor subclasses](https://docs.pytorch.org/tutorials/recipes/recipes/swap_tensors.html) - [torch.export AOTInductor Tutorial for Python runtime (Beta)](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html) - [How to use TensorBoard with PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) - [(beta) Utilizing Torch Function modes with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_torch_function_modes.html) - [(beta) Running the compiled optimizer with an LR Scheduler](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer_lr_scheduler.html) - [Explicit horizontal fusion with foreach\_map and torch.compile](https://docs.pytorch.org/tutorials/recipes/foreach_map.html) - [Using User-Defined Triton Kernels with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html) - [Compile Time Caching in torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) - [Compile Time Caching Configuration](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_configuration_tutorial.html) - [Reducing torch.compile cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) - [Reducing AoT cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_aot.html) - [Ease-of-use quantization for PyTorch with Intel® Neural Compressor](https://docs.pytorch.org/tutorials/recipes/intel_neural_compressor_for_pytorch.html) - [Getting Started with DeviceMesh](https://docs.pytorch.org/tutorials/recipes/distributed_device_mesh.html) - [Getting Started with Distributed Checkpoint (DCP)](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html) - [Asynchronous Saving with Distributed Checkpoint (DCP)](https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html) - [DebugMode: Recording Dispatched Operations and Numerical Debugging](https://docs.pytorch.org/tutorials/recipes/debug_mode_tutorial.html) - [Unstable](https://docs.pytorch.org/tutorials/unstable_index.html) - [Introduction to Context Parallel](https://docs.pytorch.org/tutorials/unstable/context_parallel.html) - [Flight Recorder for Debugging Stuck Jobs](https://docs.pytorch.org/tutorials/unstable/flight_recorder_tutorial.html) - [TorchInductor C++ Wrapper Tutorial](https://docs.pytorch.org/tutorials/unstable/inductor_cpp_wrapper_tutorial.html) - [How to use torch.compile on Windows CPU/XPU](https://docs.pytorch.org/tutorials/unstable/inductor_windows.html) - [torch.vmap](https://docs.pytorch.org/tutorials/unstable/vmap_recipe.html) - [Getting Started with Nested Tensors](https://docs.pytorch.org/tutorials/unstable/nestedtensor.html) - [MaskedTensor Overview](https://docs.pytorch.org/tutorials/unstable/maskedtensor_overview.html) - [MaskedTensor Sparsity](https://docs.pytorch.org/tutorials/unstable/maskedtensor_sparsity.html) - [MaskedTensor Advanced Semantics](https://docs.pytorch.org/tutorials/unstable/maskedtensor_advanced_semantics.html) - [Efficiently writing “sparse” semantics for Adagrad with MaskedTensor](https://docs.pytorch.org/tutorials/unstable/maskedtensor_adagrad.html) - [Autoloading Out-of-Tree Extension](https://docs.pytorch.org/tutorials/unstable/python_extension_autoload.html) - [Using Max-Autotune Compilation on CPU for Better Performance](https://docs.pytorch.org/tutorials/unstable/max_autotune_on_CPU_tutorial.html) [Go to pytorch.org](https://pytorch.org/) - [X](https://x.com/PyTorch) - [GitHub](https://github.com/pytorch/tutorials) - [Discourse](https://dev-discuss.pytorch.org/) - [PyPi](https://pypi.org/project/torch/) Section Navigation - [Defining a Neural Network in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/defining_a_neural_network.html) - [(beta) Using TORCH\_LOGS python API with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_logs.html) - [What is a state\_dict in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html) - [Warmstarting model using parameters from a different model in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/warmstarting_model_using_parameters_from_a_different_model.html) - [Zeroing out gradients in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/zeroing_out_gradients.html) - [PyTorch Profiler](https://docs.pytorch.org/tutorials/recipes/recipes/profiler_recipe.html) - [Model Interpretability using Captum](https://docs.pytorch.org/tutorials/recipes/recipes/Captum_Recipe.html) - [How to use TensorBoard with PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) - [Automatic Mixed Precision](https://docs.pytorch.org/tutorials/recipes/recipes/amp_recipe.html) - [Performance Tuning Guide](https://docs.pytorch.org/tutorials/recipes/recipes/tuning_guide.html) - [(beta) Compiling the optimizer with torch.compile](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer.html) - [Timer quick start](https://docs.pytorch.org/tutorials/recipes/recipes/timer_quick_start.html) - [Shard Optimizer States with ZeroRedundancyOptimizer](https://docs.pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html) - [Getting Started with `CommDebugMode`](https://docs.pytorch.org/tutorials/recipes/distributed_comm_debug_mode.html) - [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html) - [PyTorch Benchmark](https://docs.pytorch.org/tutorials/recipes/recipes/benchmark.html) - [Tips for Loading an `nn.Module` from a Checkpoint](https://docs.pytorch.org/tutorials/recipes/recipes/module_load_state_dict_tips.html) - [Reasoning about Shapes in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/reasoning_about_shapes.html) - [Extension points in `nn.Module` for `load_state_dict` and tensor subclasses](https://docs.pytorch.org/tutorials/recipes/recipes/swap_tensors.html) - [`torch.export` AOTInductor Tutorial for Python runtime (Beta)](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html) - [How to use TensorBoard with PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) - [(beta) Utilizing Torch Function modes with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_torch_function_modes.html) - [(beta) Running the compiled optimizer with an LR Scheduler](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer_lr_scheduler.html) - [Explicit horizontal fusion with foreach\_map and torch.compile](https://docs.pytorch.org/tutorials/recipes/foreach_map.html) - [Using User-Defined Triton Kernels with `torch.compile`](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html) - [Compile Time Caching in `torch.compile`](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) - [Compile Time Caching Configuration](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_configuration_tutorial.html) - [Reducing torch.compile cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) - [Reducing AoT cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_aot.html) - [Ease-of-use quantization for PyTorch with Intel® Neural Compressor](https://docs.pytorch.org/tutorials/recipes/intel_neural_compressor_for_pytorch.html) - [Getting Started with DeviceMesh](https://docs.pytorch.org/tutorials/recipes/distributed_device_mesh.html) - [Getting Started with Distributed Checkpoint (DCP)](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html) - [Asynchronous Saving with Distributed Checkpoint (DCP)](https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html) - [DebugMode: Recording Dispatched Operations and Numerical Debugging](https://docs.pytorch.org/tutorials/recipes/debug_mode_tutorial.html) - [Recipes](https://docs.pytorch.org/tutorials/recipes_index.html) - Getting... Rate this Page ★ ★ ★ ★ ★ recipes/distributed\_checkpoint\_recipe [![](https://docs.pytorch.org/tutorials/_static/img/pytorch-colab.svg) Run in Google Colab Colab]() [![](https://docs.pytorch.org/tutorials/_static/img/pytorch-download.svg) Download Notebook Notebook]() [![](https://docs.pytorch.org/tutorials/_static/img/pytorch-github.svg) View on GitHub GitHub]() # Getting Started with Distributed Checkpoint (DCP)[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#getting-started-with-distributed-checkpoint-dcp "Link to this heading") Created On: Oct 02, 2023 \| Last Updated: Jul 10, 2025 \| Last Verified: Nov 05, 2024 **Author**: [Iris Zhang](https://github.com/wz337), [Rodrigo Kumpera](https://github.com/kumpera), [Chien-Chin Huang](https://github.com/fegin), [Lucas Pasqualin](https://github.com/lucasllc) Note [![edit](https://docs.pytorch.org/tutorials/_images/pencil-16.png)](https://docs.pytorch.org/tutorials/_images/pencil-16.png) View and edit this tutorial in [github](https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst). Prerequisites: - [FullyShardedDataParallel API documents](https://pytorch.org/docs/master/fsdp.html) - [torch.load API documents](https://pytorch.org/docs/stable/generated/torch.load.html) Checkpointing AI models during distributed training could be challenging, as parameters and gradients are partitioned across trainers and the number of trainers available could change when you resume training. Pytorch Distributed Checkpointing (DCP) can help make this process easier. In this tutorial, we show how to use DCP APIs with a simple FSDP wrapped model. ## How DCP works[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#how-dcp-works "Link to this heading") `torch.distributed.checkpoint()` enables saving and loading models from multiple ranks in parallel. You can use this module to save on any number of ranks in parallel, and then re-shard across differing cluster topologies at load time. Addditionally, through the use of modules in `torch.distributed.checkpoint.state_dict()`, DCP offers support for gracefully handling `state_dict` generation and loading in distributed settings. This includes managing fully-qualified-name (FQN) mappings across models and optimizers, and setting default parameters for PyTorch provided parallelisms. DCP is different from [`torch.save()`](https://docs.pytorch.org/docs/stable/generated/torch.save.html#torch.save "(in PyTorch v2.11)") and [`torch.load()`](https://docs.pytorch.org/docs/stable/generated/torch.load.html#torch.load "(in PyTorch v2.11)") in a few significant ways: - It produces multiple files per checkpoint, with at least one per rank. - It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead. - DCP offers special handling of Stateful objects (formally defined in torch.distributed.checkpoint.stateful), automatically calling both state\_dict and load\_state\_dict methods if they are defined. Note The code in this tutorial runs on an 8-GPU server, but it can be easily generalized to other environments. ## How to use DCP[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#how-to-use-dcp "Link to this heading") Here we use a toy model wrapped with FSDP for demonstration purposes. Similarly, the APIs and logic can be applied to larger models for checkpointing. ### Saving[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#saving "Link to this heading") Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it. ``` import os import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.multiprocessing as mp import torch.nn as nn from torch.distributed.fsdp import fully_shard from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.checkpoint.stateful import Stateful CHECKPOINT_DIR = "checkpoint" class AppState(Stateful): """This is a useful wrapper for checkpointing the Application State. Since this object is compliant with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the dcp.save/load APIs. Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model and optimizer. """ def __init__(self, model, optimizer=None): self.model = model self.optimizer = optimizer def state_dict(self): # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) return { "model": model_state_dict, "optim": optimizer_state_dict } def load_state_dict(self, state_dict): # sets our state dicts on the model and optimizer, now that we've loaded set_state_dict( self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"] ) class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.net1 = nn.Linear(16, 16) self.relu = nn.ReLU() self.net2 = nn.Linear(16, 8) def forward(self, x): return self.net2(self.relu(self.net1(x))) def setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355 " # initialize the process group dist.init_process_group("nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) def cleanup(): dist.destroy_process_group() def run_fsdp_checkpoint_save_example(rank, world_size): print(f"Running basic FSDP checkpoint saving example on rank {rank}.") setup(rank, world_size) # create a model and move it to GPU with id rank model = ToyModel().to(rank) model = fully_shard(model) loss_fn = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.1) optimizer.zero_grad() model(torch.rand(8, 16, device="cuda")).sum().backward() optimizer.step() state_dict = { "app": AppState(model, optimizer) } dcp.save(state_dict, checkpoint_id=CHECKPOINT_DIR) cleanup() if __name__ == "__main__": world_size = torch.cuda.device_count() print(f"Running fsdp checkpoint example on {world_size} devices.") mp.spawn( run_fsdp_checkpoint_save_example, args=(world_size,), nprocs=world_size, join=True, ) ``` Please go ahead and check the checkpoint directory. You should see checkpoint files corresponding to the number of files as shown below. For example, if you have 8 devices, you should see 8 files. [![Distributed Checkpoint](https://docs.pytorch.org/tutorials/_images/distributed_checkpoint_generated_files.png)](https://docs.pytorch.org/tutorials/_images/distributed_checkpoint_generated_files.png) ### Loading[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#loading "Link to this heading") After saving, let’s create the same FSDP-wrapped model, and load the saved state dict from storage into the model. You can load in the same world size or different world size. Please note that you will have to call `model.state_dict()` prior to loading and pass it to DCP’s `load_state_dict()` API. This is fundamentally different from [`torch.load()`](https://docs.pytorch.org/docs/stable/generated/torch.load.html#torch.load "(in PyTorch v2.11)"), as [`torch.load()`](https://docs.pytorch.org/docs/stable/generated/torch.load.html#torch.load "(in PyTorch v2.11)") simply requires the path to the checkpoint prior for loading. The reason that we need the `state_dict` prior to loading is: - DCP uses the pre-allocated storage from model state\_dict to load from the checkpoint directory. During loading, the state\_dict passed in will be updated in place. - DCP requires the sharding information from the model prior to loading to support resharding. ``` import os import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict import torch.multiprocessing as mp import torch.nn as nn from torch.distributed.fsdp import fully_shard CHECKPOINT_DIR = "checkpoint" class AppState(Stateful): """This is a useful wrapper for checkpointing the Application State. Since this object is compliant with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the dcp.save/load APIs. Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model and optimizer. """ def __init__(self, model, optimizer=None): self.model = model self.optimizer = optimizer def state_dict(self): # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) return { "model": model_state_dict, "optim": optimizer_state_dict } def load_state_dict(self, state_dict): # sets our state dicts on the model and optimizer, now that we've loaded set_state_dict( self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"] ) class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.net1 = nn.Linear(16, 16) self.relu = nn.ReLU() self.net2 = nn.Linear(16, 8) def forward(self, x): return self.net2(self.relu(self.net1(x))) def setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355 " # initialize the process group dist.init_process_group("nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) def cleanup(): dist.destroy_process_group() def run_fsdp_checkpoint_load_example(rank, world_size): print(f"Running basic FSDP checkpoint loading example on rank {rank}.") setup(rank, world_size) # create a model and move it to GPU with id rank model = ToyModel().to(rank) model = fully_shard(model) optimizer = torch.optim.Adam(model.parameters(), lr=0.1) state_dict = { "app": AppState(model, optimizer)} dcp.load( state_dict=state_dict, checkpoint_id=CHECKPOINT_DIR, ) cleanup() if __name__ == "__main__": world_size = torch.cuda.device_count() print(f"Running fsdp checkpoint example on {world_size} devices.") mp.spawn( run_fsdp_checkpoint_load_example, args=(world_size,), nprocs=world_size, join=True, ) ``` If you would like to load the saved checkpoint into a non-FSDP wrapped model in a non-distributed setup, perhaps for inference, you can also do that with DCP. By default, DCP saves and loads a distributed `state_dict` in Single Program Multiple Data(SPMD) style. However if no process group is initialized, DCP infers the intent is to save or load in “non-distributed” style, meaning entirely in the current process. Note Distributed checkpoint support for Multi-Program Multi-Data is still under development. ``` import os import torch import torch.distributed.checkpoint as dcp import torch.nn as nn CHECKPOINT_DIR = "checkpoint" class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.net1 = nn.Linear(16, 16) self.relu = nn.ReLU() self.net2 = nn.Linear(16, 8) def forward(self, x): return self.net2(self.relu(self.net1(x))) def run_checkpoint_load_example(): # create the non FSDP-wrapped toy model model = ToyModel() state_dict = { "model": model.state_dict(), } # since no progress group is initialized, DCP will disable any collectives. dcp.load( state_dict=state_dict, checkpoint_id=CHECKPOINT_DIR, ) model.load_state_dict(state_dict["model"]) if __name__ == "__main__": print(f"Running basic DCP checkpoint loading example.") run_checkpoint_load_example() ``` ## Formats[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#formats "Link to this heading") One drawback not yet mentioned is that DCP saves checkpoints in a format which is inherently different then those generated using torch.save. Since this can be an issue when users wish to share models with users used to the torch.save format, or in general just want to add format flexibility to their applications. For this case, we provide the `format_utils` module in `torch.distributed.checkpoint.format_utils`. A command line utility is provided for the users convenience, which follows the following format: ``` python -m torch.distributed.checkpoint.format_utils <mode> <checkpoint location> <location to write formats to> ``` In the command above, `mode` is one of `torch_to_dcp` or `dcp_to_torch`. Alternatively, methods are also provided for users who may wish to convert checkpoints directly. ``` import os import torch import torch.distributed.checkpoint as DCP from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp CHECKPOINT_DIR = "checkpoint" TORCH_SAVE_CHECKPOINT_DIR = "torch_save_checkpoint.pth" # convert dcp model to torch.save (assumes checkpoint was generated as above) dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_DIR) # converts the torch.save model back to DCP torch_save_to_dcp(TORCH_SAVE_CHECKPOINT_DIR, f"{CHECKPOINT_DIR}_new") ``` ## Conclusion[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#conclusion "Link to this heading") In conclusion, we have learned how to use DCP’s `save()` and `load()` APIs, as well as how they are different form [`torch.save()`](https://docs.pytorch.org/docs/stable/generated/torch.save.html#torch.save "(in PyTorch v2.11)") and [`torch.load()`](https://docs.pytorch.org/docs/stable/generated/torch.load.html#torch.load "(in PyTorch v2.11)"). Additionally, we’ve learned how to use `get_state_dict()` and `set_state_dict()` to automatically manage parallelism-specific FQN’s and defaults during state dict generation and loading. For more information, please see the following: - [Saving and loading models tutorial](https://pytorch.org/tutorials/beginner/saving_loading_models.html) - [Getting started with FullyShardedDataParallel tutorial](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) Rate this Page ★ ★ ★ ★ ★ Send Feedback [previous Getting Started with DeviceMesh](https://docs.pytorch.org/tutorials/recipes/distributed_device_mesh.html "previous page") [next Asynchronous Saving with Distributed Checkpoint (DCP)](https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html "next page") Built with the [PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html) 0.15.4. [previous Getting Started with DeviceMesh](https://docs.pytorch.org/tutorials/recipes/distributed_device_mesh.html "previous page") [next Asynchronous Saving with Distributed Checkpoint (DCP)](https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html "next page") On this page - [How DCP works](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#how-dcp-works) - [How to use DCP](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#how-to-use-dcp) - [Saving](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#saving) - [Loading](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#loading) - [Formats](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#formats) - [Conclusion](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#conclusion) PyTorch Libraries - [ExecuTorch](https://docs.pytorch.org/executorch) - [Helion](https://docs.pytorch.org/helion) - [torchao](https://docs.pytorch.org/ao) - [kineto](https://github.com/pytorch/kineto) - [torchtitan](https://github.com/pytorch/torchtitan) - [TorchRL](https://docs.pytorch.org/rl) - [torchvision](https://docs.pytorch.org/vision) - [torchaudio](https://docs.pytorch.org/audio) - [tensordict](https://docs.pytorch.org/tensordict) - [PyTorch on XLA Devices](https://docs.pytorch.org/xla) ## Docs Access comprehensive developer documentation for PyTorch [View Docs](https://docs.pytorch.org/docs/stable/index.html) ## Tutorials Get in-depth tutorials for beginners and advanced developers [View Tutorials](https://docs.pytorch.org/tutorials) ## Resources Find development resources and get your questions answered [View Resources](https://pytorch.org/resources) **Stay in touch** for updates, event info, and the latest news By submitting this form, I consent to receive marketing emails from the LF and its projects regarding their events, training, research, developments, and related announcements. I understand that I can unsubscribe at any time using the links in the footers of the emails I receive. [Privacy Policy](https://www.linuxfoundation.org/privacy/). © PyTorch. Copyright © The Linux Foundation®. All rights reserved. The Linux Foundation has registered trademarks and uses trademarks. For more information, including terms of use, privacy policy, and trademark usage, please see our [Policies](https://www.linuxfoundation.org/legal/policies) page. [Trademark Usage](https://www.linuxfoundation.org/trademark-usage). [Privacy Policy](http://www.linuxfoundation.org/privacy). To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: [Cookies Policy](https://opensource.fb.com/legal/cookie-policy). ![](https://docs.pytorch.org/tutorials/_static/img/pytorch-x.svg) © Copyright 2024, PyTorch. Created using [Sphinx](https://www.sphinx-doc.org/) 7.2.6. Built with the [PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html) 0.15.4.
Readable Markdown
Created On: Oct 02, 2023 \| Last Updated: Jul 10, 2025 \| Last Verified: Nov 05, 2024 **Author**: [Iris Zhang](https://github.com/wz337), [Rodrigo Kumpera](https://github.com/kumpera), [Chien-Chin Huang](https://github.com/fegin), [Lucas Pasqualin](https://github.com/lucasllc) Note [![edit](https://docs.pytorch.org/tutorials/_images/pencil-16.png)](https://docs.pytorch.org/tutorials/_images/pencil-16.png) View and edit this tutorial in [github](https://github.com/pytorch/tutorials/blob/main/recipes_source/distributed_checkpoint_recipe.rst). Prerequisites: - [FullyShardedDataParallel API documents](https://pytorch.org/docs/master/fsdp.html) - [torch.load API documents](https://pytorch.org/docs/stable/generated/torch.load.html) Checkpointing AI models during distributed training could be challenging, as parameters and gradients are partitioned across trainers and the number of trainers available could change when you resume training. Pytorch Distributed Checkpointing (DCP) can help make this process easier. In this tutorial, we show how to use DCP APIs with a simple FSDP wrapped model. ## How DCP works[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#how-dcp-works "Link to this heading") `torch.distributed.checkpoint()` enables saving and loading models from multiple ranks in parallel. You can use this module to save on any number of ranks in parallel, and then re-shard across differing cluster topologies at load time. Addditionally, through the use of modules in `torch.distributed.checkpoint.state_dict()`, DCP offers support for gracefully handling `state_dict` generation and loading in distributed settings. This includes managing fully-qualified-name (FQN) mappings across models and optimizers, and setting default parameters for PyTorch provided parallelisms. DCP is different from [`torch.save()`](https://docs.pytorch.org/docs/stable/generated/torch.save.html#torch.save "(in PyTorch v2.11)") and [`torch.load()`](https://docs.pytorch.org/docs/stable/generated/torch.load.html#torch.load "(in PyTorch v2.11)") in a few significant ways: - It produces multiple files per checkpoint, with at least one per rank. - It operates in place, meaning that the model should allocate its data first and DCP uses that storage instead. - DCP offers special handling of Stateful objects (formally defined in torch.distributed.checkpoint.stateful), automatically calling both state\_dict and load\_state\_dict methods if they are defined. Note The code in this tutorial runs on an 8-GPU server, but it can be easily generalized to other environments. ## How to use DCP[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#how-to-use-dcp "Link to this heading") Here we use a toy model wrapped with FSDP for demonstration purposes. Similarly, the APIs and logic can be applied to larger models for checkpointing. ### Saving[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#saving "Link to this heading") Now, let’s create a toy module, wrap it with FSDP, feed it with some dummy input data, and save it. ``` import os import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.multiprocessing as mp import torch.nn as nn from torch.distributed.fsdp import fully_shard from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.checkpoint.stateful import Stateful CHECKPOINT_DIR = "checkpoint" class AppState(Stateful): """This is a useful wrapper for checkpointing the Application State. Since this object is compliant with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the dcp.save/load APIs. Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model and optimizer. """ def __init__(self, model, optimizer=None): self.model = model self.optimizer = optimizer def state_dict(self): # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) return { "model": model_state_dict, "optim": optimizer_state_dict } def load_state_dict(self, state_dict): # sets our state dicts on the model and optimizer, now that we've loaded set_state_dict( self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"] ) class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.net1 = nn.Linear(16, 16) self.relu = nn.ReLU() self.net2 = nn.Linear(16, 8) def forward(self, x): return self.net2(self.relu(self.net1(x))) def setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355 " # initialize the process group dist.init_process_group("nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) def cleanup(): dist.destroy_process_group() def run_fsdp_checkpoint_save_example(rank, world_size): print(f"Running basic FSDP checkpoint saving example on rank {rank}.") setup(rank, world_size) # create a model and move it to GPU with id rank model = ToyModel().to(rank) model = fully_shard(model) loss_fn = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.1) optimizer.zero_grad() model(torch.rand(8, 16, device="cuda")).sum().backward() optimizer.step() state_dict = { "app": AppState(model, optimizer) } dcp.save(state_dict, checkpoint_id=CHECKPOINT_DIR) cleanup() if __name__ == "__main__": world_size = torch.cuda.device_count() print(f"Running fsdp checkpoint example on {world_size} devices.") mp.spawn( run_fsdp_checkpoint_save_example, args=(world_size,), nprocs=world_size, join=True, ) ``` Please go ahead and check the checkpoint directory. You should see checkpoint files corresponding to the number of files as shown below. For example, if you have 8 devices, you should see 8 files. [![Distributed Checkpoint](https://docs.pytorch.org/tutorials/_images/distributed_checkpoint_generated_files.png)](https://docs.pytorch.org/tutorials/_images/distributed_checkpoint_generated_files.png) ### Loading[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#loading "Link to this heading") After saving, let’s create the same FSDP-wrapped model, and load the saved state dict from storage into the model. You can load in the same world size or different world size. Please note that you will have to call `model.state_dict()` prior to loading and pass it to DCP’s `load_state_dict()` API. This is fundamentally different from [`torch.load()`](https://docs.pytorch.org/docs/stable/generated/torch.load.html#torch.load "(in PyTorch v2.11)"), as [`torch.load()`](https://docs.pytorch.org/docs/stable/generated/torch.load.html#torch.load "(in PyTorch v2.11)") simply requires the path to the checkpoint prior for loading. The reason that we need the `state_dict` prior to loading is: - DCP uses the pre-allocated storage from model state\_dict to load from the checkpoint directory. During loading, the state\_dict passed in will be updated in place. - DCP requires the sharding information from the model prior to loading to support resharding. ``` import os import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp from torch.distributed.checkpoint.stateful import Stateful from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict import torch.multiprocessing as mp import torch.nn as nn from torch.distributed.fsdp import fully_shard CHECKPOINT_DIR = "checkpoint" class AppState(Stateful): """This is a useful wrapper for checkpointing the Application State. Since this object is compliant with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the dcp.save/load APIs. Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model and optimizer. """ def __init__(self, model, optimizer=None): self.model = model self.optimizer = optimizer def state_dict(self): # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) return { "model": model_state_dict, "optim": optimizer_state_dict } def load_state_dict(self, state_dict): # sets our state dicts on the model and optimizer, now that we've loaded set_state_dict( self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"] ) class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.net1 = nn.Linear(16, 16) self.relu = nn.ReLU() self.net2 = nn.Linear(16, 8) def forward(self, x): return self.net2(self.relu(self.net1(x))) def setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355 " # initialize the process group dist.init_process_group("nccl", rank=rank, world_size=world_size) torch.cuda.set_device(rank) def cleanup(): dist.destroy_process_group() def run_fsdp_checkpoint_load_example(rank, world_size): print(f"Running basic FSDP checkpoint loading example on rank {rank}.") setup(rank, world_size) # create a model and move it to GPU with id rank model = ToyModel().to(rank) model = fully_shard(model) optimizer = torch.optim.Adam(model.parameters(), lr=0.1) state_dict = { "app": AppState(model, optimizer)} dcp.load( state_dict=state_dict, checkpoint_id=CHECKPOINT_DIR, ) cleanup() if __name__ == "__main__": world_size = torch.cuda.device_count() print(f"Running fsdp checkpoint example on {world_size} devices.") mp.spawn( run_fsdp_checkpoint_load_example, args=(world_size,), nprocs=world_size, join=True, ) ``` If you would like to load the saved checkpoint into a non-FSDP wrapped model in a non-distributed setup, perhaps for inference, you can also do that with DCP. By default, DCP saves and loads a distributed `state_dict` in Single Program Multiple Data(SPMD) style. However if no process group is initialized, DCP infers the intent is to save or load in “non-distributed” style, meaning entirely in the current process. Note Distributed checkpoint support for Multi-Program Multi-Data is still under development. ``` import os import torch import torch.distributed.checkpoint as dcp import torch.nn as nn CHECKPOINT_DIR = "checkpoint" class ToyModel(nn.Module): def __init__(self): super(ToyModel, self).__init__() self.net1 = nn.Linear(16, 16) self.relu = nn.ReLU() self.net2 = nn.Linear(16, 8) def forward(self, x): return self.net2(self.relu(self.net1(x))) def run_checkpoint_load_example(): # create the non FSDP-wrapped toy model model = ToyModel() state_dict = { "model": model.state_dict(), } # since no progress group is initialized, DCP will disable any collectives. dcp.load( state_dict=state_dict, checkpoint_id=CHECKPOINT_DIR, ) model.load_state_dict(state_dict["model"]) if __name__ == "__main__": print(f"Running basic DCP checkpoint loading example.") run_checkpoint_load_example() ``` ## Formats[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#formats "Link to this heading") One drawback not yet mentioned is that DCP saves checkpoints in a format which is inherently different then those generated using torch.save. Since this can be an issue when users wish to share models with users used to the torch.save format, or in general just want to add format flexibility to their applications. For this case, we provide the `format_utils` module in `torch.distributed.checkpoint.format_utils`. A command line utility is provided for the users convenience, which follows the following format: ``` python -m torch.distributed.checkpoint.format_utils <mode> <checkpoint location> <location to write formats to> ``` In the command above, `mode` is one of `torch_to_dcp` or `dcp_to_torch`. Alternatively, methods are also provided for users who may wish to convert checkpoints directly. ``` import os import torch import torch.distributed.checkpoint as DCP from torch.distributed.checkpoint.format_utils import dcp_to_torch_save, torch_save_to_dcp CHECKPOINT_DIR = "checkpoint" TORCH_SAVE_CHECKPOINT_DIR = "torch_save_checkpoint.pth" # convert dcp model to torch.save (assumes checkpoint was generated as above) dcp_to_torch_save(CHECKPOINT_DIR, TORCH_SAVE_CHECKPOINT_DIR) # converts the torch.save model back to DCP torch_save_to_dcp(TORCH_SAVE_CHECKPOINT_DIR, f"{CHECKPOINT_DIR}_new") ``` ## Conclusion[\#](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html#conclusion "Link to this heading") In conclusion, we have learned how to use DCP’s `save()` and `load()` APIs, as well as how they are different form [`torch.save()`](https://docs.pytorch.org/docs/stable/generated/torch.save.html#torch.save "(in PyTorch v2.11)") and [`torch.load()`](https://docs.pytorch.org/docs/stable/generated/torch.load.html#torch.load "(in PyTorch v2.11)"). Additionally, we’ve learned how to use `get_state_dict()` and `set_state_dict()` to automatically manage parallelism-specific FQN’s and defaults during state dict generation and loading. For more information, please see the following: - [Saving and loading models tutorial](https://pytorch.org/tutorials/beginner/saving_loading_models.html) - [Getting started with FullyShardedDataParallel tutorial](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
Shard114 (laksa)
Root Hash14416670112284949514
Unparsed URLorg,pytorch!docs,/tutorials/recipes/distributed_checkpoint_recipe.html s443