ā¹ļø Skipped - page is already crawled
| Filter | Status | Condition | Details |
|---|---|---|---|
| HTTP status | PASS | download_http_code = 200 | HTTP 200 |
| Age cutoff | PASS | download_stamp > now() - 6 MONTH | 0 months ago |
| History drop | PASS | isNull(history_drop_reason) | No drop reason |
| Spam/ban | PASS | fh_dont_index != 1 AND ml_spam_score = 0 | ml_spam_score=0 |
| Canonical | PASS | meta_canonical IS NULL OR = '' OR = src_unparsed | Not set |
| Property | Value |
|---|---|
| URL | https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html |
| Last Crawled | 2026-04-16 13:49:59 (11 hours ago) |
| First Indexed | 2025-06-30 23:25:20 (9 months ago) |
| HTTP Status Code | 200 |
| Meta Title | Writing Distributed Applications with PyTorch ā PyTorch Tutorials 2.11.0+cu130 documentation |
| Meta Description | null |
| Meta Canonical | null |
| Boilerpipe Text | Created On: Oct 06, 2017 | Last Updated: Sep 05, 2025 | Last Verified: Nov 05, 2024
Author
:
SƩb Arnold
Note
View and edit this tutorial in
github
.
Prerequisites:
PyTorch Distributed Overview
In this short tutorial, we will be going over the distributed package
of PyTorch. Weāll see how to set up the distributed setting, use the
different communication strategies, and go over some of the internals of
the package.
Setup
#
The distributed package included in PyTorch (i.e.,
torch.distributed
) enables researchers and practitioners to easily
parallelize their computations across processes and clusters of
machines. To do so, it leverages message passing semantics
allowing each process to communicate data to any of the other processes.
As opposed to the multiprocessing (
torch.multiprocessing
) package,
processes can use different communication backends and are not
restricted to being executed on the same machine.
In order to get started we need the ability to run multiple processes
simultaneously. If you have access to compute cluster you should check
with your local sysadmin or use your favorite coordination tool (e.g.,
pdsh
,
clustershell
, or
slurm
). For the purpose of this
tutorial, we will use a single machine and spawn multiple processes using
the following template.
"""run.py:"""
#!/usr/bin/env python
import
os
import
sys
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
def
run
(
rank
,
size
):
""" Distributed function to be implemented later. """
pass
def
init_process
(
rank
,
size
,
fn
,
backend
=
'gloo'
):
""" Initialize the distributed environment. """
os
.
environ
[
'MASTER_ADDR'
]
=
'127.0.0.1'
os
.
environ
[
'MASTER_PORT'
]
=
'29500'
dist
.
init_process_group
(
backend
,
rank
=
rank
,
world_size
=
size
)
fn
(
rank
,
size
)
if
__name__
==
"__main__"
:
world_size
=
2
processes
=
[]
if
"google.colab"
in
sys
.
modules
:
print
(
"Running in Google Colab"
)
mp
.
get_context
(
"spawn"
)
else
:
mp
.
set_start_method
(
"spawn"
)
for
rank
in
range
(
world_size
):
p
=
mp
.
Process
(
target
=
init_process
,
args
=
(
rank
,
world_size
,
run
))
p
.
start
()
processes
.
append
(
p
)
for
p
in
processes
:
p
.
join
()
The above script spawns two processes who will each setup the
distributed environment, initialize the process group
(
dist.init_process_group
), and finally execute the given
run
function.
Letās have a look at the
init_process
function. It ensures that
every process will be able to coordinate through a master, using the
same ip address and port. Note that we used the
gloo
backend but
other backends are available. (c.f.
Section 5.1
) We will go over the magic
happening in
dist.init_process_group
at the end of this tutorial,
but it essentially allows processes to communicate with each other by
sharing their locations.
Point-to-Point Communication
#
Send and Recv
#
A transfer of data from one process to another is called a
point-to-point communication. These are achieved through the
send
and
recv
functions or their
immediate
counter-parts,
isend
and
irecv
.
"""Blocking point-to-point communication."""
def
run
(
rank
,
size
):
tensor
=
torch
.
zeros
(
1
)
if
rank
==
0
:
tensor
+=
1
# Send the tensor to process 1
dist
.
send
(
tensor
=
tensor
,
dst
=
1
)
else
:
# Receive tensor from process 0
dist
.
recv
(
tensor
=
tensor
,
src
=
0
)
print
(
'Rank '
,
rank
,
' has data '
,
tensor
[
0
])
In the above example, both processes start with a zero tensor, then
process 0 increments the tensor and sends it to process 1 so that they
both end up with 1.0. Notice that process 1 needs to allocate memory in
order to store the data it will receive.
Also notice that
send/recv
are
blocking
: both processes block
until the communication is completed. On the other hand immediates are
non-blocking
; the script continues its execution and the methods
return a
Work
object upon which we can choose to
wait()
.
"""Non-blocking point-to-point communication."""
def
run
(
rank
,
size
):
tensor
=
torch
.
zeros
(
1
)
req
=
None
if
rank
==
0
:
tensor
+=
1
# Send the tensor to process 1
req
=
dist
.
isend
(
tensor
=
tensor
,
dst
=
1
)
print
(
'Rank 0 started sending'
)
else
:
# Receive tensor from process 0
req
=
dist
.
irecv
(
tensor
=
tensor
,
src
=
0
)
print
(
'Rank 1 started receiving'
)
req
.
wait
()
print
(
'Rank '
,
rank
,
' has data '
,
tensor
[
0
])
When using immediates we have to be careful about how we use the sent and received tensors.
Since we do not know when the data will be communicated to the other process,
we should not modify the sent tensor nor access the received tensor before
req.wait()
has completed.
In other words,
writing to
tensor
after
dist.isend()
will result in undefined behaviour.
reading from
tensor
after
dist.irecv()
will result in undefined
behaviour, until
req.wait()
has been executed.
However, after
req.wait()
has been executed we are guaranteed that the communication took place,
and that the value stored in
tensor[0]
is 1.0.
Point-to-point communication is useful when we want more fine-grained
control over the communication of our processes. They can be used to
implement fancy algorithms, such as the one used in
Baiduās
DeepSpeech
or
Facebookās large-scale
experiments
.(c.f.
Section 4.1
)
Collective Communication
#
As opposed to point-to-point communcation, collectives allow for
communication patterns across all processes in a
group
. A group is a
subset of all our processes. To create a group, we can pass a list of
ranks to
dist.new_group(group)
. By default, collectives are executed
on all processes, also known as the
world
. For example, in order
to obtain the sum of all tensors on all processes, we can use the
dist.all_reduce(tensor,
op,
group)
collective.
""" All-Reduce example."""
def
run
(
rank
,
size
):
""" Simple collective communication. """
group
=
dist
.
new_group
([
0
,
1
])
tensor
=
torch
.
ones
(
1
)
dist
.
all_reduce
(
tensor
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
group
)
print
(
'Rank '
,
rank
,
' has data '
,
tensor
[
0
])
Since we want the sum of all tensors in the group, we use
dist.ReduceOp.SUM
as the reduce operator. Generally speaking, any
commutative mathematical operation can be used as an operator.
Out-of-the-box, PyTorch comes with many such operators, all working at the
element-wise level:
dist.ReduceOp.SUM
,
dist.ReduceOp.PRODUCT
,
dist.ReduceOp.MAX
,
dist.ReduceOp.MIN
,
dist.ReduceOp.BAND
,
dist.ReduceOp.BOR
,
dist.ReduceOp.BXOR
,
dist.ReduceOp.PREMUL_SUM
.
The full list of supported operators is
here
.
In addition to
dist.all_reduce(tensor,
op,
group)
, there are many additional collectives currently implemented in
PyTorch. Here are a few supported collectives.
dist.broadcast(tensor,
src,
group)
: Copies
tensor
from
src
to all other processes.
dist.reduce(tensor,
dst,
op,
group)
: Applies
op
to every
tensor
and stores the result in
dst
.
dist.all_reduce(tensor,
op,
group)
: Same as reduce, but the
result is stored in all processes.
dist.scatter(tensor,
scatter_list,
src,
group)
: Copies the
i
th
i^{\text{th}}
tensor
scatter_list[i]
to the
i
th
i^{\text{th}}
process.
dist.gather(tensor,
gather_list,
dst,
group)
: Copies
tensor
from all processes in
dst
.
dist.all_gather(tensor_list,
tensor,
group)
: Copies
tensor
from all processes to
tensor_list
, on all processes.
dist.barrier(group)
: Blocks all processes in
group
until each one has entered this function.
dist.all_to_all(output_tensor_list,
input_tensor_list,
group)
: Scatters list of input tensors to all processes in
a group and return gathered list of tensors in output list.
The full list of supported collectives can be found by looking at the latest documentation for PyTorch Distributed
(link)
.
Distributed Training
#
Note:
You can find the example script of this section in
this
GitHub repository
.
Now that we understand how the distributed module works, let us write
something useful with it. Our goal will be to replicate the
functionality of
DistributedDataParallel
.
Of course, this will be a didactic example and in a real-world
situation you should use the official, well-tested and well-optimized
version linked above.
Quite simply we want to implement a distributed version of stochastic
gradient descent. Our script will let all processes compute the
gradients of their model on their batch of data and then average their
gradients. In order to ensure similar convergence results when changing
the number of processes, we will first have to partition our dataset.
(You could also use
torch.utils.data.random_split
,
instead of the snippet below.)
""" Dataset partitioning helper """
class
Partition
(
object
):
def
__init__
(
self
,
data
,
index
):
self
.
data
=
data
self
.
index
=
index
def
__len__
(
self
):
return
len
(
self
.
index
)
def
__getitem__
(
self
,
index
):
data_idx
=
self
.
index
[
index
]
return
self
.
data
[
data_idx
]
class
DataPartitioner
(
object
):
def
__init__
(
self
,
data
,
sizes
=
[
0.7
,
0.2
,
0.1
],
seed
=
1234
):
self
.
data
=
data
self
.
partitions
=
[]
rng
=
Random
()
# from random import Random
rng
.
seed
(
seed
)
data_len
=
len
(
data
)
indexes
=
[
x
for
x
in
range
(
0
,
data_len
)]
rng
.
shuffle
(
indexes
)
for
frac
in
sizes
:
part_len
=
int
(
frac
*
data_len
)
self
.
partitions
.
append
(
indexes
[
0
:
part_len
])
indexes
=
indexes
[
part_len
:]
def
use
(
self
,
partition
):
return
Partition
(
self
.
data
,
self
.
partitions
[
partition
])
With the above snippet, we can now simply partition any dataset using
the following few lines:
""" Partitioning MNIST """
def
partition_dataset
():
dataset
=
datasets
.
MNIST
(
'./data'
,
train
=
True
,
download
=
True
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))
]))
size
=
dist
.
get_world_size
()
bsz
=
128
//
size
partition_sizes
=
[
1.0
/
size
for
_
in
range
(
size
)]
partition
=
DataPartitioner
(
dataset
,
partition_sizes
)
partition
=
partition
.
use
(
dist
.
get_rank
())
train_set
=
torch
.
utils
.
data
.
DataLoader
(
partition
,
batch_size
=
bsz
,
shuffle
=
True
)
return
train_set
,
bsz
Assuming we have 2 replicas, then each process will have a
train_set
of 60000 / 2 = 30000 samples. We also divide the batch size by the
number of replicas in order to maintain the
overall
batch size of 128.
We can now write our usual forward-backward-optimize training code, and
add a function call to average the gradients of our models. (The
following is largely inspired by the official
PyTorch MNIST
example
.)
""" Distributed Synchronous SGD Example """
def
run
(
rank
,
size
):
torch
.
manual_seed
(
1234
)
train_set
,
bsz
=
partition_dataset
()
model
=
Net
()
optimizer
=
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.5
)
num_batches
=
ceil
(
len
(
train_set
.
dataset
)
/
float
(
bsz
))
for
epoch
in
range
(
10
):
epoch_loss
=
0.0
for
data
,
target
in
train_set
:
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
F
.
nll_loss
(
output
,
target
)
epoch_loss
+=
loss
.
item
()
loss
.
backward
()
average_gradients
(
model
)
optimizer
.
step
()
print
(
'Rank '
,
dist
.
get_rank
(),
', epoch '
,
epoch
,
': '
,
epoch_loss
/
num_batches
)
It remains to implement the
average_gradients(model)
function, which
simply takes in a model and averages its gradients across the whole
world.
""" Gradient averaging. """
def
average_gradients
(
model
):
size
=
float
(
dist
.
get_world_size
())
for
param
in
model
.
parameters
():
dist
.
all_reduce
(
param
.
grad
.
data
,
op
=
dist
.
ReduceOp
.
SUM
)
param
.
grad
.
data
/=
size
Et voilĆ
! We successfully implemented distributed synchronous SGD and
could train any model on a large computer cluster.
Note:
While the last sentence is
technically
true, there are
a
lot more tricks
required to
implement a production-level implementation of synchronous SGD. Again,
use what
has been tested and
optimized
.
Our Own Ring-Allreduce
#
As an additional challenge, imagine that we wanted to implement
DeepSpeechās efficient ring allreduce. This is fairly easy to implement
using point-to-point collectives.
""" Implementation of a ring-reduce with addition. """
def
allreduce
(
send
,
recv
):
rank
=
dist
.
get_rank
()
size
=
dist
.
get_world_size
()
send_buff
=
send
.
clone
()
recv_buff
=
send
.
clone
()
accum
=
send
.
clone
()
left
=
((
rank
-
1
)
+
size
)
%
size
right
=
(
rank
+
1
)
%
size
for
i
in
range
(
size
-
1
):
if
i
%
2
==
0
:
# Send send_buff
send_req
=
dist
.
isend
(
send_buff
,
right
)
dist
.
recv
(
recv_buff
,
left
)
accum
[:]
+=
recv_buff
[:]
else
:
# Send recv_buff
send_req
=
dist
.
isend
(
recv_buff
,
right
)
dist
.
recv
(
send_buff
,
left
)
accum
[:]
+=
send_buff
[:]
send_req
.
wait
()
recv
[:]
=
accum
[:]
In the above script, the
allreduce(send,
recv)
function has a
slightly different signature than the ones in PyTorch. It takes a
recv
tensor and will store the sum of all
send
tensors in it. As
an exercise left to the reader, there is still one difference between
our version and the one in DeepSpeech: their implementation divides the
gradient tensor into
chunks
, so as to optimally utilize the
communication bandwidth. (Hint:
torch.chunk
)
Advanced Topics
#
We are now ready to discover some of the more advanced functionalities
of
torch.distributed
. Since there is a lot to cover, this section is
divided into two subsections:
Communication Backends: where we learn how to use MPI and Gloo for
GPU-GPU communication.
Initialization Methods: where we understand how to best set up the
initial coordination phase in
dist.init_process_group()
.
Communication Backends
#
One of the most elegant aspects of
torch.distributed
is its ability
to abstract and build on top of different backends. As mentioned before,
there are multiple backends implemented in PyTorch. These backends can be easily selected
using the
Accelerator API
,
which provides a interface for working with different accelerator types.
Some of the most popular backends are Gloo, NCCL, and MPI. They each have different specifications and tradeoffs, depending
on the desired use case. A comparative table of supported functions can
be found
here
.
Gloo Backend
So far we have made extensive usage of the
Gloo backend
.
It is quite handy as a development platform, as it is included in
the pre-compiled PyTorch binaries and works on both Linux (since 0.2)
and macOS (since 1.3). It supports all point-to-point and collective
operations on CPU, and all collective operations on GPU. The
implementation of the collective operations for CUDA tensors is not as
optimized as the ones provided by the NCCL backend.
As you have surely noticed, our
distributed SGD example does not work if you put
model
on the GPU.
In order to use multiple GPUs, let us also make the following
modifications:
Use Accelerator API
device_type
=
torch.accelerator.current_accelerator()
Use
torch.device(f"{device_type}:{rank}")
model
=
Net()
ā
\rightarrow
model
=
Net().to(device)
Use
data,
target
=
data.to(device),
target.to(device)
With these modifications, your model will now train across two GPUs.
You can monitor GPU utilization using
watch
nvidia-smi
if you are running on NVIDIA hardware.
MPI Backend
The Message Passing Interface (MPI) is a standardized tool from the
field of high-performance computing. It allows to do point-to-point and
collective communications and was the main inspiration for the API of
torch.distributed
. Several implementations of MPI exist (e.g.
Open-MPI
,
MVAPICH2
,
Intel
MPI
) each
optimized for different purposes. The advantage of using the MPI backend
lies in MPIās wide availability - and high-level of optimization - on
large computer clusters.
Some
recent
implementations
are also able to take
advantage of CUDA IPC and GPU Direct technologies in order to avoid
memory copies through the CPU.
Unfortunately, PyTorchās binaries cannot include an MPI implementation
and weāll have to recompile it by hand. Fortunately, this process is
fairly simple given that upon compilation, PyTorch will look
by itself
for an available MPI implementation. The following steps install the MPI
backend, by installing PyTorch
from
source
.
Create and activate your Anaconda environment, install all the
pre-requisites following
the
guide
, but do
not
run
python
setup.py
install
yet.
Choose and install your favorite MPI implementation. Note that
enabling CUDA-aware MPI might require some additional steps. In our
case, weāll stick to Open-MPI
without
GPU support:
conda
install
-c
conda-forge
openmpi
Now, go to your cloned PyTorch repo and execute
python
setup.py
install
.
In order to test our newly installed backend, a few modifications are
required.
Replace the content under
if
__name__
==
'__main__':
with
init_process(0,
0,
run,
backend='mpi')
.
Run
mpirun
-n
4
python
myscript.py
.
The reason for these changes is that MPI needs to create its own
environment before spawning the processes. MPI will also spawn its own
processes and perform the handshake described in
Initialization
Methods
, making the
rank
and
size
arguments of
init_process_group
superfluous. This is actually quite
powerful as you can pass additional arguments to
mpirun
in order to
tailor computational resources for each process. (Things like number of
cores per process, hand-assigning machines to specific ranks, and
some
more
)
Doing so, you should obtain the same familiar output as with the other
communication backends.
NCCL Backend
The
NCCL backend
provides an
optimized implementation of collective operations against CUDA
tensors. If you only use CUDA tensors for your collective operations,
consider using this backend for the best in class performance. The
NCCL backend is included in the pre-built binaries with CUDA support.
XCCL Backend
The
XCCL backend
offers an optimized implementation of collective operations for XPU tensors.
If your workload uses only XPU tensors for collective operations,
this backend provides best-in-class performance.
The XCCL backend is included in the pre-built binaries with XPU support.
Initialization Methods
#
To conclude this tutorial, letās examine the initial function we invoked:
dist.init_process_group(backend,
init_method)
. Specifically, we will discuss the various
initialization methods responsible for the preliminary coordination step between each process.
These methods enable you to define how this coordination is accomplished.
The choice of initialization method depends on your hardware setup, and one method may be more
suitable than others. In addition to the following sections, please refer to the
official
documentation
for further information.
Environment Variable
We have been using the environment variable initialization method
throughout this tutorial. By setting the following four environment
variables on all machines, all processes will be able to properly
connect to the master, obtain information about the other processes, and
finally handshake with them.
MASTER_PORT
: A free port on the machine that will host the
process with rank 0.
MASTER_ADDR
: IP address of the machine that will host the process
with rank 0.
WORLD_SIZE
: The total number of processes, so that the master
knows how many workers to wait for.
RANK
: Rank of each process, so they will know whether it is the
master or a worker.
Shared File System
The shared filesystem requires all processes to have access to a shared
file system, and will coordinate them through a shared file. This means
that each process will open the file, write its information, and wait
until everybody did so. After that all required information will be
readily available to all processes. In order to avoid race conditions,
the file system must support locking through
fcntl
.
dist
.
init_process_group
(
init_method
=
'file:///mnt/nfs/sharedfile'
,
rank
=
args
.
rank
,
world_size
=
4
)
TCP
Initializing via TCP can be achieved by providing the IP address of the process with rank 0 and a reachable port number.
Here, all workers will be able to connect to the process
with rank 0 and exchange information on how to reach each other.
dist
.
init_process_group
(
init_method
=
'tcp://10.1.1.20:23456'
,
rank
=
args
.
rank
,
world_size
=
4
)
Acknowledgements
Iād like to thank the PyTorch developers for doing such a good job on
their implementation, documentation, and tests. When the code was
unclear, I could always count on the
docs
or the
tests
to find an answer. In particular, Iād like to thank Soumith Chintala,
Adam Paszke, and Natalia Gimelshein for providing insightful comments
and answering questions on early drafts. |
| Markdown | 
[Skip to main content](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#main-content)
Back to top
[ ](https://docs.pytorch.org/tutorials/index.html)
[ ](https://docs.pytorch.org/tutorials/index.html)
[v2.11.0+cu130](https://docs.pytorch.org/tutorials/index.html)
- [Intro](https://docs.pytorch.org/tutorials/intro.html)
- [Learn the Basics](https://docs.pytorch.org/tutorials/beginner/basics/intro.html)
- [Introduction to PyTorch - YouTube Series](https://docs.pytorch.org/tutorials/beginner/introyt/introyt_index.html)
- [Deep Learning with PyTorch: A 60 Minute Blitz](https://docs.pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
- [Learning PyTorch with Examples](https://docs.pytorch.org/tutorials/beginner/pytorch_with_examples.html)
- [What is torch.nn really?](https://docs.pytorch.org/tutorials/beginner/nn_tutorial.html)
- [Understanding requires\_grad, retain\_grad, Leaf, and Non-leaf Tensors](https://docs.pytorch.org/tutorials/beginner/understanding_leaf_vs_nonleaf_tutorial.html)
- [NLP from Scratch](https://docs.pytorch.org/tutorials/intermediate/nlp_from_scratch_index.html)
- [Visualizing Models, Data, and Training with TensorBoard](https://docs.pytorch.org/tutorials/intermediate/tensorboard_tutorial.html)
- [A guide on good usage of non\_blocking and pin\_memory() in PyTorch](https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html)
- [Visualizing Gradients](https://docs.pytorch.org/tutorials/intermediate/visualizing_gradients_tutorial.html)
- [Compilers](https://docs.pytorch.org/tutorials/compilers_index.html)
- [Introduction to torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)
- [torch.compile End-to-End Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_compile_full_example.html)
- [Compiled Autograd: Capturing a larger backward graph for torch.compile](https://docs.pytorch.org/tutorials/intermediate/compiled_autograd_tutorial.html)
- [Inductor CPU backend debugging and profiling](https://docs.pytorch.org/tutorials/intermediate/inductor_debug_cpu.html)
- [Dynamic Compilation Control with torch.compiler.set\_stance](https://docs.pytorch.org/tutorials/recipes/torch_compiler_set_stance_tutorial.html)
- [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html)
- [(beta) Compiling the optimizer with torch.compile](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer.html)
- [(beta) Running the compiled optimizer with an LR Scheduler](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer_lr_scheduler.html)
- [Using Variable Length Attention in PyTorch](https://docs.pytorch.org/tutorials/intermediate/variable_length_attention_tutorial.html)
- [Using User-Defined Triton Kernels with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html)
- [Compile Time Caching in torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html)
- [Reducing torch.compile cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)
- [torch.export Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html)
- [torch.export AOTInductor Tutorial for Python runtime (Beta)](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html)
- [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html)
- [Introduction to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/intro_onnx.html)
- [Export a PyTorch model to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html)
- [Extending the ONNX Exporter Operator Support](https://docs.pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html)
- [Export a model with control flow to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html)
- [Building a Convolution/Batch Norm fuser with torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_conv_bn_fuser.html)
- [(beta) Building a Simple CPU Performance Profiler with FX](https://docs.pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html)
- [Domains](https://docs.pytorch.org/tutorials/domains.html)
- [TorchVision Object Detection Finetuning Tutorial](https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
- [Transfer Learning for Computer Vision Tutorial](https://docs.pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)
- [Adversarial Example Generation](https://docs.pytorch.org/tutorials/beginner/fgsm_tutorial.html)
- [DCGAN Tutorial](https://docs.pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html)
- [Spatial Transformer Networks Tutorial](https://docs.pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html)
- [Reinforcement Learning (DQN) Tutorial](https://docs.pytorch.org/tutorials/intermediate/reinforcement_q_learning.html)
- [Reinforcement Learning (PPO) with TorchRL Tutorial](https://docs.pytorch.org/tutorials/intermediate/reinforcement_ppo.html)
- [Train a Mario-playing RL Agent](https://docs.pytorch.org/tutorials/intermediate/mario_rl_tutorial.html)
- [Pendulum: Writing your environment and transforms with TorchRL](https://docs.pytorch.org/tutorials/advanced/pendulum.html)
- [Introduction to TorchRec](https://docs.pytorch.org/tutorials/intermediate/torchrec_intro_tutorial.html)
- [Exploring TorchRec sharding](https://docs.pytorch.org/tutorials/advanced/sharding.html)
- [Distributed](https://docs.pytorch.org/tutorials/distributed.html)
- [PyTorch Distributed Overview](https://docs.pytorch.org/tutorials/beginner/dist_overview.html)
- [Distributed Data Parallel in PyTorch - Video Tutorials](https://docs.pytorch.org/tutorials/beginner/ddp_series_intro.html)
- [Getting Started with Distributed Data Parallel](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html)
- [Writing Distributed Applications with PyTorch](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html)
- [Getting Started with Fully Sharded Data Parallel (FSDP2)](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- [Introduction to Libuv TCPStore Backend](https://docs.pytorch.org/tutorials/intermediate/TCPStore_libuv_backend.html)
- [Large Scale Transformer model training with Tensor Parallel (TP)](https://docs.pytorch.org/tutorials/intermediate/TP_tutorial.html)
- [Introduction to Distributed Pipeline Parallelism](https://docs.pytorch.org/tutorials/intermediate/pipelining_tutorial.html)
- [Customize Process Group Backends Using Cpp Extensions](https://docs.pytorch.org/tutorials/intermediate/process_group_cpp_extension_tutorial.html)
- [Getting Started with Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html)
- [Implementing a Parameter Server Using Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html)
- [Implementing Batch RPC Processing Using Asynchronous Executions](https://docs.pytorch.org/tutorials/intermediate/rpc_async_execution.html)
- [Interactive Distributed Applications with Monarch](https://docs.pytorch.org/tutorials/intermediate/monarch_distributed_tutorial.html)
- [Combining Distributed DataParallel with Distributed RPC Framework](https://docs.pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html)
- [Distributed Training with Uneven Inputs Using the Join Context Manager](https://docs.pytorch.org/tutorials/advanced/generic_join.html)
- [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html)
- [Deep Dive](https://docs.pytorch.org/tutorials/deep-dive.html)
- [Profiling your PyTorch Module](https://docs.pytorch.org/tutorials/beginner/profiler.html)
- [Parametrizations Tutorial](https://docs.pytorch.org/tutorials/intermediate/parametrizations.html)
- [Pruning Tutorial](https://docs.pytorch.org/tutorials/intermediate/pruning_tutorial.html)
- [Inductor CPU backend debugging and profiling](https://docs.pytorch.org/tutorials/intermediate/inductor_debug_cpu.html)
- [(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html)
- [Knowledge Distillation Tutorial](https://docs.pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html)
- [Channels Last Memory Format in PyTorch](https://docs.pytorch.org/tutorials/intermediate/memory_format_tutorial.html)
- [Forward-mode Automatic Differentiation (Beta)](https://docs.pytorch.org/tutorials/intermediate/forward_ad_usage.html)
- [Jacobians, Hessians, hvp, vhp, and more: composing function transforms](https://docs.pytorch.org/tutorials/intermediate/jacobians_hessians.html)
- [Model ensembling](https://docs.pytorch.org/tutorials/intermediate/ensembling.html)
- [Per-sample-gradients](https://docs.pytorch.org/tutorials/intermediate/per_sample_grads.html)
- [Using the PyTorch C++ Frontend](https://docs.pytorch.org/tutorials/advanced/cpp_frontend.html)
- [Autograd in C++ Frontend](https://docs.pytorch.org/tutorials/advanced/cpp_autograd.html)
- [Extension](https://docs.pytorch.org/tutorials/extension.html)
- [PyTorch Custom Operators](https://docs.pytorch.org/tutorials/advanced/custom_ops_landing_page.html)
- [Custom Python Operators](https://docs.pytorch.org/tutorials/advanced/python_custom_ops.html)
- [Custom C++ and CUDA Operators](https://docs.pytorch.org/tutorials/advanced/cpp_custom_ops.html)
- [Double Backward with Custom Functions](https://docs.pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html)
- [Fusing Convolution and Batch Norm using Custom Function](https://docs.pytorch.org/tutorials/intermediate/custom_function_conv_bn_tutorial.html)
- [Registering a Dispatched Operator in C++](https://docs.pytorch.org/tutorials/advanced/dispatcher.html)
- [Extending dispatcher for a new backend in C++](https://docs.pytorch.org/tutorials/advanced/extend_dispatcher.html)
- [Facilitating New Backend Integration by PrivateUse1](https://docs.pytorch.org/tutorials/advanced/privateuseone.html)
- [Ecosystem](https://docs.pytorch.org/tutorials/ecosystem.html)
- [Hyperparameter tuning using Ray Tune](https://docs.pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html)
- [Serve PyTorch models at scale with Ray Serve](https://docs.pytorch.org/tutorials/beginner/serving_tutorial.html)
- [Multi-Objective NAS with Ax](https://docs.pytorch.org/tutorials/intermediate/ax_multiobjective_nas_tutorial.html)
- [PyTorch Profiler With TensorBoard](https://docs.pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html)
- [Real Time Inference on Raspberry Pi 4 and 5 (40 fps!)](https://docs.pytorch.org/tutorials/intermediate/realtime_rpi.html)
- [Mosaic: Memory Profiling for PyTorch](https://docs.pytorch.org/tutorials/beginner/mosaic_memory_profiling_tutorial.html)
- [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html)
- More
- [Recipes](https://docs.pytorch.org/tutorials/recipes_index.html)
- [Unstable](https://docs.pytorch.org/tutorials/unstable_index.html)
[Go to pytorch.org](https://pytorch.org/)
- [X](https://x.com/PyTorch)
- [GitHub](https://github.com/pytorch/tutorials)
- [Discourse](https://dev-discuss.pytorch.org/)
- [PyPi](https://pypi.org/project/torch/)
[v2.11.0+cu130](https://docs.pytorch.org/tutorials/index.html)
- [Intro](https://docs.pytorch.org/tutorials/intro.html)
- [Learn the Basics](https://docs.pytorch.org/tutorials/beginner/basics/intro.html)
- [Introduction to PyTorch - YouTube Series](https://docs.pytorch.org/tutorials/beginner/introyt/introyt_index.html)
- [Deep Learning with PyTorch: A 60 Minute Blitz](https://docs.pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
- [Learning PyTorch with Examples](https://docs.pytorch.org/tutorials/beginner/pytorch_with_examples.html)
- [What is torch.nn really?](https://docs.pytorch.org/tutorials/beginner/nn_tutorial.html)
- [Understanding requires\_grad, retain\_grad, Leaf, and Non-leaf Tensors](https://docs.pytorch.org/tutorials/beginner/understanding_leaf_vs_nonleaf_tutorial.html)
- [NLP from Scratch](https://docs.pytorch.org/tutorials/intermediate/nlp_from_scratch_index.html)
- [Visualizing Models, Data, and Training with TensorBoard](https://docs.pytorch.org/tutorials/intermediate/tensorboard_tutorial.html)
- [A guide on good usage of non\_blocking and pin\_memory() in PyTorch](https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html)
- [Visualizing Gradients](https://docs.pytorch.org/tutorials/intermediate/visualizing_gradients_tutorial.html)
- [Compilers](https://docs.pytorch.org/tutorials/compilers_index.html)
- [Introduction to torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)
- [torch.compile End-to-End Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_compile_full_example.html)
- [Compiled Autograd: Capturing a larger backward graph for torch.compile](https://docs.pytorch.org/tutorials/intermediate/compiled_autograd_tutorial.html)
- [Inductor CPU backend debugging and profiling](https://docs.pytorch.org/tutorials/intermediate/inductor_debug_cpu.html)
- [Dynamic Compilation Control with torch.compiler.set\_stance](https://docs.pytorch.org/tutorials/recipes/torch_compiler_set_stance_tutorial.html)
- [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html)
- [(beta) Compiling the optimizer with torch.compile](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer.html)
- [(beta) Running the compiled optimizer with an LR Scheduler](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer_lr_scheduler.html)
- [Using Variable Length Attention in PyTorch](https://docs.pytorch.org/tutorials/intermediate/variable_length_attention_tutorial.html)
- [Using User-Defined Triton Kernels with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html)
- [Compile Time Caching in torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html)
- [Reducing torch.compile cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)
- [torch.export Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html)
- [torch.export AOTInductor Tutorial for Python runtime (Beta)](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html)
- [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html)
- [Introduction to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/intro_onnx.html)
- [Export a PyTorch model to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html)
- [Extending the ONNX Exporter Operator Support](https://docs.pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html)
- [Export a model with control flow to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html)
- [Building a Convolution/Batch Norm fuser with torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_conv_bn_fuser.html)
- [(beta) Building a Simple CPU Performance Profiler with FX](https://docs.pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html)
- [Domains](https://docs.pytorch.org/tutorials/domains.html)
- [TorchVision Object Detection Finetuning Tutorial](https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
- [Transfer Learning for Computer Vision Tutorial](https://docs.pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)
- [Adversarial Example Generation](https://docs.pytorch.org/tutorials/beginner/fgsm_tutorial.html)
- [DCGAN Tutorial](https://docs.pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html)
- [Spatial Transformer Networks Tutorial](https://docs.pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html)
- [Reinforcement Learning (DQN) Tutorial](https://docs.pytorch.org/tutorials/intermediate/reinforcement_q_learning.html)
- [Reinforcement Learning (PPO) with TorchRL Tutorial](https://docs.pytorch.org/tutorials/intermediate/reinforcement_ppo.html)
- [Train a Mario-playing RL Agent](https://docs.pytorch.org/tutorials/intermediate/mario_rl_tutorial.html)
- [Pendulum: Writing your environment and transforms with TorchRL](https://docs.pytorch.org/tutorials/advanced/pendulum.html)
- [Introduction to TorchRec](https://docs.pytorch.org/tutorials/intermediate/torchrec_intro_tutorial.html)
- [Exploring TorchRec sharding](https://docs.pytorch.org/tutorials/advanced/sharding.html)
- [Distributed](https://docs.pytorch.org/tutorials/distributed.html)
- [PyTorch Distributed Overview](https://docs.pytorch.org/tutorials/beginner/dist_overview.html)
- [Distributed Data Parallel in PyTorch - Video Tutorials](https://docs.pytorch.org/tutorials/beginner/ddp_series_intro.html)
- [Getting Started with Distributed Data Parallel](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html)
- [Writing Distributed Applications with PyTorch](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html)
- [Getting Started with Fully Sharded Data Parallel (FSDP2)](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- [Introduction to Libuv TCPStore Backend](https://docs.pytorch.org/tutorials/intermediate/TCPStore_libuv_backend.html)
- [Large Scale Transformer model training with Tensor Parallel (TP)](https://docs.pytorch.org/tutorials/intermediate/TP_tutorial.html)
- [Introduction to Distributed Pipeline Parallelism](https://docs.pytorch.org/tutorials/intermediate/pipelining_tutorial.html)
- [Customize Process Group Backends Using Cpp Extensions](https://docs.pytorch.org/tutorials/intermediate/process_group_cpp_extension_tutorial.html)
- [Getting Started with Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html)
- [Implementing a Parameter Server Using Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html)
- [Implementing Batch RPC Processing Using Asynchronous Executions](https://docs.pytorch.org/tutorials/intermediate/rpc_async_execution.html)
- [Interactive Distributed Applications with Monarch](https://docs.pytorch.org/tutorials/intermediate/monarch_distributed_tutorial.html)
- [Combining Distributed DataParallel with Distributed RPC Framework](https://docs.pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html)
- [Distributed Training with Uneven Inputs Using the Join Context Manager](https://docs.pytorch.org/tutorials/advanced/generic_join.html)
- [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html)
- [Deep Dive](https://docs.pytorch.org/tutorials/deep-dive.html)
- [Profiling your PyTorch Module](https://docs.pytorch.org/tutorials/beginner/profiler.html)
- [Parametrizations Tutorial](https://docs.pytorch.org/tutorials/intermediate/parametrizations.html)
- [Pruning Tutorial](https://docs.pytorch.org/tutorials/intermediate/pruning_tutorial.html)
- [Inductor CPU backend debugging and profiling](https://docs.pytorch.org/tutorials/intermediate/inductor_debug_cpu.html)
- [(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html)
- [Knowledge Distillation Tutorial](https://docs.pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html)
- [Channels Last Memory Format in PyTorch](https://docs.pytorch.org/tutorials/intermediate/memory_format_tutorial.html)
- [Forward-mode Automatic Differentiation (Beta)](https://docs.pytorch.org/tutorials/intermediate/forward_ad_usage.html)
- [Jacobians, Hessians, hvp, vhp, and more: composing function transforms](https://docs.pytorch.org/tutorials/intermediate/jacobians_hessians.html)
- [Model ensembling](https://docs.pytorch.org/tutorials/intermediate/ensembling.html)
- [Per-sample-gradients](https://docs.pytorch.org/tutorials/intermediate/per_sample_grads.html)
- [Using the PyTorch C++ Frontend](https://docs.pytorch.org/tutorials/advanced/cpp_frontend.html)
- [Autograd in C++ Frontend](https://docs.pytorch.org/tutorials/advanced/cpp_autograd.html)
- [Extension](https://docs.pytorch.org/tutorials/extension.html)
- [PyTorch Custom Operators](https://docs.pytorch.org/tutorials/advanced/custom_ops_landing_page.html)
- [Custom Python Operators](https://docs.pytorch.org/tutorials/advanced/python_custom_ops.html)
- [Custom C++ and CUDA Operators](https://docs.pytorch.org/tutorials/advanced/cpp_custom_ops.html)
- [Double Backward with Custom Functions](https://docs.pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html)
- [Fusing Convolution and Batch Norm using Custom Function](https://docs.pytorch.org/tutorials/intermediate/custom_function_conv_bn_tutorial.html)
- [Registering a Dispatched Operator in C++](https://docs.pytorch.org/tutorials/advanced/dispatcher.html)
- [Extending dispatcher for a new backend in C++](https://docs.pytorch.org/tutorials/advanced/extend_dispatcher.html)
- [Facilitating New Backend Integration by PrivateUse1](https://docs.pytorch.org/tutorials/advanced/privateuseone.html)
- [Ecosystem](https://docs.pytorch.org/tutorials/ecosystem.html)
- [Hyperparameter tuning using Ray Tune](https://docs.pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html)
- [Serve PyTorch models at scale with Ray Serve](https://docs.pytorch.org/tutorials/beginner/serving_tutorial.html)
- [Multi-Objective NAS with Ax](https://docs.pytorch.org/tutorials/intermediate/ax_multiobjective_nas_tutorial.html)
- [PyTorch Profiler With TensorBoard](https://docs.pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html)
- [Real Time Inference on Raspberry Pi 4 and 5 (40 fps!)](https://docs.pytorch.org/tutorials/intermediate/realtime_rpi.html)
- [Mosaic: Memory Profiling for PyTorch](https://docs.pytorch.org/tutorials/beginner/mosaic_memory_profiling_tutorial.html)
- [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html)
- [Recipes](https://docs.pytorch.org/tutorials/recipes_index.html)
- [Defining a Neural Network in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/defining_a_neural_network.html)
- [(beta) Using TORCH\_LOGS python API with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_logs.html)
- [What is a state\_dict in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html)
- [Warmstarting model using parameters from a different model in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/warmstarting_model_using_parameters_from_a_different_model.html)
- [Zeroing out gradients in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/zeroing_out_gradients.html)
- [PyTorch Profiler](https://docs.pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
- [Model Interpretability using Captum](https://docs.pytorch.org/tutorials/recipes/recipes/Captum_Recipe.html)
- [How to use TensorBoard with PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html)
- [Automatic Mixed Precision](https://docs.pytorch.org/tutorials/recipes/recipes/amp_recipe.html)
- [Performance Tuning Guide](https://docs.pytorch.org/tutorials/recipes/recipes/tuning_guide.html)
- [(beta) Compiling the optimizer with torch.compile](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer.html)
- [Timer quick start](https://docs.pytorch.org/tutorials/recipes/recipes/timer_quick_start.html)
- [Shard Optimizer States with ZeroRedundancyOptimizer](https://docs.pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html)
- [Getting Started with CommDebugMode](https://docs.pytorch.org/tutorials/recipes/distributed_comm_debug_mode.html)
- [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html)
- [PyTorch Benchmark](https://docs.pytorch.org/tutorials/recipes/recipes/benchmark.html)
- [Tips for Loading an nn.Module from a Checkpoint](https://docs.pytorch.org/tutorials/recipes/recipes/module_load_state_dict_tips.html)
- [Reasoning about Shapes in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/reasoning_about_shapes.html)
- [Extension points in nn.Module for load\_state\_dict and tensor subclasses](https://docs.pytorch.org/tutorials/recipes/recipes/swap_tensors.html)
- [torch.export AOTInductor Tutorial for Python runtime (Beta)](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html)
- [How to use TensorBoard with PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html)
- [(beta) Utilizing Torch Function modes with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_torch_function_modes.html)
- [(beta) Running the compiled optimizer with an LR Scheduler](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer_lr_scheduler.html)
- [Explicit horizontal fusion with foreach\_map and torch.compile](https://docs.pytorch.org/tutorials/recipes/foreach_map.html)
- [Using User-Defined Triton Kernels with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html)
- [Compile Time Caching in torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html)
- [Compile Time Caching Configuration](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_configuration_tutorial.html)
- [Reducing torch.compile cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)
- [Reducing AoT cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_aot.html)
- [Ease-of-use quantization for PyTorch with IntelĀ® Neural Compressor](https://docs.pytorch.org/tutorials/recipes/intel_neural_compressor_for_pytorch.html)
- [Getting Started with DeviceMesh](https://docs.pytorch.org/tutorials/recipes/distributed_device_mesh.html)
- [Getting Started with Distributed Checkpoint (DCP)](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html)
- [Asynchronous Saving with Distributed Checkpoint (DCP)](https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html)
- [DebugMode: Recording Dispatched Operations and Numerical Debugging](https://docs.pytorch.org/tutorials/recipes/debug_mode_tutorial.html)
- [Unstable](https://docs.pytorch.org/tutorials/unstable_index.html)
- [Introduction to Context Parallel](https://docs.pytorch.org/tutorials/unstable/context_parallel.html)
- [Flight Recorder for Debugging Stuck Jobs](https://docs.pytorch.org/tutorials/unstable/flight_recorder_tutorial.html)
- [TorchInductor C++ Wrapper Tutorial](https://docs.pytorch.org/tutorials/unstable/inductor_cpp_wrapper_tutorial.html)
- [How to use torch.compile on Windows CPU/XPU](https://docs.pytorch.org/tutorials/unstable/inductor_windows.html)
- [torch.vmap](https://docs.pytorch.org/tutorials/unstable/vmap_recipe.html)
- [Getting Started with Nested Tensors](https://docs.pytorch.org/tutorials/unstable/nestedtensor.html)
- [MaskedTensor Overview](https://docs.pytorch.org/tutorials/unstable/maskedtensor_overview.html)
- [MaskedTensor Sparsity](https://docs.pytorch.org/tutorials/unstable/maskedtensor_sparsity.html)
- [MaskedTensor Advanced Semantics](https://docs.pytorch.org/tutorials/unstable/maskedtensor_advanced_semantics.html)
- [Efficiently writing āsparseā semantics for Adagrad with MaskedTensor](https://docs.pytorch.org/tutorials/unstable/maskedtensor_adagrad.html)
- [Autoloading Out-of-Tree Extension](https://docs.pytorch.org/tutorials/unstable/python_extension_autoload.html)
- [Using Max-Autotune Compilation on CPU for Better Performance](https://docs.pytorch.org/tutorials/unstable/max_autotune_on_CPU_tutorial.html)
[Go to pytorch.org](https://pytorch.org/)
- [X](https://x.com/PyTorch)
- [GitHub](https://github.com/pytorch/tutorials)
- [Discourse](https://dev-discuss.pytorch.org/)
- [PyPi](https://pypi.org/project/torch/)
Section Navigation
- [PyTorch Distributed Overview](https://docs.pytorch.org/tutorials/beginner/dist_overview.html)
- [Distributed Data Parallel in PyTorch - Video Tutorials](https://docs.pytorch.org/tutorials/beginner/ddp_series_intro.html)
- [Getting Started with Distributed Data Parallel](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html)
- [Writing Distributed Applications with PyTorch](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html)
- [Getting Started with Fully Sharded Data Parallel (FSDP2)](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- [Introduction to Libuv TCPStore Backend](https://docs.pytorch.org/tutorials/intermediate/TCPStore_libuv_backend.html)
- [Large Scale Transformer model training with Tensor Parallel (TP)](https://docs.pytorch.org/tutorials/intermediate/TP_tutorial.html)
- [Introduction to Distributed Pipeline Parallelism](https://docs.pytorch.org/tutorials/intermediate/pipelining_tutorial.html)
- [Customize Process Group Backends Using Cpp Extensions](https://docs.pytorch.org/tutorials/intermediate/process_group_cpp_extension_tutorial.html)
- [Getting Started with Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html)
- [Implementing a Parameter Server Using Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html)
- [Implementing Batch RPC Processing Using Asynchronous Executions](https://docs.pytorch.org/tutorials/intermediate/rpc_async_execution.html)
- [Interactive Distributed Applications with Monarch](https://docs.pytorch.org/tutorials/intermediate/monarch_distributed_tutorial.html)
- [Combining Distributed DataParallel with Distributed RPC Framework](https://docs.pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html)
- [Distributed Training with Uneven Inputs Using the Join Context Manager](https://docs.pytorch.org/tutorials/advanced/generic_join.html)
- [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html)
- [Distributed](https://docs.pytorch.org/tutorials/distributed.html)
- Writing...
Rate this Page
ā
ā
ā
ā
ā
intermediate/dist\_tuto
[ Run in Google Colab Colab]()
[ Download Notebook Notebook]()
[ View on GitHub GitHub]()
# Writing Distributed Applications with PyTorch[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#writing-distributed-applications-with-pytorch "Link to this heading")
Created On: Oct 06, 2017 \| Last Updated: Sep 05, 2025 \| Last Verified: Nov 05, 2024
**Author**: [SƩb Arnold](https://seba1511.com/)
Note
[](https://docs.pytorch.org/tutorials/_images/pencil-16.png) View and edit this tutorial in [github](https://github.com/pytorch/tutorials/blob/main/intermediate_source/dist_tuto.rst).
Prerequisites:
- [PyTorch Distributed Overview](https://docs.pytorch.org/tutorials/beginner/dist_overview.html)
In this short tutorial, we will be going over the distributed package of PyTorch. Weāll see how to set up the distributed setting, use the different communication strategies, and go over some of the internals of the package.
## Setup[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#setup "Link to this heading")
The distributed package included in PyTorch (i.e., `torch.distributed`) enables researchers and practitioners to easily parallelize their computations across processes and clusters of machines. To do so, it leverages message passing semantics allowing each process to communicate data to any of the other processes. As opposed to the multiprocessing (`torch.multiprocessing`) package, processes can use different communication backends and are not restricted to being executed on the same machine.
In order to get started we need the ability to run multiple processes simultaneously. If you have access to compute cluster you should check with your local sysadmin or use your favorite coordination tool (e.g., [pdsh](https://linux.die.net/man/1/pdsh), [clustershell](https://cea-hpc.github.io/clustershell/), or [slurm](https://slurm.schedmd.com/)). For the purpose of this tutorial, we will use a single machine and spawn multiple processes using the following template.
```
"""run.py:"""
#!/usr/bin/env python
import os
import sys
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def run(rank, size):
""" Distributed function to be implemented later. """
pass
def init_process(rank, size, fn, backend='gloo'):
""" Initialize the distributed environment. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)
if __name__ == "__main__":
world_size = 2
processes = []
if "google.colab" in sys.modules:
print("Running in Google Colab")
mp.get_context("spawn")
else:
mp.set_start_method("spawn")
for rank in range(world_size):
p = mp.Process(target=init_process, args=(rank, world_size, run))
p.start()
processes.append(p)
for p in processes:
p.join()
```
The above script spawns two processes who will each setup the distributed environment, initialize the process group (`dist.init_process_group`), and finally execute the given `run` function.
Letās have a look at the `init_process` function. It ensures that every process will be able to coordinate through a master, using the same ip address and port. Note that we used the `gloo` backend but other backends are available. (c.f. [Section 5.1](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#communication-backends)) We will go over the magic happening in `dist.init_process_group` at the end of this tutorial, but it essentially allows processes to communicate with each other by sharing their locations.
## Point-to-Point Communication[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#point-to-point-communication "Link to this heading")
[](https://docs.pytorch.org/tutorials/_images/send_recv.png)
Send and Recv[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#id1 "Link to this image")
A transfer of data from one process to another is called a point-to-point communication. These are achieved through the `send` and `recv` functions or their *immediate* counter-parts, `isend` and `irecv`.
```
"""Blocking point-to-point communication."""
def run(rank, size):
tensor = torch.zeros(1)
if rank == 0:
tensor += 1
# Send the tensor to process 1
dist.send(tensor=tensor, dst=1)
else:
# Receive tensor from process 0
dist.recv(tensor=tensor, src=0)
print('Rank ', rank, ' has data ', tensor[0])
```
In the above example, both processes start with a zero tensor, then process 0 increments the tensor and sends it to process 1 so that they both end up with 1.0. Notice that process 1 needs to allocate memory in order to store the data it will receive.
Also notice that `send/recv` are **blocking**: both processes block until the communication is completed. On the other hand immediates are **non-blocking**; the script continues its execution and the methods return a `Work` object upon which we can choose to `wait()`.
```
"""Non-blocking point-to-point communication."""
def run(rank, size):
tensor = torch.zeros(1)
req = None
if rank == 0:
tensor += 1
# Send the tensor to process 1
req = dist.isend(tensor=tensor, dst=1)
print('Rank 0 started sending')
else:
# Receive tensor from process 0
req = dist.irecv(tensor=tensor, src=0)
print('Rank 1 started receiving')
req.wait()
print('Rank ', rank, ' has data ', tensor[0])
```
When using immediates we have to be careful about how we use the sent and received tensors. Since we do not know when the data will be communicated to the other process, we should not modify the sent tensor nor access the received tensor before `req.wait()` has completed. In other words,
- writing to `tensor` after `dist.isend()` will result in undefined behaviour.
- reading from `tensor` after `dist.irecv()` will result in undefined behaviour, until `req.wait()` has been executed.
However, after `req.wait()` has been executed we are guaranteed that the communication took place, and that the value stored in `tensor[0]` is 1.0.
Point-to-point communication is useful when we want more fine-grained control over the communication of our processes. They can be used to implement fancy algorithms, such as the one used in [Baiduās DeepSpeech](https://github.com/baidu-research/baidu-allreduce) or [Facebookās large-scale experiments](https://research.fb.com/publications/imagenet1kin1h/).(c.f. [Section 4.1](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#our-own-ring-allreduce))
## Collective Communication[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#collective-communication "Link to this heading")
| | |
|---|---|
| [](https://docs.pytorch.org/tutorials/_images/scatter.png) Scatter[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#id2 "Link to this image") | [](https://docs.pytorch.org/tutorials/_images/gather.png) Gather[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#id3 "Link to this image") |
| [](https://docs.pytorch.org/tutorials/_images/reduce.png) Reduce[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#id4 "Link to this image") | [](https://docs.pytorch.org/tutorials/_images/all_reduce.png) All-Reduce[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#id5 "Link to this image") |
| [](https://docs.pytorch.org/tutorials/_images/broadcast.png) Broadcast[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#id6 "Link to this image") | [](https://docs.pytorch.org/tutorials/_images/all_gather.png) All-Gather[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#id7 "Link to this image") |
As opposed to point-to-point communcation, collectives allow for communication patterns across all processes in a **group**. A group is a subset of all our processes. To create a group, we can pass a list of ranks to `dist.new_group(group)`. By default, collectives are executed on all processes, also known as the **world**. For example, in order to obtain the sum of all tensors on all processes, we can use the `dist.all_reduce(tensor, op, group)` collective.
```
""" All-Reduce example."""
def run(rank, size):
""" Simple collective communication. """
group = dist.new_group([0, 1])
tensor = torch.ones(1)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
print('Rank ', rank, ' has data ', tensor[0])
```
Since we want the sum of all tensors in the group, we use `dist.ReduceOp.SUM` as the reduce operator. Generally speaking, any commutative mathematical operation can be used as an operator. Out-of-the-box, PyTorch comes with many such operators, all working at the element-wise level:
- `dist.ReduceOp.SUM`,
- `dist.ReduceOp.PRODUCT`,
- `dist.ReduceOp.MAX`,
- `dist.ReduceOp.MIN`,
- `dist.ReduceOp.BAND`,
- `dist.ReduceOp.BOR`,
- `dist.ReduceOp.BXOR`,
- `dist.ReduceOp.PREMUL_SUM`.
The full list of supported operators is [here](https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp).
In addition to `dist.all_reduce(tensor, op, group)`, there are many additional collectives currently implemented in PyTorch. Here are a few supported collectives.
- `dist.broadcast(tensor, src, group)`: Copies `tensor` from `src` to all other processes.
- `dist.reduce(tensor, dst, op, group)`: Applies `op` to every `tensor` and stores the result in `dst`.
- `dist.all_reduce(tensor, op, group)`: Same as reduce, but the result is stored in all processes.
- `dist.scatter(tensor, scatter_list, src, group)`: Copies the i th i^{\\text{th}} ith tensor `scatter_list[i]` to the i th i^{\\text{th}} ith process.
- `dist.gather(tensor, gather_list, dst, group)`: Copies `tensor` from all processes in `dst`.
- `dist.all_gather(tensor_list, tensor, group)`: Copies `tensor` from all processes to `tensor_list`, on all processes.
- `dist.barrier(group)`: Blocks all processes in group until each one has entered this function.
- `dist.all_to_all(output_tensor_list, input_tensor_list, group)`: Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list.
The full list of supported collectives can be found by looking at the latest documentation for PyTorch Distributed [(link)](https://pytorch.org/docs/stable/distributed.html).
## Distributed Training[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#distributed-training "Link to this heading")
**Note:** You can find the example script of this section in [this GitHub repository](https://github.com/seba-1511/dist_tuto.pth/).
Now that we understand how the distributed module works, let us write something useful with it. Our goal will be to replicate the functionality of [DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel). Of course, this will be a didactic example and in a real-world situation you should use the official, well-tested and well-optimized version linked above.
Quite simply we want to implement a distributed version of stochastic gradient descent. Our script will let all processes compute the gradients of their model on their batch of data and then average their gradients. In order to ensure similar convergence results when changing the number of processes, we will first have to partition our dataset. (You could also use [torch.utils.data.random\_split](https://pytorch.org/docs/stable/data.html#torch.utils.data.random_split), instead of the snippet below.)
```
""" Dataset partitioning helper """
class Partition(object):
def __init__(self, data, index):
self.data = data
self.index = index
def __len__(self):
return len(self.index)
def __getitem__(self, index):
data_idx = self.index[index]
return self.data[data_idx]
class DataPartitioner(object):
def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234):
self.data = data
self.partitions = []
rng = Random() # from random import Random
rng.seed(seed)
data_len = len(data)
indexes = [x for x in range(0, data_len)]
rng.shuffle(indexes)
for frac in sizes:
part_len = int(frac * data_len)
self.partitions.append(indexes[0:part_len])
indexes = indexes[part_len:]
def use(self, partition):
return Partition(self.data, self.partitions[partition])
```
With the above snippet, we can now simply partition any dataset using the following few lines:
```
""" Partitioning MNIST """
def partition_dataset():
dataset = datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
size = dist.get_world_size()
bsz = 128 // size
partition_sizes = [1.0 / size for _ in range(size)]
partition = DataPartitioner(dataset, partition_sizes)
partition = partition.use(dist.get_rank())
train_set = torch.utils.data.DataLoader(partition,
batch_size=bsz,
shuffle=True)
return train_set, bsz
```
Assuming we have 2 replicas, then each process will have a `train_set` of 60000 / 2 = 30000 samples. We also divide the batch size by the number of replicas in order to maintain the *overall* batch size of 128.
We can now write our usual forward-backward-optimize training code, and add a function call to average the gradients of our models. (The following is largely inspired by the official [PyTorch MNIST example](https://github.com/pytorch/examples/blob/master/mnist/main.py).)
```
""" Distributed Synchronous SGD Example """
def run(rank, size):
torch.manual_seed(1234)
train_set, bsz = partition_dataset()
model = Net()
optimizer = optim.SGD(model.parameters(),
lr=0.01, momentum=0.5)
num_batches = ceil(len(train_set.dataset) / float(bsz))
for epoch in range(10):
epoch_loss = 0.0
for data, target in train_set:
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
epoch_loss += loss.item()
loss.backward()
average_gradients(model)
optimizer.step()
print('Rank ', dist.get_rank(), ', epoch ',
epoch, ': ', epoch_loss / num_batches)
```
It remains to implement the `average_gradients(model)` function, which simply takes in a model and averages its gradients across the whole world.
```
""" Gradient averaging. """
def average_gradients(model):
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= size
```
*Et voilĆ *! We successfully implemented distributed synchronous SGD and could train any model on a large computer cluster.
**Note:** While the last sentence is *technically* true, there are [a lot more tricks](https://seba-1511.github.io/dist_blog) required to implement a production-level implementation of synchronous SGD. Again, use what [has been tested and optimized](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel).
### Our Own Ring-Allreduce[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#our-own-ring-allreduce "Link to this heading")
As an additional challenge, imagine that we wanted to implement DeepSpeechās efficient ring allreduce. This is fairly easy to implement using point-to-point collectives.
```
""" Implementation of a ring-reduce with addition. """
def allreduce(send, recv):
rank = dist.get_rank()
size = dist.get_world_size()
send_buff = send.clone()
recv_buff = send.clone()
accum = send.clone()
left = ((rank - 1) + size) % size
right = (rank + 1) % size
for i in range(size - 1):
if i % 2 == 0:
# Send send_buff
send_req = dist.isend(send_buff, right)
dist.recv(recv_buff, left)
accum[:] += recv_buff[:]
else:
# Send recv_buff
send_req = dist.isend(recv_buff, right)
dist.recv(send_buff, left)
accum[:] += send_buff[:]
send_req.wait()
recv[:] = accum[:]
```
In the above script, the `allreduce(send, recv)` function has a slightly different signature than the ones in PyTorch. It takes a `recv` tensor and will store the sum of all `send` tensors in it. As an exercise left to the reader, there is still one difference between our version and the one in DeepSpeech: their implementation divides the gradient tensor into *chunks*, so as to optimally utilize the communication bandwidth. (Hint: [torch.chunk](https://pytorch.org/docs/stable/torch.html#torch.chunk))
## Advanced Topics[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#advanced-topics "Link to this heading")
We are now ready to discover some of the more advanced functionalities of `torch.distributed`. Since there is a lot to cover, this section is divided into two subsections:
1. Communication Backends: where we learn how to use MPI and Gloo for GPU-GPU communication.
2. Initialization Methods: where we understand how to best set up the initial coordination phase in `dist.init_process_group()`.
### Communication Backends[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#communication-backends "Link to this heading")
One of the most elegant aspects of `torch.distributed` is its ability to abstract and build on top of different backends. As mentioned before, there are multiple backends implemented in PyTorch. These backends can be easily selected using the [Accelerator API](https://pytorch.org/docs/stable/torch.html#accelerators), which provides a interface for working with different accelerator types. Some of the most popular backends are Gloo, NCCL, and MPI. They each have different specifications and tradeoffs, depending on the desired use case. A comparative table of supported functions can be found [here](https://pytorch.org/docs/stable/distributed.html#module-torch.distributed).
**Gloo Backend**
So far we have made extensive usage of the [Gloo backend](https://github.com/facebookincubator/gloo). It is quite handy as a development platform, as it is included in the pre-compiled PyTorch binaries and works on both Linux (since 0.2) and macOS (since 1.3). It supports all point-to-point and collective operations on CPU, and all collective operations on GPU. The implementation of the collective operations for CUDA tensors is not as optimized as the ones provided by the NCCL backend.
As you have surely noticed, our distributed SGD example does not work if you put `model` on the GPU. In order to use multiple GPUs, let us also make the following modifications:
1. Use Accelerator API `device_type = torch.accelerator.current_accelerator()`
2. Use `torch.device(f"{device_type}:{rank}")`
3. `model = Net()` ā \\rightarrow ā `model = Net().to(device)`
4. Use `data, target = data.to(device), target.to(device)`
With these modifications, your model will now train across two GPUs. You can monitor GPU utilization using `watch nvidia-smi` if you are running on NVIDIA hardware.
**MPI Backend**
The Message Passing Interface (MPI) is a standardized tool from the field of high-performance computing. It allows to do point-to-point and collective communications and was the main inspiration for the API of `torch.distributed`. Several implementations of MPI exist (e.g. [Open-MPI](https://www.open-mpi.org/), [MVAPICH2](http://mvapich.cse.ohio-state.edu/), [Intel MPI](https://software.intel.com/en-us/intel-mpi-library)) each optimized for different purposes. The advantage of using the MPI backend lies in MPIās wide availability - and high-level of optimization - on large computer clusters. [Some](https://developer.nvidia.com/mvapich) [recent](https://developer.nvidia.com/ibm-spectrum-mpi) [implementations](https://www.open-mpi.org/) are also able to take advantage of CUDA IPC and GPU Direct technologies in order to avoid memory copies through the CPU.
Unfortunately, PyTorchās binaries cannot include an MPI implementation and weāll have to recompile it by hand. Fortunately, this process is fairly simple given that upon compilation, PyTorch will look *by itself* for an available MPI implementation. The following steps install the MPI backend, by installing PyTorch [from source](https://github.com/pytorch/pytorch#from-source).
1. Create and activate your Anaconda environment, install all the pre-requisites following [the guide](https://github.com/pytorch/pytorch#from-source), but do **not** run `python setup.py install` yet.
2. Choose and install your favorite MPI implementation. Note that enabling CUDA-aware MPI might require some additional steps. In our case, weāll stick to Open-MPI *without* GPU support: `conda install -c conda-forge openmpi`
3. Now, go to your cloned PyTorch repo and execute `python setup.py install`.
In order to test our newly installed backend, a few modifications are required.
1. Replace the content under `if __name__ == '__main__':` with `init_process(0, 0, run, backend='mpi')`.
2. Run `mpirun -n 4 python myscript.py`.
The reason for these changes is that MPI needs to create its own environment before spawning the processes. MPI will also spawn its own processes and perform the handshake described in [Initialization Methods](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#initialization-methods), making the `rank`and `size` arguments of `init_process_group` superfluous. This is actually quite powerful as you can pass additional arguments to `mpirun` in order to tailor computational resources for each process. (Things like number of cores per process, hand-assigning machines to specific ranks, and [some more](https://www.open-mpi.org/faq/?category=running#mpirun-hostfile)) Doing so, you should obtain the same familiar output as with the other communication backends.
**NCCL Backend**
The [NCCL backend](https://github.com/nvidia/nccl) provides an optimized implementation of collective operations against CUDA tensors. If you only use CUDA tensors for your collective operations, consider using this backend for the best in class performance. The NCCL backend is included in the pre-built binaries with CUDA support.
**XCCL Backend**
The XCCL backend offers an optimized implementation of collective operations for XPU tensors. If your workload uses only XPU tensors for collective operations, this backend provides best-in-class performance. The XCCL backend is included in the pre-built binaries with XPU support.
### Initialization Methods[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#initialization-methods "Link to this heading")
To conclude this tutorial, letās examine the initial function we invoked: `dist.init_process_group(backend, init_method)`. Specifically, we will discuss the various initialization methods responsible for the preliminary coordination step between each process. These methods enable you to define how this coordination is accomplished.
The choice of initialization method depends on your hardware setup, and one method may be more suitable than others. In addition to the following sections, please refer to the [official documentation](https://pytorch.org/docs/stable/distributed.html#initialization) for further information.
**Environment Variable**
We have been using the environment variable initialization method throughout this tutorial. By setting the following four environment variables on all machines, all processes will be able to properly connect to the master, obtain information about the other processes, and finally handshake with them.
- `MASTER_PORT`: A free port on the machine that will host the process with rank 0.
- `MASTER_ADDR`: IP address of the machine that will host the process with rank 0.
- `WORLD_SIZE`: The total number of processes, so that the master knows how many workers to wait for.
- `RANK`: Rank of each process, so they will know whether it is the master or a worker.
**Shared File System**
The shared filesystem requires all processes to have access to a shared file system, and will coordinate them through a shared file. This means that each process will open the file, write its information, and wait until everybody did so. After that all required information will be readily available to all processes. In order to avoid race conditions, the file system must support locking through [fcntl](http://man7.org/linux/man-pages/man2/fcntl.2.html).
```
dist.init_process_group(
init_method='file:///mnt/nfs/sharedfile',
rank=args.rank,
world_size=4)
```
**TCP**
Initializing via TCP can be achieved by providing the IP address of the process with rank 0 and a reachable port number. Here, all workers will be able to connect to the process with rank 0 and exchange information on how to reach each other.
```
dist.init_process_group(
init_method='tcp://10.1.1.20:23456',
rank=args.rank,
world_size=4)
```
**Acknowledgements**
Iād like to thank the PyTorch developers for doing such a good job on their implementation, documentation, and tests. When the code was unclear, I could always count on the [docs](https://pytorch.org/docs/stable/distributed.html) or the [tests](https://github.com/pytorch/pytorch/tree/master/test/distributed) to find an answer. In particular, Iād like to thank Soumith Chintala, Adam Paszke, and Natalia Gimelshein for providing insightful comments and answering questions on early drafts.
Rate this Page
ā
ā
ā
ā
ā
Send Feedback
[previous Getting Started with Distributed Data Parallel](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html "previous page")
[next Getting Started with Fully Sharded Data Parallel (FSDP2)](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html "next page")
Built with the [PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html) 0.15.4.
[previous Getting Started with Distributed Data Parallel](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html "previous page")
[next Getting Started with Fully Sharded Data Parallel (FSDP2)](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html "next page")
On this page
- [Setup](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#setup)
- [Point-to-Point Communication](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#point-to-point-communication)
- [Collective Communication](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#collective-communication)
- [Distributed Training](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#distributed-training)
- [Our Own Ring-Allreduce](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#our-own-ring-allreduce)
- [Advanced Topics](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#advanced-topics)
- [Communication Backends](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#communication-backends)
- [Initialization Methods](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#initialization-methods)
PyTorch Libraries
- [ExecuTorch](https://docs.pytorch.org/executorch)
- [Helion](https://docs.pytorch.org/helion)
- [torchao](https://docs.pytorch.org/ao)
- [kineto](https://github.com/pytorch/kineto)
- [torchtitan](https://github.com/pytorch/torchtitan)
- [TorchRL](https://docs.pytorch.org/rl)
- [torchvision](https://docs.pytorch.org/vision)
- [torchaudio](https://docs.pytorch.org/audio)
- [tensordict](https://docs.pytorch.org/tensordict)
- [PyTorch on XLA Devices](https://docs.pytorch.org/xla)
## Docs
Access comprehensive developer documentation for PyTorch
[View Docs](https://docs.pytorch.org/docs/stable/index.html)
## Tutorials
Get in-depth tutorials for beginners and advanced developers
[View Tutorials](https://docs.pytorch.org/tutorials)
## Resources
Find development resources and get your questions answered
[View Resources](https://pytorch.org/resources)
**Stay in touch** for updates, event info, and the latest news
By submitting this form, I consent to receive marketing emails from the LF and its projects regarding their events, training, research, developments, and related announcements. I understand that I can unsubscribe at any time using the links in the footers of the emails I receive. [Privacy Policy](https://www.linuxfoundation.org/privacy/).
Ā© PyTorch. Copyright Ā© The Linux FoundationĀ®. All rights reserved. The Linux Foundation has registered trademarks and uses trademarks. For more information, including terms of use, privacy policy, and trademark usage, please see our [Policies](https://www.linuxfoundation.org/legal/policies) page. [Trademark Usage](https://www.linuxfoundation.org/trademark-usage). [Privacy Policy](http://www.linuxfoundation.org/privacy).
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebookās Cookies Policy applies. Learn more, including about available controls: [Cookies Policy](https://opensource.fb.com/legal/cookie-policy).

Ā© Copyright 2024, PyTorch.
Created using [Sphinx](https://www.sphinx-doc.org/) 7.2.6.
Built with the [PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html) 0.15.4. |
| Readable Markdown | Created On: Oct 06, 2017 \| Last Updated: Sep 05, 2025 \| Last Verified: Nov 05, 2024
**Author**: [SƩb Arnold](https://seba1511.com/)
Note
[](https://docs.pytorch.org/tutorials/_images/pencil-16.png) View and edit this tutorial in [github](https://github.com/pytorch/tutorials/blob/main/intermediate_source/dist_tuto.rst).
Prerequisites:
- [PyTorch Distributed Overview](https://docs.pytorch.org/tutorials/beginner/dist_overview.html)
In this short tutorial, we will be going over the distributed package of PyTorch. Weāll see how to set up the distributed setting, use the different communication strategies, and go over some of the internals of the package.
## Setup[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#setup "Link to this heading")
The distributed package included in PyTorch (i.e., `torch.distributed`) enables researchers and practitioners to easily parallelize their computations across processes and clusters of machines. To do so, it leverages message passing semantics allowing each process to communicate data to any of the other processes. As opposed to the multiprocessing (`torch.multiprocessing`) package, processes can use different communication backends and are not restricted to being executed on the same machine.
In order to get started we need the ability to run multiple processes simultaneously. If you have access to compute cluster you should check with your local sysadmin or use your favorite coordination tool (e.g., [pdsh](https://linux.die.net/man/1/pdsh), [clustershell](https://cea-hpc.github.io/clustershell/), or [slurm](https://slurm.schedmd.com/)). For the purpose of this tutorial, we will use a single machine and spawn multiple processes using the following template.
```
"""run.py:"""
#!/usr/bin/env python
import os
import sys
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def run(rank, size):
""" Distributed function to be implemented later. """
pass
def init_process(rank, size, fn, backend='gloo'):
""" Initialize the distributed environment. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)
if __name__ == "__main__":
world_size = 2
processes = []
if "google.colab" in sys.modules:
print("Running in Google Colab")
mp.get_context("spawn")
else:
mp.set_start_method("spawn")
for rank in range(world_size):
p = mp.Process(target=init_process, args=(rank, world_size, run))
p.start()
processes.append(p)
for p in processes:
p.join()
```
The above script spawns two processes who will each setup the distributed environment, initialize the process group (`dist.init_process_group`), and finally execute the given `run` function.
Letās have a look at the `init_process` function. It ensures that every process will be able to coordinate through a master, using the same ip address and port. Note that we used the `gloo` backend but other backends are available. (c.f. [Section 5.1](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#communication-backends)) We will go over the magic happening in `dist.init_process_group` at the end of this tutorial, but it essentially allows processes to communicate with each other by sharing their locations.
## Point-to-Point Communication[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#point-to-point-communication "Link to this heading")
[](https://docs.pytorch.org/tutorials/_images/send_recv.png)
Send and Recv[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#id1 "Link to this image")
A transfer of data from one process to another is called a point-to-point communication. These are achieved through the `send` and `recv` functions or their *immediate* counter-parts, `isend` and `irecv`.
```
"""Blocking point-to-point communication."""
def run(rank, size):
tensor = torch.zeros(1)
if rank == 0:
tensor += 1
# Send the tensor to process 1
dist.send(tensor=tensor, dst=1)
else:
# Receive tensor from process 0
dist.recv(tensor=tensor, src=0)
print('Rank ', rank, ' has data ', tensor[0])
```
In the above example, both processes start with a zero tensor, then process 0 increments the tensor and sends it to process 1 so that they both end up with 1.0. Notice that process 1 needs to allocate memory in order to store the data it will receive.
Also notice that `send/recv` are **blocking**: both processes block until the communication is completed. On the other hand immediates are **non-blocking**; the script continues its execution and the methods return a `Work` object upon which we can choose to `wait()`.
```
"""Non-blocking point-to-point communication."""
def run(rank, size):
tensor = torch.zeros(1)
req = None
if rank == 0:
tensor += 1
# Send the tensor to process 1
req = dist.isend(tensor=tensor, dst=1)
print('Rank 0 started sending')
else:
# Receive tensor from process 0
req = dist.irecv(tensor=tensor, src=0)
print('Rank 1 started receiving')
req.wait()
print('Rank ', rank, ' has data ', tensor[0])
```
When using immediates we have to be careful about how we use the sent and received tensors. Since we do not know when the data will be communicated to the other process, we should not modify the sent tensor nor access the received tensor before `req.wait()` has completed. In other words,
- writing to `tensor` after `dist.isend()` will result in undefined behaviour.
- reading from `tensor` after `dist.irecv()` will result in undefined behaviour, until `req.wait()` has been executed.
However, after `req.wait()` has been executed we are guaranteed that the communication took place, and that the value stored in `tensor[0]` is 1.0.
Point-to-point communication is useful when we want more fine-grained control over the communication of our processes. They can be used to implement fancy algorithms, such as the one used in [Baiduās DeepSpeech](https://github.com/baidu-research/baidu-allreduce) or [Facebookās large-scale experiments](https://research.fb.com/publications/imagenet1kin1h/).(c.f. [Section 4.1](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#our-own-ring-allreduce))
## Collective Communication[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#collective-communication "Link to this heading")
As opposed to point-to-point communcation, collectives allow for communication patterns across all processes in a **group**. A group is a subset of all our processes. To create a group, we can pass a list of ranks to `dist.new_group(group)`. By default, collectives are executed on all processes, also known as the **world**. For example, in order to obtain the sum of all tensors on all processes, we can use the `dist.all_reduce(tensor, op, group)` collective.
```
""" All-Reduce example."""
def run(rank, size):
""" Simple collective communication. """
group = dist.new_group([0, 1])
tensor = torch.ones(1)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group)
print('Rank ', rank, ' has data ', tensor[0])
```
Since we want the sum of all tensors in the group, we use `dist.ReduceOp.SUM` as the reduce operator. Generally speaking, any commutative mathematical operation can be used as an operator. Out-of-the-box, PyTorch comes with many such operators, all working at the element-wise level:
- `dist.ReduceOp.SUM`,
- `dist.ReduceOp.PRODUCT`,
- `dist.ReduceOp.MAX`,
- `dist.ReduceOp.MIN`,
- `dist.ReduceOp.BAND`,
- `dist.ReduceOp.BOR`,
- `dist.ReduceOp.BXOR`,
- `dist.ReduceOp.PREMUL_SUM`.
The full list of supported operators is [here](https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp).
In addition to `dist.all_reduce(tensor, op, group)`, there are many additional collectives currently implemented in PyTorch. Here are a few supported collectives.
- `dist.broadcast(tensor, src, group)`: Copies `tensor` from `src` to all other processes.
- `dist.reduce(tensor, dst, op, group)`: Applies `op` to every `tensor` and stores the result in `dst`.
- `dist.all_reduce(tensor, op, group)`: Same as reduce, but the result is stored in all processes.
- `dist.scatter(tensor, scatter_list, src, group)`: Copies the i th i^{\\text{th}} tensor `scatter_list[i]` to the i th i^{\\text{th}} process.
- `dist.gather(tensor, gather_list, dst, group)`: Copies `tensor` from all processes in `dst`.
- `dist.all_gather(tensor_list, tensor, group)`: Copies `tensor` from all processes to `tensor_list`, on all processes.
- `dist.barrier(group)`: Blocks all processes in group until each one has entered this function.
- `dist.all_to_all(output_tensor_list, input_tensor_list, group)`: Scatters list of input tensors to all processes in a group and return gathered list of tensors in output list.
The full list of supported collectives can be found by looking at the latest documentation for PyTorch Distributed [(link)](https://pytorch.org/docs/stable/distributed.html).
## Distributed Training[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#distributed-training "Link to this heading")
**Note:** You can find the example script of this section in [this GitHub repository](https://github.com/seba-1511/dist_tuto.pth/).
Now that we understand how the distributed module works, let us write something useful with it. Our goal will be to replicate the functionality of [DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel). Of course, this will be a didactic example and in a real-world situation you should use the official, well-tested and well-optimized version linked above.
Quite simply we want to implement a distributed version of stochastic gradient descent. Our script will let all processes compute the gradients of their model on their batch of data and then average their gradients. In order to ensure similar convergence results when changing the number of processes, we will first have to partition our dataset. (You could also use [torch.utils.data.random\_split](https://pytorch.org/docs/stable/data.html#torch.utils.data.random_split), instead of the snippet below.)
```
""" Dataset partitioning helper """
class Partition(object):
def __init__(self, data, index):
self.data = data
self.index = index
def __len__(self):
return len(self.index)
def __getitem__(self, index):
data_idx = self.index[index]
return self.data[data_idx]
class DataPartitioner(object):
def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234):
self.data = data
self.partitions = []
rng = Random() # from random import Random
rng.seed(seed)
data_len = len(data)
indexes = [x for x in range(0, data_len)]
rng.shuffle(indexes)
for frac in sizes:
part_len = int(frac * data_len)
self.partitions.append(indexes[0:part_len])
indexes = indexes[part_len:]
def use(self, partition):
return Partition(self.data, self.partitions[partition])
```
With the above snippet, we can now simply partition any dataset using the following few lines:
```
""" Partitioning MNIST """
def partition_dataset():
dataset = datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
size = dist.get_world_size()
bsz = 128 // size
partition_sizes = [1.0 / size for _ in range(size)]
partition = DataPartitioner(dataset, partition_sizes)
partition = partition.use(dist.get_rank())
train_set = torch.utils.data.DataLoader(partition,
batch_size=bsz,
shuffle=True)
return train_set, bsz
```
Assuming we have 2 replicas, then each process will have a `train_set` of 60000 / 2 = 30000 samples. We also divide the batch size by the number of replicas in order to maintain the *overall* batch size of 128.
We can now write our usual forward-backward-optimize training code, and add a function call to average the gradients of our models. (The following is largely inspired by the official [PyTorch MNIST example](https://github.com/pytorch/examples/blob/master/mnist/main.py).)
```
""" Distributed Synchronous SGD Example """
def run(rank, size):
torch.manual_seed(1234)
train_set, bsz = partition_dataset()
model = Net()
optimizer = optim.SGD(model.parameters(),
lr=0.01, momentum=0.5)
num_batches = ceil(len(train_set.dataset) / float(bsz))
for epoch in range(10):
epoch_loss = 0.0
for data, target in train_set:
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
epoch_loss += loss.item()
loss.backward()
average_gradients(model)
optimizer.step()
print('Rank ', dist.get_rank(), ', epoch ',
epoch, ': ', epoch_loss / num_batches)
```
It remains to implement the `average_gradients(model)` function, which simply takes in a model and averages its gradients across the whole world.
```
""" Gradient averaging. """
def average_gradients(model):
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= size
```
*Et voilĆ *! We successfully implemented distributed synchronous SGD and could train any model on a large computer cluster.
**Note:** While the last sentence is *technically* true, there are [a lot more tricks](https://seba-1511.github.io/dist_blog) required to implement a production-level implementation of synchronous SGD. Again, use what [has been tested and optimized](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel).
### Our Own Ring-Allreduce[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#our-own-ring-allreduce "Link to this heading")
As an additional challenge, imagine that we wanted to implement DeepSpeechās efficient ring allreduce. This is fairly easy to implement using point-to-point collectives.
```
""" Implementation of a ring-reduce with addition. """
def allreduce(send, recv):
rank = dist.get_rank()
size = dist.get_world_size()
send_buff = send.clone()
recv_buff = send.clone()
accum = send.clone()
left = ((rank - 1) + size) % size
right = (rank + 1) % size
for i in range(size - 1):
if i % 2 == 0:
# Send send_buff
send_req = dist.isend(send_buff, right)
dist.recv(recv_buff, left)
accum[:] += recv_buff[:]
else:
# Send recv_buff
send_req = dist.isend(recv_buff, right)
dist.recv(send_buff, left)
accum[:] += send_buff[:]
send_req.wait()
recv[:] = accum[:]
```
In the above script, the `allreduce(send, recv)` function has a slightly different signature than the ones in PyTorch. It takes a `recv` tensor and will store the sum of all `send` tensors in it. As an exercise left to the reader, there is still one difference between our version and the one in DeepSpeech: their implementation divides the gradient tensor into *chunks*, so as to optimally utilize the communication bandwidth. (Hint: [torch.chunk](https://pytorch.org/docs/stable/torch.html#torch.chunk))
## Advanced Topics[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#advanced-topics "Link to this heading")
We are now ready to discover some of the more advanced functionalities of `torch.distributed`. Since there is a lot to cover, this section is divided into two subsections:
1. Communication Backends: where we learn how to use MPI and Gloo for GPU-GPU communication.
2. Initialization Methods: where we understand how to best set up the initial coordination phase in `dist.init_process_group()`.
### Communication Backends[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#communication-backends "Link to this heading")
One of the most elegant aspects of `torch.distributed` is its ability to abstract and build on top of different backends. As mentioned before, there are multiple backends implemented in PyTorch. These backends can be easily selected using the [Accelerator API](https://pytorch.org/docs/stable/torch.html#accelerators), which provides a interface for working with different accelerator types. Some of the most popular backends are Gloo, NCCL, and MPI. They each have different specifications and tradeoffs, depending on the desired use case. A comparative table of supported functions can be found [here](https://pytorch.org/docs/stable/distributed.html#module-torch.distributed).
**Gloo Backend**
So far we have made extensive usage of the [Gloo backend](https://github.com/facebookincubator/gloo). It is quite handy as a development platform, as it is included in the pre-compiled PyTorch binaries and works on both Linux (since 0.2) and macOS (since 1.3). It supports all point-to-point and collective operations on CPU, and all collective operations on GPU. The implementation of the collective operations for CUDA tensors is not as optimized as the ones provided by the NCCL backend.
As you have surely noticed, our distributed SGD example does not work if you put `model` on the GPU. In order to use multiple GPUs, let us also make the following modifications:
1. Use Accelerator API `device_type = torch.accelerator.current_accelerator()`
2. Use `torch.device(f"{device_type}:{rank}")`
3. `model = Net()` ā \\rightarrow `model = Net().to(device)`
4. Use `data, target = data.to(device), target.to(device)`
With these modifications, your model will now train across two GPUs. You can monitor GPU utilization using `watch nvidia-smi` if you are running on NVIDIA hardware.
**MPI Backend**
The Message Passing Interface (MPI) is a standardized tool from the field of high-performance computing. It allows to do point-to-point and collective communications and was the main inspiration for the API of `torch.distributed`. Several implementations of MPI exist (e.g. [Open-MPI](https://www.open-mpi.org/), [MVAPICH2](http://mvapich.cse.ohio-state.edu/), [Intel MPI](https://software.intel.com/en-us/intel-mpi-library)) each optimized for different purposes. The advantage of using the MPI backend lies in MPIās wide availability - and high-level of optimization - on large computer clusters. [Some](https://developer.nvidia.com/mvapich) [recent](https://developer.nvidia.com/ibm-spectrum-mpi) [implementations](https://www.open-mpi.org/) are also able to take advantage of CUDA IPC and GPU Direct technologies in order to avoid memory copies through the CPU.
Unfortunately, PyTorchās binaries cannot include an MPI implementation and weāll have to recompile it by hand. Fortunately, this process is fairly simple given that upon compilation, PyTorch will look *by itself* for an available MPI implementation. The following steps install the MPI backend, by installing PyTorch [from source](https://github.com/pytorch/pytorch#from-source).
1. Create and activate your Anaconda environment, install all the pre-requisites following [the guide](https://github.com/pytorch/pytorch#from-source), but do **not** run `python setup.py install` yet.
2. Choose and install your favorite MPI implementation. Note that enabling CUDA-aware MPI might require some additional steps. In our case, weāll stick to Open-MPI *without* GPU support: `conda install -c conda-forge openmpi`
3. Now, go to your cloned PyTorch repo and execute `python setup.py install`.
In order to test our newly installed backend, a few modifications are required.
1. Replace the content under `if __name__ == '__main__':` with `init_process(0, 0, run, backend='mpi')`.
2. Run `mpirun -n 4 python myscript.py`.
The reason for these changes is that MPI needs to create its own environment before spawning the processes. MPI will also spawn its own processes and perform the handshake described in [Initialization Methods](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#initialization-methods), making the `rank`and `size` arguments of `init_process_group` superfluous. This is actually quite powerful as you can pass additional arguments to `mpirun` in order to tailor computational resources for each process. (Things like number of cores per process, hand-assigning machines to specific ranks, and [some more](https://www.open-mpi.org/faq/?category=running#mpirun-hostfile)) Doing so, you should obtain the same familiar output as with the other communication backends.
**NCCL Backend**
The [NCCL backend](https://github.com/nvidia/nccl) provides an optimized implementation of collective operations against CUDA tensors. If you only use CUDA tensors for your collective operations, consider using this backend for the best in class performance. The NCCL backend is included in the pre-built binaries with CUDA support.
**XCCL Backend**
The XCCL backend offers an optimized implementation of collective operations for XPU tensors. If your workload uses only XPU tensors for collective operations, this backend provides best-in-class performance. The XCCL backend is included in the pre-built binaries with XPU support.
### Initialization Methods[\#](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html#initialization-methods "Link to this heading")
To conclude this tutorial, letās examine the initial function we invoked: `dist.init_process_group(backend, init_method)`. Specifically, we will discuss the various initialization methods responsible for the preliminary coordination step between each process. These methods enable you to define how this coordination is accomplished.
The choice of initialization method depends on your hardware setup, and one method may be more suitable than others. In addition to the following sections, please refer to the [official documentation](https://pytorch.org/docs/stable/distributed.html#initialization) for further information.
**Environment Variable**
We have been using the environment variable initialization method throughout this tutorial. By setting the following four environment variables on all machines, all processes will be able to properly connect to the master, obtain information about the other processes, and finally handshake with them.
- `MASTER_PORT`: A free port on the machine that will host the process with rank 0.
- `MASTER_ADDR`: IP address of the machine that will host the process with rank 0.
- `WORLD_SIZE`: The total number of processes, so that the master knows how many workers to wait for.
- `RANK`: Rank of each process, so they will know whether it is the master or a worker.
**Shared File System**
The shared filesystem requires all processes to have access to a shared file system, and will coordinate them through a shared file. This means that each process will open the file, write its information, and wait until everybody did so. After that all required information will be readily available to all processes. In order to avoid race conditions, the file system must support locking through [fcntl](http://man7.org/linux/man-pages/man2/fcntl.2.html).
```
dist.init_process_group(
init_method='file:///mnt/nfs/sharedfile',
rank=args.rank,
world_size=4)
```
**TCP**
Initializing via TCP can be achieved by providing the IP address of the process with rank 0 and a reachable port number. Here, all workers will be able to connect to the process with rank 0 and exchange information on how to reach each other.
```
dist.init_process_group(
init_method='tcp://10.1.1.20:23456',
rank=args.rank,
world_size=4)
```
**Acknowledgements**
Iād like to thank the PyTorch developers for doing such a good job on their implementation, documentation, and tests. When the code was unclear, I could always count on the [docs](https://pytorch.org/docs/stable/distributed.html) or the [tests](https://github.com/pytorch/pytorch/tree/master/test/distributed) to find an answer. In particular, Iād like to thank Soumith Chintala, Adam Paszke, and Natalia Gimelshein for providing insightful comments and answering questions on early drafts. |
| Shard | 114 (laksa) |
| Root Hash | 14416670112284949514 |
| Unparsed URL | org,pytorch!docs,/tutorials/intermediate/dist_tuto.html s443 |