ā¹ļø Skipped - page is already crawled
| Filter | Status | Condition | Details |
|---|---|---|---|
| HTTP status | PASS | download_http_code = 200 | HTTP 200 |
| Age cutoff | PASS | download_stamp > now() - 6 MONTH | 0.1 months ago |
| History drop | PASS | isNull(history_drop_reason) | No drop reason |
| Spam/ban | PASS | fh_dont_index != 1 AND ml_spam_score = 0 | ml_spam_score=0 |
| Canonical | PASS | meta_canonical IS NULL OR = '' OR = src_unparsed | Not set |
| Property | Value |
|---|---|
| URL | https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html |
| Last Crawled | 2026-04-15 07:01:18 (2 days ago) |
| First Indexed | 2025-06-30 23:25:20 (9 months ago) |
| HTTP Status Code | 200 |
| Meta Title | Getting Started with Distributed RPC Framework ā PyTorch Tutorials 2.11.0+cu130 documentation |
| Meta Description | null |
| Meta Canonical | null |
| Boilerpipe Text | Created On: Jan 01, 2020 | Last Updated: Sep 03, 2025 | Last Verified: Nov 05, 2024
Author
:
Shen Li
Note
View and edit this tutorial in
github
.
Prerequisites:
PyTorch Distributed Overview
RPC API documents
This tutorial uses two simple examples to demonstrate how to build distributed
training with the
torch.distributed.rpc
package which was first introduced as an experimental feature in PyTorch v1.4.
Source code of the two examples can be found in
PyTorch examples
.
Previous tutorials,
Getting Started With Distributed Data Parallel
and
Writing Distributed Applications With PyTorch
,
described
DistributedDataParallel
which supports a specific training paradigm where the model is replicated across
multiple processes and each process handles a split of the input data.
Sometimes, you might run into scenarios that require different training
paradigms. For example:
In reinforcement learning, it might be relatively expensive to acquire
training data from environments while the model itself can be quite small. In
this case, it might be useful to spawn multiple observers running in parallel
and share a single agent. In this case, the agent takes care of the training
locally, but the application would still need libraries to send and receive
data between observers and the trainer.
Your model might be too large to fit in GPUs on a single machine, and hence
would need a library to help split the model onto multiple machines. Or you
might be implementing a
parameter server
training framework, where model parameters and trainers live on different
machines.
The
torch.distributed.rpc
package
can help with the above scenarios. In case 1,
RPC
and
RRef
allow sending data
from one worker to another while easily referencing remote data objects. In
case 2,
distributed autograd
and
distributed optimizer
make executing backward pass and optimizer step as if it is local training. In
the next two sections, we will demonstrate APIs of
torch.distributed.rpc
using a
reinforcement learning example and a language model example. Please note, this
tutorial does not aim at building the most accurate or efficient models to
solve given problems, instead, the main goal here is to show how to use the
torch.distributed.rpc
package to
build distributed training applications.
Distributed Reinforcement Learning using RPC and RRef
#
This section describes steps to build a toy distributed reinforcement learning
model using RPC to solve CartPole-v1 from
OpenAI Gym
.
The policy code is mostly borrowed from the existing single-thread
example
as shown below. We will skip details of the
Policy
design, and focus on RPC
usages.
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
Policy
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Policy
,
self
)
.
__init__
()
self
.
affine1
=
nn
.
Linear
(
4
,
128
)
self
.
dropout
=
nn
.
Dropout
(
p
=
0.6
)
self
.
affine2
=
nn
.
Linear
(
128
,
2
)
def
forward
(
self
,
x
):
x
=
self
.
affine1
(
x
)
x
=
self
.
dropout
(
x
)
x
=
F
.
relu
(
x
)
action_scores
=
self
.
affine2
(
x
)
return
F
.
softmax
(
action_scores
,
dim
=
1
)
We are ready to present the observer. In this example, each observer creates its
own environment, and waits for the agentās command to run an episode. In each
episode, one observer loops at most
n_steps
iterations, and in each
iteration, it uses RPC to pass its environment state to the agent and gets an
action back. Then it applies that action to its environment, and gets the reward
and the next state from the environment. After that, the observer uses another
RPC to report the reward to the agent. Again, please note that, this is
obviously not the most efficient observer implementation. For example, one
simple optimization could be packing current state and last reward in one RPC to
reduce the communication overhead. However, the goal is to demonstrate RPC API
instead of building the best solver for CartPole. So, letās keep the logic
simple and the two steps explicit in this example.
import
argparse
import
gym
import
torch.distributed.rpc
as
rpc
parser
=
argparse
.
ArgumentParser
(
description
=
"RPC Reinforcement Learning Example"
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
)
parser
.
add_argument
(
'--world_size'
,
default
=
2
,
type
=
int
,
metavar
=
'W'
,
help
=
'number of workers'
)
parser
.
add_argument
(
'--log_interval'
,
type
=
int
,
default
=
10
,
metavar
=
'N'
,
help
=
'interval between training status logs'
)
parser
.
add_argument
(
'--gamma'
,
type
=
float
,
default
=
0.99
,
metavar
=
'G'
,
help
=
'how much to value future rewards'
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1
,
metavar
=
'S'
,
help
=
'random seed for reproducibility'
)
args
=
parser
.
parse_args
()
class
Observer
:
def
__init__
(
self
):
self
.
id
=
rpc
.
get_worker_info
()
.
id
self
.
env
=
gym
.
make
(
'CartPole-v1'
)
self
.
env
.
seed
(
args
.
seed
)
def
run_episode
(
self
,
agent_rref
):
state
,
ep_reward
=
self
.
env
.
reset
(),
0
for
_
in
range
(
10000
):
# send the state to the agent to get an action
action
=
agent_rref
.
rpc_sync
()
.
select_action
(
self
.
id
,
state
)
# apply the action to the environment, and get the reward
state
,
reward
,
done
,
_
=
self
.
env
.
step
(
action
)
# report the reward to the agent for training purpose
agent_rref
.
rpc_sync
()
.
report_reward
(
self
.
id
,
reward
)
# finishes after the number of self.env._max_episode_steps
if
done
:
break
The code for agent is a little more complex, and we will break it into multiple
pieces. In this example, the agent serves as both the trainer and the master,
such that it sends command to multiple distributed observers to run episodes,
and it also records all actions and rewards locally which will be used during
the training phase after each episode. The code below shows
Agent
constructor where most lines are initializing various components. The loop at
the end initializes observers remotely on other workers, and holds
RRefs
to
those observers locally. The agent will use those observer
RRefs
later to
send commands. Applications donāt need to worry about the lifetime of
RRefs
.
The owner of each
RRef
maintains a reference counting map to track its
lifetime, and guarantees the remote data object will not be deleted as long as
there is any live user of that
RRef
. Please refer to the
RRef
design doc
for details.
import
gym
import
numpy
as
np
import
torch
import
torch.distributed.rpc
as
rpc
import
torch.optim
as
optim
from
torch.distributed.rpc
import
RRef
,
rpc_async
,
remote
from
torch.distributions
import
Categorical
class
Agent
:
def
__init__
(
self
,
world_size
):
self
.
ob_rrefs
=
[]
self
.
agent_rref
=
RRef
(
self
)
self
.
rewards
=
{}
self
.
saved_log_probs
=
{}
self
.
policy
=
Policy
()
self
.
optimizer
=
optim
.
Adam
(
self
.
policy
.
parameters
(),
lr
=
1e-2
)
self
.
eps
=
np
.
finfo
(
np
.
float32
)
.
eps
.
item
()
self
.
running_reward
=
0
self
.
reward_threshold
=
gym
.
make
(
'CartPole-v1'
)
.
spec
.
reward_threshold
for
ob_rank
in
range
(
1
,
world_size
):
ob_info
=
rpc
.
get_worker_info
(
OBSERVER_NAME
.
format
(
ob_rank
))
self
.
ob_rrefs
.
append
(
remote
(
ob_info
,
Observer
))
self
.
rewards
[
ob_info
.
id
]
=
[]
self
.
saved_log_probs
[
ob_info
.
id
]
=
[]
Next, the agent exposes two APIs to observers for selecting actions and
reporting rewards. Those functions only run locally on the agent, but will
be triggered by observers through RPC.
class
Agent
:
...
def
select_action
(
self
,
ob_id
,
state
):
state
=
torch
.
from_numpy
(
state
)
.
float
()
.
unsqueeze
(
0
)
probs
=
self
.
policy
(
state
)
m
=
Categorical
(
probs
)
action
=
m
.
sample
()
self
.
saved_log_probs
[
ob_id
]
.
append
(
m
.
log_prob
(
action
))
return
action
.
item
()
def
report_reward
(
self
,
ob_id
,
reward
):
self
.
rewards
[
ob_id
]
.
append
(
reward
)
Letās add a
run_episode
function on agent which tells all observers
to execute an episode. In this function, it first creates a list to collect
futures from asynchronous RPCs, and then loop over all observer
RRefs
to
make asynchronous RPCs. In these RPCs, the agent also passes an
RRef
of
itself to the observer, so that the observer can call functions on the agent as
well. As shown above, each observer will make RPCs back to the agent, which are
nested RPCs. After each episode, the
saved_log_probs
and
rewards
will
contain the recorded action probs and rewards.
class
Agent
:
...
def
run_episode
(
self
):
futs
=
[]
for
ob_rref
in
self
.
ob_rrefs
:
# make async RPC to kick off an episode on all observers
futs
.
append
(
rpc_async
(
ob_rref
.
owner
(),
ob_rref
.
rpc_sync
()
.
run_episode
,
args
=
(
self
.
agent_rref
,)
)
)
# wait until all obervers have finished this episode
for
fut
in
futs
:
fut
.
wait
()
Finally, after one episode, the agent needs to train the model, which
is implemented in the
finish_episode
function below. There is no RPCs in
this function and it is mostly borrowed from the single-thread
example
.
Hence, we skip describing its contents.
class
Agent
:
...
def
finish_episode
(
self
):
# joins probs and rewards from different observers into lists
R
,
probs
,
rewards
=
0
,
[],
[]
for
ob_id
in
self
.
rewards
:
probs
.
extend
(
self
.
saved_log_probs
[
ob_id
])
rewards
.
extend
(
self
.
rewards
[
ob_id
])
# use the minimum observer reward to calculate the running reward
min_reward
=
min
([
sum
(
self
.
rewards
[
ob_id
])
for
ob_id
in
self
.
rewards
])
self
.
running_reward
=
0.05
*
min_reward
+
(
1
-
0.05
)
*
self
.
running_reward
# clear saved probs and rewards
for
ob_id
in
self
.
rewards
:
self
.
rewards
[
ob_id
]
=
[]
self
.
saved_log_probs
[
ob_id
]
=
[]
policy_loss
,
returns
=
[],
[]
for
r
in
rewards
[::
-
1
]:
R
=
r
+
args
.
gamma
*
R
returns
.
insert
(
0
,
R
)
returns
=
torch
.
tensor
(
returns
)
returns
=
(
returns
-
returns
.
mean
())
/
(
returns
.
std
()
+
self
.
eps
)
for
log_prob
,
R
in
zip
(
probs
,
returns
):
policy_loss
.
append
(
-
log_prob
*
R
)
self
.
optimizer
.
zero_grad
()
policy_loss
=
torch
.
cat
(
policy_loss
)
.
sum
()
policy_loss
.
backward
()
self
.
optimizer
.
step
()
return
min_reward
With
Policy
,
Observer
, and
Agent
classes, we are ready to launch
multiple processes to perform the distributed training. In this example, all
processes run the same
run_worker
function, and they use the rank to
distinguish their role. Rank 0 is always the agent, and all other ranks are
observers. The agent serves as master by repeatedly calling
run_episode
and
finish_episode
until the running reward surpasses the reward threshold
specified by the environment. All observers passively waiting for commands
from the agent. The code is wrapped by
rpc.init_rpc
and
rpc.shutdown
,
which initializes and terminates RPC instances respectively. More details are
available in the
API page
.
import
os
from
itertools
import
count
import
torch.multiprocessing
as
mp
AGENT_NAME
=
"agent"
OBSERVER_NAME
=
"obs
{}
"
def
run_worker
(
rank
,
world_size
):
os
.
environ
[
'MASTER_ADDR'
]
=
'localhost'
os
.
environ
[
'MASTER_PORT'
]
=
'29500'
if
rank
==
0
:
# rank0 is the agent
rpc
.
init_rpc
(
AGENT_NAME
,
rank
=
rank
,
world_size
=
world_size
)
agent
=
Agent
(
world_size
)
print
(
f
"This will run until reward threshold of
{
agent
.
reward_threshold
}
"
" is reached. Ctrl+C to exit."
)
for
i_episode
in
count
(
1
):
agent
.
run_episode
()
last_reward
=
agent
.
finish_episode
()
if
i_episode
%
args
.
log_interval
==
0
:
print
(
f
"Episode
{
i_episode
}
\t
Last reward:
{
last_reward
:
.2f
}
\t
Average reward: "
f
"
{
agent
.
running_reward
:
.2f
}
"
)
if
agent
.
running_reward
>
agent
.
reward_threshold
:
print
(
f
"Solved! Running reward is now
{
agent
.
running_reward
}
!"
)
break
else
:
# other ranks are the observer
rpc
.
init_rpc
(
OBSERVER_NAME
.
format
(
rank
),
rank
=
rank
,
world_size
=
world_size
)
# observers passively waiting for instructions from the agent
# block until all rpcs finish, and shutdown the RPC instance
rpc
.
shutdown
()
mp
.
spawn
(
run_worker
,
args
=
(
args
.
world_size
,
),
nprocs
=
args
.
world_size
,
join
=
True
)
Below are some sample outputs when training with
world_size=2
.
This will run until reward threshold of 475.0 is reached. Ctrl+C to exit.
Episode 10 Last reward: 26.00 Average reward: 10.01
Episode 20 Last reward: 16.00 Average reward: 11.27
Episode 30 Last reward: 49.00 Average reward: 18.62
Episode 40 Last reward: 45.00 Average reward: 26.09
Episode 50 Last reward: 44.00 Average reward: 30.03
Episode 60 Last reward: 111.00 Average reward: 42.23
Episode 70 Last reward: 131.00 Average reward: 70.11
Episode 80 Last reward: 87.00 Average reward: 76.51
Episode 90 Last reward: 86.00 Average reward: 95.93
Episode 100 Last reward: 13.00 Average reward: 123.93
Episode 110 Last reward: 33.00 Average reward: 91.39
Episode 120 Last reward: 73.00 Average reward: 76.38
Episode 130 Last reward: 137.00 Average reward: 88.08
Episode 140 Last reward: 89.00 Average reward: 104.96
Episode 150 Last reward: 97.00 Average reward: 98.74
Episode 160 Last reward: 150.00 Average reward: 100.87
Episode 170 Last reward: 126.00 Average reward: 104.38
Episode 180 Last reward: 500.00 Average reward: 213.74
Episode 190 Last reward: 322.00 Average reward: 300.22
Episode 200 Last reward: 165.00 Average reward: 272.71
Episode 210 Last reward: 168.00 Average reward: 233.11
Episode 220 Last reward: 184.00 Average reward: 195.02
Episode 230 Last reward: 284.00 Average reward: 208.32
Episode 240 Last reward: 395.00 Average reward: 247.37
Episode 250 Last reward: 500.00 Average reward: 335.42
Episode 260 Last reward: 500.00 Average reward: 386.30
Episode 270 Last reward: 500.00 Average reward: 405.29
Episode 280 Last reward: 500.00 Average reward: 443.29
Episode 290 Last reward: 500.00 Average reward: 464.65
Solved! Running reward is now 475.3163778435275!
In this example, we show how to use RPC as the communication vehicle to pass
data across workers, and how to use RRef to reference remote objects. It is true
that you could build the entire structure directly on top of
ProcessGroup
send
and
recv
APIs or use other communication/RPC libraries. However,
by using
torch.distributed.rpc
, you can get the native support and
continuously optimized performance under the hood.
Next, we will show how to combine RPC and RRef with distributed autograd and
distributed optimizer to perform distributed model parallel training.
Distributed RNN using Distributed Autograd and Distributed Optimizer
#
In this section, we use an RNN model to show how to build distributed model
parallel training with the RPC API. The example RNN model is very small and
can easily fit into a single GPU, but we still divide its layers onto two
different workers to demonstrate the idea. Developer can apply the similar
techniques to distribute much larger models across multiple devices and
machines.
The RNN model design is borrowed from the word language model in PyTorch
example
repository, which contains three main components, an embedding table, an
LSTM
layer, and a decoder. The code below wraps the embedding table and the
decoder into sub-modules, so that their constructors can be passed to the RPC
API. In the
EmbeddingTable
sub-module, we intentionally put the
Embedding
layer on GPU to cover the use case. In v1.4, RPC always creates
CPU tensor arguments or return values on the destination worker. If the function
takes a GPU tensor, you need to move it to the proper device explicitly.
class
EmbeddingTable
(
nn
.
Module
):
r
"""
Encoding layers of the RNNModel
"""
def
__init__
(
self
,
ntoken
,
ninp
,
dropout
):
super
(
EmbeddingTable
,
self
)
.
__init__
()
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
encoder
=
nn
.
Embedding
(
ntoken
,
ninp
)
.
cuda
()
self
.
encoder
.
weight
.
data
.
uniform_
(
-
0.1
,
0.1
)
def
forward
(
self
,
input
):
return
self
.
drop
(
self
.
encoder
(
input
.
cuda
())
.
cpu
()
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
ntoken
,
nhid
,
dropout
):
super
(
Decoder
,
self
)
.
__init__
()
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
decoder
=
nn
.
Linear
(
nhid
,
ntoken
)
self
.
decoder
.
bias
.
data
.
zero_
()
self
.
decoder
.
weight
.
data
.
uniform_
(
-
0.1
,
0.1
)
def
forward
(
self
,
output
):
return
self
.
decoder
(
self
.
drop
(
output
))
With the above sub-modules, we can now piece them together using RPC to
create an RNN model. In the code below
ps
represents a parameter server,
which hosts parameters of the embedding table and the decoder. The constructor
uses the
remote
API to create an
EmbeddingTable
object and a
Decoder
object on the
parameter server, and locally creates the
LSTM
sub-module. During the
forward pass, the trainer uses the
EmbeddingTable
RRef
to find the
remote sub-module and passes the input data to the
EmbeddingTable
using RPC
and fetches the lookup results. Then, it runs the embedding through the local
LSTM
layer, and finally uses another RPC to send the output to the
Decoder
sub-module. In general, to implement distributed model parallel
training, developers can divide the model into sub-modules, invoke RPC to create
sub-module instances remotely, and use on
RRef
to find them when necessary.
As you can see in the code below, it looks very similar to single-machine model
parallel training. The main difference is replacing
Tensor.to(device)
with
RPC functions.
class
RNNModel
(
nn
.
Module
):
def
__init__
(
self
,
ps
,
ntoken
,
ninp
,
nhid
,
nlayers
,
dropout
=
0.5
):
super
(
RNNModel
,
self
)
.
__init__
()
# setup embedding table remotely
self
.
emb_table_rref
=
rpc
.
remote
(
ps
,
EmbeddingTable
,
args
=
(
ntoken
,
ninp
,
dropout
))
# setup LSTM locally
self
.
rnn
=
nn
.
LSTM
(
ninp
,
nhid
,
nlayers
,
dropout
=
dropout
)
# setup decoder remotely
self
.
decoder_rref
=
rpc
.
remote
(
ps
,
Decoder
,
args
=
(
ntoken
,
nhid
,
dropout
))
def
forward
(
self
,
input
,
hidden
):
# pass input to the remote embedding table and fetch emb tensor back
emb
=
_remote_method
(
EmbeddingTable
.
forward
,
self
.
emb_table_rref
,
input
)
output
,
hidden
=
self
.
rnn
(
emb
,
hidden
)
# pass output to the rremote decoder and get the decoded output back
decoded
=
_remote_method
(
Decoder
.
forward
,
self
.
decoder_rref
,
output
)
return
decoded
,
hidden
Before introducing the distributed optimizer, letās add a helper function to
generate a list of RRefs of model parameters, which will be consumed by the
distributed optimizer. In local training, applications could call
Module.parameters()
to grab references to all parameter tensors, and pass it
to the local optimizer for subsequent updates. However, the same API does not
work in distributed training scenarios as some parameters live on remote
machines. Therefore, instead of taking a list of parameter
Tensors
, the
distributed optimizer takes a list of
RRefs
, one
RRef
per model
parameter for both local and remote model parameters. The helper function is
pretty simple, just call
Module.parameters()
and creates a local
RRef
on
each of the parameters.
def
_parameter_rrefs
(
module
):
param_rrefs
=
[]
for
param
in
module
.
parameters
():
param_rrefs
.
append
(
RRef
(
param
))
return
param_rrefs
Then, as the
RNNModel
contains three sub-modules, we need to call
_parameter_rrefs
three times, and wrap that into another helper function.
class
RNNModel
(
nn
.
Module
):
...
def
parameter_rrefs
(
self
):
remote_params
=
[]
# get RRefs of embedding table
remote_params
.
extend
(
_remote_method
(
_parameter_rrefs
,
self
.
emb_table_rref
))
# create RRefs for local parameters
remote_params
.
extend
(
_parameter_rrefs
(
self
.
rnn
))
# get RRefs of decoder
remote_params
.
extend
(
_remote_method
(
_parameter_rrefs
,
self
.
decoder_rref
))
return
remote_params
Now, we are ready to implement the training loop. After initializing model
arguments, we create the
RNNModel
and the
DistributedOptimizer
. The
distributed optimizer will take a list of parameter
RRefs
, find all distinct
owner workers, and create the given local optimizer (i.e.,
SGD
in this case,
you can use other local optimizers as well) on each of the owner worker using
the given arguments (i.e.,
lr=0.05
).
In the training loop, it first creates a distributed autograd context, which
will help the distributed autograd engine to find gradients and involved RPC
send/recv functions. The design details of the distributed autograd engine can
be found in its
design note
.
Then, it kicks off the forward pass as if it is a local
model, and run the distributed backward pass. For the distributed backward, you
only need to specify a list of roots, in this case, it is the loss
Tensor
.
The distributed autograd engine will traverse the distributed graph
automatically and write gradients properly. Next, it runs the
step
function on the distributed optimizer, which will reach out to all involved
local optimizers to update model parameters. Compared to local training, one
minor difference is that you donāt need to run
zero_grad()
because each
autograd context has dedicated space to store gradients, and as we create a
context per iteration, those gradients from different iterations will not
accumulate to the same set of
Tensors
.
def
run_trainer
():
batch
=
5
ntoken
=
10
ninp
=
2
nhid
=
3
nindices
=
3
nlayers
=
4
hidden
=
(
torch
.
randn
(
nlayers
,
nindices
,
nhid
),
torch
.
randn
(
nlayers
,
nindices
,
nhid
)
)
model
=
rnn
.
RNNModel
(
'ps'
,
ntoken
,
ninp
,
nhid
,
nlayers
)
# setup distributed optimizer
opt
=
DistributedOptimizer
(
optim
.
SGD
,
model
.
parameter_rrefs
(),
lr
=
0.05
,
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
def
get_next_batch
():
for
_
in
range
(
5
):
data
=
torch
.
LongTensor
(
batch
,
nindices
)
%
ntoken
target
=
torch
.
LongTensor
(
batch
,
ntoken
)
%
nindices
yield
data
,
target
# train for 10 iterations
for
epoch
in
range
(
10
):
for
data
,
target
in
get_next_batch
():
# create distributed autograd context
with
dist_autograd
.
context
()
as
context_id
:
hidden
[
0
]
.
detach_
()
hidden
[
1
]
.
detach_
()
output
,
hidden
=
model
(
data
,
hidden
)
loss
=
criterion
(
output
,
target
)
# run distributed backward pass
dist_autograd
.
backward
(
context_id
,
[
loss
])
# run distributed optimizer
opt
.
step
(
context_id
)
# not necessary to zero grads since they are
# accumulated into the distributed autograd context
# which is reset every iteration.
print
(
"Training epoch
{}
"
.
format
(
epoch
))
Finally, letās add some glue code to launch the parameter server and the trainer
processes.
def
run_worker
(
rank
,
world_size
):
os
.
environ
[
'MASTER_ADDR'
]
=
'localhost'
os
.
environ
[
'MASTER_PORT'
]
=
'29500'
if
rank
==
1
:
rpc
.
init_rpc
(
"trainer"
,
rank
=
rank
,
world_size
=
world_size
)
_run_trainer
()
else
:
rpc
.
init_rpc
(
"ps"
,
rank
=
rank
,
world_size
=
world_size
)
# parameter server do nothing
pass
# block until all rpcs finish
rpc
.
shutdown
()
if
__name__
==
"__main__"
:
world_size
=
2
mp
.
spawn
(
run_worker
,
args
=
(
world_size
,
),
nprocs
=
world_size
,
join
=
True
) |
| Markdown | 
Help us understand how you use PyTorch! Take our quick survey. [Take Survey](https://docs.google.com/forms/d/e/1FAIpQLSfsGAWBcfutRcbO6kfrShBMOMmRuBezRjjOcXk0e9I9luBzvQ/viewform)
[Skip to main content](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html#main-content)
Back to top
[ ](https://docs.pytorch.org/tutorials/index.html)
[ ](https://docs.pytorch.org/tutorials/index.html)
[v2.11.0+cu130](https://docs.pytorch.org/tutorials/index.html)
- [Intro](https://docs.pytorch.org/tutorials/intro.html)
- [Learn the Basics](https://docs.pytorch.org/tutorials/beginner/basics/intro.html)
- [Introduction to PyTorch - YouTube Series](https://docs.pytorch.org/tutorials/beginner/introyt/introyt_index.html)
- [Deep Learning with PyTorch: A 60 Minute Blitz](https://docs.pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
- [Learning PyTorch with Examples](https://docs.pytorch.org/tutorials/beginner/pytorch_with_examples.html)
- [What is torch.nn really?](https://docs.pytorch.org/tutorials/beginner/nn_tutorial.html)
- [Understanding requires\_grad, retain\_grad, Leaf, and Non-leaf Tensors](https://docs.pytorch.org/tutorials/beginner/understanding_leaf_vs_nonleaf_tutorial.html)
- [NLP from Scratch](https://docs.pytorch.org/tutorials/intermediate/nlp_from_scratch_index.html)
- [Visualizing Models, Data, and Training with TensorBoard](https://docs.pytorch.org/tutorials/intermediate/tensorboard_tutorial.html)
- [A guide on good usage of non\_blocking and pin\_memory() in PyTorch](https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html)
- [Visualizing Gradients](https://docs.pytorch.org/tutorials/intermediate/visualizing_gradients_tutorial.html)
- [Compilers](https://docs.pytorch.org/tutorials/compilers_index.html)
- [Introduction to torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)
- [torch.compile End-to-End Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_compile_full_example.html)
- [Compiled Autograd: Capturing a larger backward graph for torch.compile](https://docs.pytorch.org/tutorials/intermediate/compiled_autograd_tutorial.html)
- [Inductor CPU backend debugging and profiling](https://docs.pytorch.org/tutorials/intermediate/inductor_debug_cpu.html)
- [Dynamic Compilation Control with torch.compiler.set\_stance](https://docs.pytorch.org/tutorials/recipes/torch_compiler_set_stance_tutorial.html)
- [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html)
- [(beta) Compiling the optimizer with torch.compile](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer.html)
- [(beta) Running the compiled optimizer with an LR Scheduler](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer_lr_scheduler.html)
- [Using Variable Length Attention in PyTorch](https://docs.pytorch.org/tutorials/intermediate/variable_length_attention_tutorial.html)
- [Using User-Defined Triton Kernels with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html)
- [Compile Time Caching in torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html)
- [Reducing torch.compile cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)
- [torch.export Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html)
- [torch.export AOTInductor Tutorial for Python runtime (Beta)](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html)
- [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html)
- [Introduction to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/intro_onnx.html)
- [Export a PyTorch model to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html)
- [Extending the ONNX Exporter Operator Support](https://docs.pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html)
- [Export a model with control flow to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html)
- [Building a Convolution/Batch Norm fuser with torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_conv_bn_fuser.html)
- [(beta) Building a Simple CPU Performance Profiler with FX](https://docs.pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html)
- [Domains](https://docs.pytorch.org/tutorials/domains.html)
- [TorchVision Object Detection Finetuning Tutorial](https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
- [Transfer Learning for Computer Vision Tutorial](https://docs.pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)
- [Adversarial Example Generation](https://docs.pytorch.org/tutorials/beginner/fgsm_tutorial.html)
- [DCGAN Tutorial](https://docs.pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html)
- [Spatial Transformer Networks Tutorial](https://docs.pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html)
- [Reinforcement Learning (DQN) Tutorial](https://docs.pytorch.org/tutorials/intermediate/reinforcement_q_learning.html)
- [Reinforcement Learning (PPO) with TorchRL Tutorial](https://docs.pytorch.org/tutorials/intermediate/reinforcement_ppo.html)
- [Train a Mario-playing RL Agent](https://docs.pytorch.org/tutorials/intermediate/mario_rl_tutorial.html)
- [Pendulum: Writing your environment and transforms with TorchRL](https://docs.pytorch.org/tutorials/advanced/pendulum.html)
- [Introduction to TorchRec](https://docs.pytorch.org/tutorials/intermediate/torchrec_intro_tutorial.html)
- [Exploring TorchRec sharding](https://docs.pytorch.org/tutorials/advanced/sharding.html)
- [Distributed](https://docs.pytorch.org/tutorials/distributed.html)
- [PyTorch Distributed Overview](https://docs.pytorch.org/tutorials/beginner/dist_overview.html)
- [Distributed Data Parallel in PyTorch - Video Tutorials](https://docs.pytorch.org/tutorials/beginner/ddp_series_intro.html)
- [Getting Started with Distributed Data Parallel](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html)
- [Writing Distributed Applications with PyTorch](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html)
- [Getting Started with Fully Sharded Data Parallel (FSDP2)](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- [Introduction to Libuv TCPStore Backend](https://docs.pytorch.org/tutorials/intermediate/TCPStore_libuv_backend.html)
- [Large Scale Transformer model training with Tensor Parallel (TP)](https://docs.pytorch.org/tutorials/intermediate/TP_tutorial.html)
- [Introduction to Distributed Pipeline Parallelism](https://docs.pytorch.org/tutorials/intermediate/pipelining_tutorial.html)
- [Customize Process Group Backends Using Cpp Extensions](https://docs.pytorch.org/tutorials/intermediate/process_group_cpp_extension_tutorial.html)
- [Getting Started with Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html)
- [Implementing a Parameter Server Using Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html)
- [Implementing Batch RPC Processing Using Asynchronous Executions](https://docs.pytorch.org/tutorials/intermediate/rpc_async_execution.html)
- [Interactive Distributed Applications with Monarch](https://docs.pytorch.org/tutorials/intermediate/monarch_distributed_tutorial.html)
- [Combining Distributed DataParallel with Distributed RPC Framework](https://docs.pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html)
- [Distributed Training with Uneven Inputs Using the Join Context Manager](https://docs.pytorch.org/tutorials/advanced/generic_join.html)
- [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html)
- [Deep Dive](https://docs.pytorch.org/tutorials/deep-dive.html)
- [Profiling your PyTorch Module](https://docs.pytorch.org/tutorials/beginner/profiler.html)
- [Parametrizations Tutorial](https://docs.pytorch.org/tutorials/intermediate/parametrizations.html)
- [Pruning Tutorial](https://docs.pytorch.org/tutorials/intermediate/pruning_tutorial.html)
- [Inductor CPU backend debugging and profiling](https://docs.pytorch.org/tutorials/intermediate/inductor_debug_cpu.html)
- [(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html)
- [Knowledge Distillation Tutorial](https://docs.pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html)
- [Channels Last Memory Format in PyTorch](https://docs.pytorch.org/tutorials/intermediate/memory_format_tutorial.html)
- [Forward-mode Automatic Differentiation (Beta)](https://docs.pytorch.org/tutorials/intermediate/forward_ad_usage.html)
- [Jacobians, Hessians, hvp, vhp, and more: composing function transforms](https://docs.pytorch.org/tutorials/intermediate/jacobians_hessians.html)
- [Model ensembling](https://docs.pytorch.org/tutorials/intermediate/ensembling.html)
- [Per-sample-gradients](https://docs.pytorch.org/tutorials/intermediate/per_sample_grads.html)
- [Using the PyTorch C++ Frontend](https://docs.pytorch.org/tutorials/advanced/cpp_frontend.html)
- [Autograd in C++ Frontend](https://docs.pytorch.org/tutorials/advanced/cpp_autograd.html)
- [Extension](https://docs.pytorch.org/tutorials/extension.html)
- [PyTorch Custom Operators](https://docs.pytorch.org/tutorials/advanced/custom_ops_landing_page.html)
- [Custom Python Operators](https://docs.pytorch.org/tutorials/advanced/python_custom_ops.html)
- [Custom C++ and CUDA Operators](https://docs.pytorch.org/tutorials/advanced/cpp_custom_ops.html)
- [Double Backward with Custom Functions](https://docs.pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html)
- [Fusing Convolution and Batch Norm using Custom Function](https://docs.pytorch.org/tutorials/intermediate/custom_function_conv_bn_tutorial.html)
- [Registering a Dispatched Operator in C++](https://docs.pytorch.org/tutorials/advanced/dispatcher.html)
- [Extending dispatcher for a new backend in C++](https://docs.pytorch.org/tutorials/advanced/extend_dispatcher.html)
- [Facilitating New Backend Integration by PrivateUse1](https://docs.pytorch.org/tutorials/advanced/privateuseone.html)
- [Ecosystem](https://docs.pytorch.org/tutorials/ecosystem.html)
- [Hyperparameter tuning using Ray Tune](https://docs.pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html)
- [Serve PyTorch models at scale with Ray Serve](https://docs.pytorch.org/tutorials/beginner/serving_tutorial.html)
- [Multi-Objective NAS with Ax](https://docs.pytorch.org/tutorials/intermediate/ax_multiobjective_nas_tutorial.html)
- [PyTorch Profiler With TensorBoard](https://docs.pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html)
- [Real Time Inference on Raspberry Pi 4 and 5 (40 fps!)](https://docs.pytorch.org/tutorials/intermediate/realtime_rpi.html)
- [Mosaic: Memory Profiling for PyTorch](https://docs.pytorch.org/tutorials/beginner/mosaic_memory_profiling_tutorial.html)
- [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html)
- More
- [Recipes](https://docs.pytorch.org/tutorials/recipes_index.html)
- [Unstable](https://docs.pytorch.org/tutorials/unstable_index.html)
[Go to pytorch.org](https://pytorch.org/)
- [X](https://x.com/PyTorch)
- [GitHub](https://github.com/pytorch/tutorials)
- [Discourse](https://dev-discuss.pytorch.org/)
- [PyPi](https://pypi.org/project/torch/)
[v2.11.0+cu130](https://docs.pytorch.org/tutorials/index.html)
- [Intro](https://docs.pytorch.org/tutorials/intro.html)
- [Learn the Basics](https://docs.pytorch.org/tutorials/beginner/basics/intro.html)
- [Introduction to PyTorch - YouTube Series](https://docs.pytorch.org/tutorials/beginner/introyt/introyt_index.html)
- [Deep Learning with PyTorch: A 60 Minute Blitz](https://docs.pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
- [Learning PyTorch with Examples](https://docs.pytorch.org/tutorials/beginner/pytorch_with_examples.html)
- [What is torch.nn really?](https://docs.pytorch.org/tutorials/beginner/nn_tutorial.html)
- [Understanding requires\_grad, retain\_grad, Leaf, and Non-leaf Tensors](https://docs.pytorch.org/tutorials/beginner/understanding_leaf_vs_nonleaf_tutorial.html)
- [NLP from Scratch](https://docs.pytorch.org/tutorials/intermediate/nlp_from_scratch_index.html)
- [Visualizing Models, Data, and Training with TensorBoard](https://docs.pytorch.org/tutorials/intermediate/tensorboard_tutorial.html)
- [A guide on good usage of non\_blocking and pin\_memory() in PyTorch](https://docs.pytorch.org/tutorials/intermediate/pinmem_nonblock.html)
- [Visualizing Gradients](https://docs.pytorch.org/tutorials/intermediate/visualizing_gradients_tutorial.html)
- [Compilers](https://docs.pytorch.org/tutorials/compilers_index.html)
- [Introduction to torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)
- [torch.compile End-to-End Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_compile_full_example.html)
- [Compiled Autograd: Capturing a larger backward graph for torch.compile](https://docs.pytorch.org/tutorials/intermediate/compiled_autograd_tutorial.html)
- [Inductor CPU backend debugging and profiling](https://docs.pytorch.org/tutorials/intermediate/inductor_debug_cpu.html)
- [Dynamic Compilation Control with torch.compiler.set\_stance](https://docs.pytorch.org/tutorials/recipes/torch_compiler_set_stance_tutorial.html)
- [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html)
- [(beta) Compiling the optimizer with torch.compile](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer.html)
- [(beta) Running the compiled optimizer with an LR Scheduler](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer_lr_scheduler.html)
- [Using Variable Length Attention in PyTorch](https://docs.pytorch.org/tutorials/intermediate/variable_length_attention_tutorial.html)
- [Using User-Defined Triton Kernels with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html)
- [Compile Time Caching in torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html)
- [Reducing torch.compile cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)
- [torch.export Tutorial](https://docs.pytorch.org/tutorials/intermediate/torch_export_tutorial.html)
- [torch.export AOTInductor Tutorial for Python runtime (Beta)](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html)
- [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html)
- [Introduction to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/intro_onnx.html)
- [Export a PyTorch model to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/export_simple_model_to_onnx_tutorial.html)
- [Extending the ONNX Exporter Operator Support](https://docs.pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html)
- [Export a model with control flow to ONNX](https://docs.pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html)
- [Building a Convolution/Batch Norm fuser with torch.compile](https://docs.pytorch.org/tutorials/intermediate/torch_compile_conv_bn_fuser.html)
- [(beta) Building a Simple CPU Performance Profiler with FX](https://docs.pytorch.org/tutorials/intermediate/fx_profiling_tutorial.html)
- [Domains](https://docs.pytorch.org/tutorials/domains.html)
- [TorchVision Object Detection Finetuning Tutorial](https://docs.pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
- [Transfer Learning for Computer Vision Tutorial](https://docs.pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)
- [Adversarial Example Generation](https://docs.pytorch.org/tutorials/beginner/fgsm_tutorial.html)
- [DCGAN Tutorial](https://docs.pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html)
- [Spatial Transformer Networks Tutorial](https://docs.pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html)
- [Reinforcement Learning (DQN) Tutorial](https://docs.pytorch.org/tutorials/intermediate/reinforcement_q_learning.html)
- [Reinforcement Learning (PPO) with TorchRL Tutorial](https://docs.pytorch.org/tutorials/intermediate/reinforcement_ppo.html)
- [Train a Mario-playing RL Agent](https://docs.pytorch.org/tutorials/intermediate/mario_rl_tutorial.html)
- [Pendulum: Writing your environment and transforms with TorchRL](https://docs.pytorch.org/tutorials/advanced/pendulum.html)
- [Introduction to TorchRec](https://docs.pytorch.org/tutorials/intermediate/torchrec_intro_tutorial.html)
- [Exploring TorchRec sharding](https://docs.pytorch.org/tutorials/advanced/sharding.html)
- [Distributed](https://docs.pytorch.org/tutorials/distributed.html)
- [PyTorch Distributed Overview](https://docs.pytorch.org/tutorials/beginner/dist_overview.html)
- [Distributed Data Parallel in PyTorch - Video Tutorials](https://docs.pytorch.org/tutorials/beginner/ddp_series_intro.html)
- [Getting Started with Distributed Data Parallel](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html)
- [Writing Distributed Applications with PyTorch](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html)
- [Getting Started with Fully Sharded Data Parallel (FSDP2)](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- [Introduction to Libuv TCPStore Backend](https://docs.pytorch.org/tutorials/intermediate/TCPStore_libuv_backend.html)
- [Large Scale Transformer model training with Tensor Parallel (TP)](https://docs.pytorch.org/tutorials/intermediate/TP_tutorial.html)
- [Introduction to Distributed Pipeline Parallelism](https://docs.pytorch.org/tutorials/intermediate/pipelining_tutorial.html)
- [Customize Process Group Backends Using Cpp Extensions](https://docs.pytorch.org/tutorials/intermediate/process_group_cpp_extension_tutorial.html)
- [Getting Started with Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html)
- [Implementing a Parameter Server Using Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html)
- [Implementing Batch RPC Processing Using Asynchronous Executions](https://docs.pytorch.org/tutorials/intermediate/rpc_async_execution.html)
- [Interactive Distributed Applications with Monarch](https://docs.pytorch.org/tutorials/intermediate/monarch_distributed_tutorial.html)
- [Combining Distributed DataParallel with Distributed RPC Framework](https://docs.pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html)
- [Distributed Training with Uneven Inputs Using the Join Context Manager](https://docs.pytorch.org/tutorials/advanced/generic_join.html)
- [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html)
- [Deep Dive](https://docs.pytorch.org/tutorials/deep-dive.html)
- [Profiling your PyTorch Module](https://docs.pytorch.org/tutorials/beginner/profiler.html)
- [Parametrizations Tutorial](https://docs.pytorch.org/tutorials/intermediate/parametrizations.html)
- [Pruning Tutorial](https://docs.pytorch.org/tutorials/intermediate/pruning_tutorial.html)
- [Inductor CPU backend debugging and profiling](https://docs.pytorch.org/tutorials/intermediate/inductor_debug_cpu.html)
- [(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html)
- [Knowledge Distillation Tutorial](https://docs.pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html)
- [Channels Last Memory Format in PyTorch](https://docs.pytorch.org/tutorials/intermediate/memory_format_tutorial.html)
- [Forward-mode Automatic Differentiation (Beta)](https://docs.pytorch.org/tutorials/intermediate/forward_ad_usage.html)
- [Jacobians, Hessians, hvp, vhp, and more: composing function transforms](https://docs.pytorch.org/tutorials/intermediate/jacobians_hessians.html)
- [Model ensembling](https://docs.pytorch.org/tutorials/intermediate/ensembling.html)
- [Per-sample-gradients](https://docs.pytorch.org/tutorials/intermediate/per_sample_grads.html)
- [Using the PyTorch C++ Frontend](https://docs.pytorch.org/tutorials/advanced/cpp_frontend.html)
- [Autograd in C++ Frontend](https://docs.pytorch.org/tutorials/advanced/cpp_autograd.html)
- [Extension](https://docs.pytorch.org/tutorials/extension.html)
- [PyTorch Custom Operators](https://docs.pytorch.org/tutorials/advanced/custom_ops_landing_page.html)
- [Custom Python Operators](https://docs.pytorch.org/tutorials/advanced/python_custom_ops.html)
- [Custom C++ and CUDA Operators](https://docs.pytorch.org/tutorials/advanced/cpp_custom_ops.html)
- [Double Backward with Custom Functions](https://docs.pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html)
- [Fusing Convolution and Batch Norm using Custom Function](https://docs.pytorch.org/tutorials/intermediate/custom_function_conv_bn_tutorial.html)
- [Registering a Dispatched Operator in C++](https://docs.pytorch.org/tutorials/advanced/dispatcher.html)
- [Extending dispatcher for a new backend in C++](https://docs.pytorch.org/tutorials/advanced/extend_dispatcher.html)
- [Facilitating New Backend Integration by PrivateUse1](https://docs.pytorch.org/tutorials/advanced/privateuseone.html)
- [Ecosystem](https://docs.pytorch.org/tutorials/ecosystem.html)
- [Hyperparameter tuning using Ray Tune](https://docs.pytorch.org/tutorials/beginner/hyperparameter_tuning_tutorial.html)
- [Serve PyTorch models at scale with Ray Serve](https://docs.pytorch.org/tutorials/beginner/serving_tutorial.html)
- [Multi-Objective NAS with Ax](https://docs.pytorch.org/tutorials/intermediate/ax_multiobjective_nas_tutorial.html)
- [PyTorch Profiler With TensorBoard](https://docs.pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html)
- [Real Time Inference on Raspberry Pi 4 and 5 (40 fps!)](https://docs.pytorch.org/tutorials/intermediate/realtime_rpi.html)
- [Mosaic: Memory Profiling for PyTorch](https://docs.pytorch.org/tutorials/beginner/mosaic_memory_profiling_tutorial.html)
- [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html)
- [Recipes](https://docs.pytorch.org/tutorials/recipes_index.html)
- [Defining a Neural Network in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/defining_a_neural_network.html)
- [(beta) Using TORCH\_LOGS python API with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_logs.html)
- [What is a state\_dict in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html)
- [Warmstarting model using parameters from a different model in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/warmstarting_model_using_parameters_from_a_different_model.html)
- [Zeroing out gradients in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/zeroing_out_gradients.html)
- [PyTorch Profiler](https://docs.pytorch.org/tutorials/recipes/recipes/profiler_recipe.html)
- [Model Interpretability using Captum](https://docs.pytorch.org/tutorials/recipes/recipes/Captum_Recipe.html)
- [How to use TensorBoard with PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html)
- [Automatic Mixed Precision](https://docs.pytorch.org/tutorials/recipes/recipes/amp_recipe.html)
- [Performance Tuning Guide](https://docs.pytorch.org/tutorials/recipes/recipes/tuning_guide.html)
- [(beta) Compiling the optimizer with torch.compile](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer.html)
- [Timer quick start](https://docs.pytorch.org/tutorials/recipes/recipes/timer_quick_start.html)
- [Shard Optimizer States with ZeroRedundancyOptimizer](https://docs.pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html)
- [Getting Started with CommDebugMode](https://docs.pytorch.org/tutorials/recipes/distributed_comm_debug_mode.html)
- [Demonstration of torch.export flow, common challenges and the solutions to address them](https://docs.pytorch.org/tutorials/recipes/torch_export_challenges_solutions.html)
- [PyTorch Benchmark](https://docs.pytorch.org/tutorials/recipes/recipes/benchmark.html)
- [Tips for Loading an nn.Module from a Checkpoint](https://docs.pytorch.org/tutorials/recipes/recipes/module_load_state_dict_tips.html)
- [Reasoning about Shapes in PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/reasoning_about_shapes.html)
- [Extension points in nn.Module for load\_state\_dict and tensor subclasses](https://docs.pytorch.org/tutorials/recipes/recipes/swap_tensors.html)
- [torch.export AOTInductor Tutorial for Python runtime (Beta)](https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html)
- [How to use TensorBoard with PyTorch](https://docs.pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html)
- [(beta) Utilizing Torch Function modes with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_torch_function_modes.html)
- [(beta) Running the compiled optimizer with an LR Scheduler](https://docs.pytorch.org/tutorials/recipes/compiling_optimizer_lr_scheduler.html)
- [Explicit horizontal fusion with foreach\_map and torch.compile](https://docs.pytorch.org/tutorials/recipes/foreach_map.html)
- [Using User-Defined Triton Kernels with torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html)
- [Compile Time Caching in torch.compile](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html)
- [Compile Time Caching Configuration](https://docs.pytorch.org/tutorials/recipes/torch_compile_caching_configuration_tutorial.html)
- [Reducing torch.compile cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html)
- [Reducing AoT cold start compilation time with regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_aot.html)
- [Ease-of-use quantization for PyTorch with IntelĀ® Neural Compressor](https://docs.pytorch.org/tutorials/recipes/intel_neural_compressor_for_pytorch.html)
- [Getting Started with DeviceMesh](https://docs.pytorch.org/tutorials/recipes/distributed_device_mesh.html)
- [Getting Started with Distributed Checkpoint (DCP)](https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html)
- [Asynchronous Saving with Distributed Checkpoint (DCP)](https://docs.pytorch.org/tutorials/recipes/distributed_async_checkpoint_recipe.html)
- [DebugMode: Recording Dispatched Operations and Numerical Debugging](https://docs.pytorch.org/tutorials/recipes/debug_mode_tutorial.html)
- [Unstable](https://docs.pytorch.org/tutorials/unstable_index.html)
- [Introduction to Context Parallel](https://docs.pytorch.org/tutorials/unstable/context_parallel.html)
- [Flight Recorder for Debugging Stuck Jobs](https://docs.pytorch.org/tutorials/unstable/flight_recorder_tutorial.html)
- [TorchInductor C++ Wrapper Tutorial](https://docs.pytorch.org/tutorials/unstable/inductor_cpp_wrapper_tutorial.html)
- [How to use torch.compile on Windows CPU/XPU](https://docs.pytorch.org/tutorials/unstable/inductor_windows.html)
- [torch.vmap](https://docs.pytorch.org/tutorials/unstable/vmap_recipe.html)
- [Getting Started with Nested Tensors](https://docs.pytorch.org/tutorials/unstable/nestedtensor.html)
- [MaskedTensor Overview](https://docs.pytorch.org/tutorials/unstable/maskedtensor_overview.html)
- [MaskedTensor Sparsity](https://docs.pytorch.org/tutorials/unstable/maskedtensor_sparsity.html)
- [MaskedTensor Advanced Semantics](https://docs.pytorch.org/tutorials/unstable/maskedtensor_advanced_semantics.html)
- [Efficiently writing āsparseā semantics for Adagrad with MaskedTensor](https://docs.pytorch.org/tutorials/unstable/maskedtensor_adagrad.html)
- [Autoloading Out-of-Tree Extension](https://docs.pytorch.org/tutorials/unstable/python_extension_autoload.html)
- [Using Max-Autotune Compilation on CPU for Better Performance](https://docs.pytorch.org/tutorials/unstable/max_autotune_on_CPU_tutorial.html)
[Go to pytorch.org](https://pytorch.org/)
- [X](https://x.com/PyTorch)
- [GitHub](https://github.com/pytorch/tutorials)
- [Discourse](https://dev-discuss.pytorch.org/)
- [PyPi](https://pypi.org/project/torch/)
Section Navigation
- [PyTorch Distributed Overview](https://docs.pytorch.org/tutorials/beginner/dist_overview.html)
- [Distributed Data Parallel in PyTorch - Video Tutorials](https://docs.pytorch.org/tutorials/beginner/ddp_series_intro.html)
- [Getting Started with Distributed Data Parallel](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html)
- [Writing Distributed Applications with PyTorch](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html)
- [Getting Started with Fully Sharded Data Parallel (FSDP2)](https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
- [Introduction to Libuv TCPStore Backend](https://docs.pytorch.org/tutorials/intermediate/TCPStore_libuv_backend.html)
- [Large Scale Transformer model training with Tensor Parallel (TP)](https://docs.pytorch.org/tutorials/intermediate/TP_tutorial.html)
- [Introduction to Distributed Pipeline Parallelism](https://docs.pytorch.org/tutorials/intermediate/pipelining_tutorial.html)
- [Customize Process Group Backends Using Cpp Extensions](https://docs.pytorch.org/tutorials/intermediate/process_group_cpp_extension_tutorial.html)
- [Getting Started with Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html)
- [Implementing a Parameter Server Using Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html)
- [Implementing Batch RPC Processing Using Asynchronous Executions](https://docs.pytorch.org/tutorials/intermediate/rpc_async_execution.html)
- [Interactive Distributed Applications with Monarch](https://docs.pytorch.org/tutorials/intermediate/monarch_distributed_tutorial.html)
- [Combining Distributed DataParallel with Distributed RPC Framework](https://docs.pytorch.org/tutorials/advanced/rpc_ddp_tutorial.html)
- [Distributed Training with Uneven Inputs Using the Join Context Manager](https://docs.pytorch.org/tutorials/advanced/generic_join.html)
- [Distributed training at scale with PyTorch and Ray Train](https://docs.pytorch.org/tutorials/beginner/distributed_training_with_ray_tutorial.html)
- [Distributed](https://docs.pytorch.org/tutorials/distributed.html)
- Getting...
Rate this Page
ā
ā
ā
ā
ā
intermediate/rpc\_tutorial
[ Run in Google Colab Colab]()
[ Download Notebook Notebook]()
[ View on GitHub GitHub]()
# Getting Started with Distributed RPC Framework[\#](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html#getting-started-with-distributed-rpc-framework "Link to this heading")
Created On: Jan 01, 2020 \| Last Updated: Sep 03, 2025 \| Last Verified: Nov 05, 2024
**Author**: [Shen Li](https://mrshenli.github.io/)
Note
[](https://docs.pytorch.org/tutorials/_images/pencil-16.png) View and edit this tutorial in [github](https://github.com/pytorch/tutorials/blob/main/intermediate_source/rpc_tutorial.rst).
Prerequisites:
- [PyTorch Distributed Overview](https://docs.pytorch.org/tutorials/beginner/dist_overview.html)
- [RPC API documents](https://pytorch.org/docs/master/rpc.html)
This tutorial uses two simple examples to demonstrate how to build distributed training with the [torch.distributed.rpc](https://pytorch.org/docs/stable/rpc.html) package which was first introduced as an experimental feature in PyTorch v1.4. Source code of the two examples can be found in [PyTorch examples](https://github.com/pytorch/examples).
Previous tutorials, [Getting Started With Distributed Data Parallel](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html) and [Writing Distributed Applications With PyTorch](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html), described [DistributedDataParallel](https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which supports a specific training paradigm where the model is replicated across multiple processes and each process handles a split of the input data. Sometimes, you might run into scenarios that require different training paradigms. For example:
1. In reinforcement learning, it might be relatively expensive to acquire training data from environments while the model itself can be quite small. In this case, it might be useful to spawn multiple observers running in parallel and share a single agent. In this case, the agent takes care of the training locally, but the application would still need libraries to send and receive data between observers and the trainer.
2. Your model might be too large to fit in GPUs on a single machine, and hence would need a library to help split the model onto multiple machines. Or you might be implementing a [parameter server](https://www.cs.cmu.edu/~muli/file/parameter_server_osdi14.pdf) training framework, where model parameters and trainers live on different machines.
The [torch.distributed.rpc](https://pytorch.org/docs/stable/rpc.html) package can help with the above scenarios. In case 1, [RPC](https://pytorch.org/docs/stable/rpc.html#rpc) and [RRef](https://pytorch.org/docs/stable/rpc.html#rref) allow sending data from one worker to another while easily referencing remote data objects. In case 2, [distributed autograd](https://pytorch.org/docs/stable/rpc.html#distributed-autograd-framework) and [distributed optimizer](https://pytorch.org/docs/stable/rpc.html#module-torch.distributed.optim) make executing backward pass and optimizer step as if it is local training. In the next two sections, we will demonstrate APIs of [torch.distributed.rpc](https://pytorch.org/docs/stable/rpc.html) using a reinforcement learning example and a language model example. Please note, this tutorial does not aim at building the most accurate or efficient models to solve given problems, instead, the main goal here is to show how to use the [torch.distributed.rpc](https://pytorch.org/docs/stable/rpc.html) package to build distributed training applications.
## Distributed Reinforcement Learning using RPC and RRef[\#](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html#distributed-reinforcement-learning-using-rpc-and-rref "Link to this heading")
This section describes steps to build a toy distributed reinforcement learning model using RPC to solve CartPole-v1 from [OpenAI Gym](https://www.gymlibrary.dev/environments/classic_control/cart_pole/). The policy code is mostly borrowed from the existing single-thread [example](https://github.com/pytorch/examples/blob/master/reinforcement_learning) as shown below. We will skip details of the `Policy` design, and focus on RPC usages.
```
import torch.nn as nn
import torch.nn.functional as F
class Policy(nn.Module):
def __init__(self):
super(Policy, self).__init__()
self.affine1 = nn.Linear(4, 128)
self.dropout = nn.Dropout(p=0.6)
self.affine2 = nn.Linear(128, 2)
def forward(self, x):
x = self.affine1(x)
x = self.dropout(x)
x = F.relu(x)
action_scores = self.affine2(x)
return F.softmax(action_scores, dim=1)
```
We are ready to present the observer. In this example, each observer creates its own environment, and waits for the agentās command to run an episode. In each episode, one observer loops at most `n_steps` iterations, and in each iteration, it uses RPC to pass its environment state to the agent and gets an action back. Then it applies that action to its environment, and gets the reward and the next state from the environment. After that, the observer uses another RPC to report the reward to the agent. Again, please note that, this is obviously not the most efficient observer implementation. For example, one simple optimization could be packing current state and last reward in one RPC to reduce the communication overhead. However, the goal is to demonstrate RPC API instead of building the best solver for CartPole. So, letās keep the logic simple and the two steps explicit in this example.
```
import argparse
import gym
import torch.distributed.rpc as rpc
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--world_size', default=2, type=int, metavar='W',
help='number of workers')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
args = parser.parse_args()
class Observer:
def __init__(self):
self.id = rpc.get_worker_info().id
self.env = gym.make('CartPole-v1')
self.env.seed(args.seed)
def run_episode(self, agent_rref):
state, ep_reward = self.env.reset(), 0
for _ in range(10000):
# send the state to the agent to get an action
action = agent_rref.rpc_sync().select_action(self.id, state)
# apply the action to the environment, and get the reward
state, reward, done, _ = self.env.step(action)
# report the reward to the agent for training purpose
agent_rref.rpc_sync().report_reward(self.id, reward)
# finishes after the number of self.env._max_episode_steps
if done:
break
```
The code for agent is a little more complex, and we will break it into multiple pieces. In this example, the agent serves as both the trainer and the master, such that it sends command to multiple distributed observers to run episodes, and it also records all actions and rewards locally which will be used during the training phase after each episode. The code below shows `Agent` constructor where most lines are initializing various components. The loop at the end initializes observers remotely on other workers, and holds `RRefs` to those observers locally. The agent will use those observer `RRefs` later to send commands. Applications donāt need to worry about the lifetime of `RRefs`. The owner of each `RRef` maintains a reference counting map to track its lifetime, and guarantees the remote data object will not be deleted as long as there is any live user of that `RRef`. Please refer to the `RRef` [design doc](https://pytorch.org/docs/stable/rpc/rref.html) for details.
```
import gym
import numpy as np
import torch
import torch.distributed.rpc as rpc
import torch.optim as optim
from torch.distributed.rpc import RRef, rpc_async, remote
from torch.distributions import Categorical
class Agent:
def __init__(self, world_size):
self.ob_rrefs = []
self.agent_rref = RRef(self)
self.rewards = {}
self.saved_log_probs = {}
self.policy = Policy()
self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
self.eps = np.finfo(np.float32).eps.item()
self.running_reward = 0
self.reward_threshold = gym.make('CartPole-v1').spec.reward_threshold
for ob_rank in range(1, world_size):
ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank))
self.ob_rrefs.append(remote(ob_info, Observer))
self.rewards[ob_info.id] = []
self.saved_log_probs[ob_info.id] = []
```
Next, the agent exposes two APIs to observers for selecting actions and reporting rewards. Those functions only run locally on the agent, but will be triggered by observers through RPC.
```
class Agent:
...
def select_action(self, ob_id, state):
state = torch.from_numpy(state).float().unsqueeze(0)
probs = self.policy(state)
m = Categorical(probs)
action = m.sample()
self.saved_log_probs[ob_id].append(m.log_prob(action))
return action.item()
def report_reward(self, ob_id, reward):
self.rewards[ob_id].append(reward)
```
Letās add a `run_episode` function on agent which tells all observers to execute an episode. In this function, it first creates a list to collect futures from asynchronous RPCs, and then loop over all observer `RRefs` to make asynchronous RPCs. In these RPCs, the agent also passes an `RRef` of itself to the observer, so that the observer can call functions on the agent as well. As shown above, each observer will make RPCs back to the agent, which are nested RPCs. After each episode, the `saved_log_probs` and `rewards` will contain the recorded action probs and rewards.
```
class Agent:
...
def run_episode(self):
futs = []
for ob_rref in self.ob_rrefs:
# make async RPC to kick off an episode on all observers
futs.append(
rpc_async(
ob_rref.owner(),
ob_rref.rpc_sync().run_episode,
args=(self.agent_rref,)
)
)
# wait until all obervers have finished this episode
for fut in futs:
fut.wait()
```
Finally, after one episode, the agent needs to train the model, which is implemented in the `finish_episode` function below. There is no RPCs in this function and it is mostly borrowed from the single-thread [example](https://github.com/pytorch/examples/blob/master/reinforcement_learning). Hence, we skip describing its contents.
```
class Agent:
...
def finish_episode(self):
# joins probs and rewards from different observers into lists
R, probs, rewards = 0, [], []
for ob_id in self.rewards:
probs.extend(self.saved_log_probs[ob_id])
rewards.extend(self.rewards[ob_id])
# use the minimum observer reward to calculate the running reward
min_reward = min([sum(self.rewards[ob_id]) for ob_id in self.rewards])
self.running_reward = 0.05 * min_reward + (1 - 0.05) * self.running_reward
# clear saved probs and rewards
for ob_id in self.rewards:
self.rewards[ob_id] = []
self.saved_log_probs[ob_id] = []
policy_loss, returns = [], []
for r in rewards[::-1]:
R = r + args.gamma * R
returns.insert(0, R)
returns = torch.tensor(returns)
returns = (returns - returns.mean()) / (returns.std() + self.eps)
for log_prob, R in zip(probs, returns):
policy_loss.append(-log_prob * R)
self.optimizer.zero_grad()
policy_loss = torch.cat(policy_loss).sum()
policy_loss.backward()
self.optimizer.step()
return min_reward
```
With `Policy`, `Observer`, and `Agent` classes, we are ready to launch multiple processes to perform the distributed training. In this example, all processes run the same `run_worker` function, and they use the rank to distinguish their role. Rank 0 is always the agent, and all other ranks are observers. The agent serves as master by repeatedly calling `run_episode` and `finish_episode` until the running reward surpasses the reward threshold specified by the environment. All observers passively waiting for commands from the agent. The code is wrapped by [rpc.init\_rpc](https://pytorch.org/docs/stable/rpc.html#torch.distributed.rpc.init_rpc) and [rpc.shutdown](https://pytorch.org/docs/stable/rpc.html#torch.distributed.rpc.shutdown), which initializes and terminates RPC instances respectively. More details are available in the [API page](https://pytorch.org/docs/stable/rpc.html).
```
import os
from itertools import count
import torch.multiprocessing as mp
AGENT_NAME = "agent"
OBSERVER_NAME="obs{}"
def run_worker(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
if rank == 0:
# rank0 is the agent
rpc.init_rpc(AGENT_NAME, rank=rank, world_size=world_size)
agent = Agent(world_size)
print(f"This will run until reward threshold of {agent.reward_threshold}"
" is reached. Ctrl+C to exit.")
for i_episode in count(1):
agent.run_episode()
last_reward = agent.finish_episode()
if i_episode % args.log_interval == 0:
print(f"Episode {i_episode}\tLast reward: {last_reward:.2f}\tAverage reward: "
f"{agent.running_reward:.2f}")
if agent.running_reward > agent.reward_threshold:
print(f"Solved! Running reward is now {agent.running_reward}!")
break
else:
# other ranks are the observer
rpc.init_rpc(OBSERVER_NAME.format(rank), rank=rank, world_size=world_size)
# observers passively waiting for instructions from the agent
# block until all rpcs finish, and shutdown the RPC instance
rpc.shutdown()
mp.spawn(
run_worker,
args=(args.world_size, ),
nprocs=args.world_size,
join=True
)
```
Below are some sample outputs when training with world\_size=2.
```
This will run until reward threshold of 475.0 is reached. Ctrl+C to exit.
Episode 10 Last reward: 26.00 Average reward: 10.01
Episode 20 Last reward: 16.00 Average reward: 11.27
Episode 30 Last reward: 49.00 Average reward: 18.62
Episode 40 Last reward: 45.00 Average reward: 26.09
Episode 50 Last reward: 44.00 Average reward: 30.03
Episode 60 Last reward: 111.00 Average reward: 42.23
Episode 70 Last reward: 131.00 Average reward: 70.11
Episode 80 Last reward: 87.00 Average reward: 76.51
Episode 90 Last reward: 86.00 Average reward: 95.93
Episode 100 Last reward: 13.00 Average reward: 123.93
Episode 110 Last reward: 33.00 Average reward: 91.39
Episode 120 Last reward: 73.00 Average reward: 76.38
Episode 130 Last reward: 137.00 Average reward: 88.08
Episode 140 Last reward: 89.00 Average reward: 104.96
Episode 150 Last reward: 97.00 Average reward: 98.74
Episode 160 Last reward: 150.00 Average reward: 100.87
Episode 170 Last reward: 126.00 Average reward: 104.38
Episode 180 Last reward: 500.00 Average reward: 213.74
Episode 190 Last reward: 322.00 Average reward: 300.22
Episode 200 Last reward: 165.00 Average reward: 272.71
Episode 210 Last reward: 168.00 Average reward: 233.11
Episode 220 Last reward: 184.00 Average reward: 195.02
Episode 230 Last reward: 284.00 Average reward: 208.32
Episode 240 Last reward: 395.00 Average reward: 247.37
Episode 250 Last reward: 500.00 Average reward: 335.42
Episode 260 Last reward: 500.00 Average reward: 386.30
Episode 270 Last reward: 500.00 Average reward: 405.29
Episode 280 Last reward: 500.00 Average reward: 443.29
Episode 290 Last reward: 500.00 Average reward: 464.65
Solved! Running reward is now 475.3163778435275!
```
In this example, we show how to use RPC as the communication vehicle to pass data across workers, and how to use RRef to reference remote objects. It is true that you could build the entire structure directly on top of `ProcessGroup` `send` and `recv` APIs or use other communication/RPC libraries. However, by using torch.distributed.rpc, you can get the native support and continuously optimized performance under the hood.
Next, we will show how to combine RPC and RRef with distributed autograd and distributed optimizer to perform distributed model parallel training.
## Distributed RNN using Distributed Autograd and Distributed Optimizer[\#](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html#distributed-rnn-using-distributed-autograd-and-distributed-optimizer "Link to this heading")
In this section, we use an RNN model to show how to build distributed model parallel training with the RPC API. The example RNN model is very small and can easily fit into a single GPU, but we still divide its layers onto two different workers to demonstrate the idea. Developer can apply the similar techniques to distribute much larger models across multiple devices and machines.
The RNN model design is borrowed from the word language model in PyTorch [example](https://github.com/pytorch/examples/tree/master/word_language_model) repository, which contains three main components, an embedding table, an `LSTM` layer, and a decoder. The code below wraps the embedding table and the decoder into sub-modules, so that their constructors can be passed to the RPC API. In the `EmbeddingTable` sub-module, we intentionally put the `Embedding` layer on GPU to cover the use case. In v1.4, RPC always creates CPU tensor arguments or return values on the destination worker. If the function takes a GPU tensor, you need to move it to the proper device explicitly.
```
class EmbeddingTable(nn.Module):
r"""
Encoding layers of the RNNModel
"""
def __init__(self, ntoken, ninp, dropout):
super(EmbeddingTable, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp).cuda()
self.encoder.weight.data.uniform_(-0.1, 0.1)
def forward(self, input):
return self.drop(self.encoder(input.cuda()).cpu()
class Decoder(nn.Module):
def __init__(self, ntoken, nhid, dropout):
super(Decoder, self).__init__()
self.drop = nn.Dropout(dropout)
self.decoder = nn.Linear(nhid, ntoken)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-0.1, 0.1)
def forward(self, output):
return self.decoder(self.drop(output))
```
With the above sub-modules, we can now piece them together using RPC to create an RNN model. In the code below `ps` represents a parameter server, which hosts parameters of the embedding table and the decoder. The constructor uses the [remote](https://pytorch.org/docs/stable/rpc.html#torch.distributed.rpc.remote) API to create an `EmbeddingTable` object and a `Decoder` object on the parameter server, and locally creates the `LSTM` sub-module. During the forward pass, the trainer uses the `EmbeddingTable` `RRef` to find the remote sub-module and passes the input data to the `EmbeddingTable` using RPC and fetches the lookup results. Then, it runs the embedding through the local `LSTM` layer, and finally uses another RPC to send the output to the `Decoder` sub-module. In general, to implement distributed model parallel training, developers can divide the model into sub-modules, invoke RPC to create sub-module instances remotely, and use on `RRef` to find them when necessary. As you can see in the code below, it looks very similar to single-machine model parallel training. The main difference is replacing `Tensor.to(device)` with RPC functions.
```
class RNNModel(nn.Module):
def __init__(self, ps, ntoken, ninp, nhid, nlayers, dropout=0.5):
super(RNNModel, self).__init__()
# setup embedding table remotely
self.emb_table_rref = rpc.remote(ps, EmbeddingTable, args=(ntoken, ninp, dropout))
# setup LSTM locally
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
# setup decoder remotely
self.decoder_rref = rpc.remote(ps, Decoder, args=(ntoken, nhid, dropout))
def forward(self, input, hidden):
# pass input to the remote embedding table and fetch emb tensor back
emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input)
output, hidden = self.rnn(emb, hidden)
# pass output to the rremote decoder and get the decoded output back
decoded = _remote_method(Decoder.forward, self.decoder_rref, output)
return decoded, hidden
```
Before introducing the distributed optimizer, letās add a helper function to generate a list of RRefs of model parameters, which will be consumed by the distributed optimizer. In local training, applications could call `Module.parameters()` to grab references to all parameter tensors, and pass it to the local optimizer for subsequent updates. However, the same API does not work in distributed training scenarios as some parameters live on remote machines. Therefore, instead of taking a list of parameter `Tensors`, the distributed optimizer takes a list of `RRefs`, one `RRef` per model parameter for both local and remote model parameters. The helper function is pretty simple, just call `Module.parameters()` and creates a local `RRef` on each of the parameters.
```
def _parameter_rrefs(module):
param_rrefs = []
for param in module.parameters():
param_rrefs.append(RRef(param))
return param_rrefs
```
Then, as the `RNNModel` contains three sub-modules, we need to call `_parameter_rrefs` three times, and wrap that into another helper function.
```
class RNNModel(nn.Module):
...
def parameter_rrefs(self):
remote_params = []
# get RRefs of embedding table
remote_params.extend(_remote_method(_parameter_rrefs, self.emb_table_rref))
# create RRefs for local parameters
remote_params.extend(_parameter_rrefs(self.rnn))
# get RRefs of decoder
remote_params.extend(_remote_method(_parameter_rrefs, self.decoder_rref))
return remote_params
```
Now, we are ready to implement the training loop. After initializing model arguments, we create the `RNNModel` and the `DistributedOptimizer`. The distributed optimizer will take a list of parameter `RRefs`, find all distinct owner workers, and create the given local optimizer (i.e., `SGD` in this case, you can use other local optimizers as well) on each of the owner worker using the given arguments (i.e., `lr=0.05`).
In the training loop, it first creates a distributed autograd context, which will help the distributed autograd engine to find gradients and involved RPC send/recv functions. The design details of the distributed autograd engine can be found in its [design note](https://pytorch.org/docs/stable/rpc/distributed_autograd.html). Then, it kicks off the forward pass as if it is a local model, and run the distributed backward pass. For the distributed backward, you only need to specify a list of roots, in this case, it is the loss `Tensor`. The distributed autograd engine will traverse the distributed graph automatically and write gradients properly. Next, it runs the `step` function on the distributed optimizer, which will reach out to all involved local optimizers to update model parameters. Compared to local training, one minor difference is that you donāt need to run `zero_grad()` because each autograd context has dedicated space to store gradients, and as we create a context per iteration, those gradients from different iterations will not accumulate to the same set of `Tensors`.
```
def run_trainer():
batch = 5
ntoken = 10
ninp = 2
nhid = 3
nindices = 3
nlayers = 4
hidden = (
torch.randn(nlayers, nindices, nhid),
torch.randn(nlayers, nindices, nhid)
)
model = rnn.RNNModel('ps', ntoken, ninp, nhid, nlayers)
# setup distributed optimizer
opt = DistributedOptimizer(
optim.SGD,
model.parameter_rrefs(),
lr=0.05,
)
criterion = torch.nn.CrossEntropyLoss()
def get_next_batch():
for _ in range(5):
data = torch.LongTensor(batch, nindices) % ntoken
target = torch.LongTensor(batch, ntoken) % nindices
yield data, target
# train for 10 iterations
for epoch in range(10):
for data, target in get_next_batch():
# create distributed autograd context
with dist_autograd.context() as context_id:
hidden[0].detach_()
hidden[1].detach_()
output, hidden = model(data, hidden)
loss = criterion(output, target)
# run distributed backward pass
dist_autograd.backward(context_id, [loss])
# run distributed optimizer
opt.step(context_id)
# not necessary to zero grads since they are
# accumulated into the distributed autograd context
# which is reset every iteration.
print("Training epoch {}".format(epoch))
```
Finally, letās add some glue code to launch the parameter server and the trainer processes.
```
def run_worker(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
if rank == 1:
rpc.init_rpc("trainer", rank=rank, world_size=world_size)
_run_trainer()
else:
rpc.init_rpc("ps", rank=rank, world_size=world_size)
# parameter server do nothing
pass
# block until all rpcs finish
rpc.shutdown()
if __name__=="__main__":
world_size = 2
mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)
```
Rate this Page
ā
ā
ā
ā
ā
Send Feedback
[previous Customize Process Group Backends Using Cpp Extensions](https://docs.pytorch.org/tutorials/intermediate/process_group_cpp_extension_tutorial.html "previous page")
[next Implementing a Parameter Server Using Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html "next page")
Built with the [PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html) 0.15.4.
[previous Customize Process Group Backends Using Cpp Extensions](https://docs.pytorch.org/tutorials/intermediate/process_group_cpp_extension_tutorial.html "previous page")
[next Implementing a Parameter Server Using Distributed RPC Framework](https://docs.pytorch.org/tutorials/intermediate/rpc_param_server_tutorial.html "next page")
On this page
- [Distributed Reinforcement Learning using RPC and RRef](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html#distributed-reinforcement-learning-using-rpc-and-rref)
- [Distributed RNN using Distributed Autograd and Distributed Optimizer](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html#distributed-rnn-using-distributed-autograd-and-distributed-optimizer)
PyTorch Libraries
- [ExecuTorch](https://docs.pytorch.org/executorch)
- [Helion](https://docs.pytorch.org/helion)
- [torchao](https://docs.pytorch.org/ao)
- [kineto](https://github.com/pytorch/kineto)
- [torchtitan](https://github.com/pytorch/torchtitan)
- [TorchRL](https://docs.pytorch.org/rl)
- [torchvision](https://docs.pytorch.org/vision)
- [torchaudio](https://docs.pytorch.org/audio)
- [tensordict](https://docs.pytorch.org/tensordict)
- [PyTorch on XLA Devices](https://docs.pytorch.org/xla)
## Docs
Access comprehensive developer documentation for PyTorch
[View Docs](https://docs.pytorch.org/docs/stable/index.html)
## Tutorials
Get in-depth tutorials for beginners and advanced developers
[View Tutorials](https://docs.pytorch.org/tutorials)
## Resources
Find development resources and get your questions answered
[View Resources](https://pytorch.org/resources)
**Stay in touch** for updates, event info, and the latest news
By submitting this form, I consent to receive marketing emails from the LF and its projects regarding their events, training, research, developments, and related announcements. I understand that I can unsubscribe at any time using the links in the footers of the emails I receive. [Privacy Policy](https://www.linuxfoundation.org/privacy/).
Ā© PyTorch. Copyright Ā© The Linux FoundationĀ®. All rights reserved. The Linux Foundation has registered trademarks and uses trademarks. For more information, including terms of use, privacy policy, and trademark usage, please see our [Policies](https://www.linuxfoundation.org/legal/policies) page. [Trademark Usage](https://www.linuxfoundation.org/trademark-usage). [Privacy Policy](http://www.linuxfoundation.org/privacy).
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebookās Cookies Policy applies. Learn more, including about available controls: [Cookies Policy](https://opensource.fb.com/legal/cookie-policy).

Ā© Copyright 2024, PyTorch.
Created using [Sphinx](https://www.sphinx-doc.org/) 7.2.6.
Built with the [PyData Sphinx Theme](https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html) 0.15.4. |
| Readable Markdown | Created On: Jan 01, 2020 \| Last Updated: Sep 03, 2025 \| Last Verified: Nov 05, 2024
**Author**: [Shen Li](https://mrshenli.github.io/)
Note
[](https://docs.pytorch.org/tutorials/_images/pencil-16.png) View and edit this tutorial in [github](https://github.com/pytorch/tutorials/blob/main/intermediate_source/rpc_tutorial.rst).
Prerequisites:
- [PyTorch Distributed Overview](https://docs.pytorch.org/tutorials/beginner/dist_overview.html)
- [RPC API documents](https://pytorch.org/docs/master/rpc.html)
This tutorial uses two simple examples to demonstrate how to build distributed training with the [torch.distributed.rpc](https://pytorch.org/docs/stable/rpc.html) package which was first introduced as an experimental feature in PyTorch v1.4. Source code of the two examples can be found in [PyTorch examples](https://github.com/pytorch/examples).
Previous tutorials, [Getting Started With Distributed Data Parallel](https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html) and [Writing Distributed Applications With PyTorch](https://docs.pytorch.org/tutorials/intermediate/dist_tuto.html), described [DistributedDataParallel](https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which supports a specific training paradigm where the model is replicated across multiple processes and each process handles a split of the input data. Sometimes, you might run into scenarios that require different training paradigms. For example:
1. In reinforcement learning, it might be relatively expensive to acquire training data from environments while the model itself can be quite small. In this case, it might be useful to spawn multiple observers running in parallel and share a single agent. In this case, the agent takes care of the training locally, but the application would still need libraries to send and receive data between observers and the trainer.
2. Your model might be too large to fit in GPUs on a single machine, and hence would need a library to help split the model onto multiple machines. Or you might be implementing a [parameter server](https://www.cs.cmu.edu/~muli/file/parameter_server_osdi14.pdf) training framework, where model parameters and trainers live on different machines.
The [torch.distributed.rpc](https://pytorch.org/docs/stable/rpc.html) package can help with the above scenarios. In case 1, [RPC](https://pytorch.org/docs/stable/rpc.html#rpc) and [RRef](https://pytorch.org/docs/stable/rpc.html#rref) allow sending data from one worker to another while easily referencing remote data objects. In case 2, [distributed autograd](https://pytorch.org/docs/stable/rpc.html#distributed-autograd-framework) and [distributed optimizer](https://pytorch.org/docs/stable/rpc.html#module-torch.distributed.optim) make executing backward pass and optimizer step as if it is local training. In the next two sections, we will demonstrate APIs of [torch.distributed.rpc](https://pytorch.org/docs/stable/rpc.html) using a reinforcement learning example and a language model example. Please note, this tutorial does not aim at building the most accurate or efficient models to solve given problems, instead, the main goal here is to show how to use the [torch.distributed.rpc](https://pytorch.org/docs/stable/rpc.html) package to build distributed training applications.
## Distributed Reinforcement Learning using RPC and RRef[\#](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html#distributed-reinforcement-learning-using-rpc-and-rref "Link to this heading")
This section describes steps to build a toy distributed reinforcement learning model using RPC to solve CartPole-v1 from [OpenAI Gym](https://www.gymlibrary.dev/environments/classic_control/cart_pole/). The policy code is mostly borrowed from the existing single-thread [example](https://github.com/pytorch/examples/blob/master/reinforcement_learning) as shown below. We will skip details of the `Policy` design, and focus on RPC usages.
```
import torch.nn as nn
import torch.nn.functional as F
class Policy(nn.Module):
def __init__(self):
super(Policy, self).__init__()
self.affine1 = nn.Linear(4, 128)
self.dropout = nn.Dropout(p=0.6)
self.affine2 = nn.Linear(128, 2)
def forward(self, x):
x = self.affine1(x)
x = self.dropout(x)
x = F.relu(x)
action_scores = self.affine2(x)
return F.softmax(action_scores, dim=1)
```
We are ready to present the observer. In this example, each observer creates its own environment, and waits for the agentās command to run an episode. In each episode, one observer loops at most `n_steps` iterations, and in each iteration, it uses RPC to pass its environment state to the agent and gets an action back. Then it applies that action to its environment, and gets the reward and the next state from the environment. After that, the observer uses another RPC to report the reward to the agent. Again, please note that, this is obviously not the most efficient observer implementation. For example, one simple optimization could be packing current state and last reward in one RPC to reduce the communication overhead. However, the goal is to demonstrate RPC API instead of building the best solver for CartPole. So, letās keep the logic simple and the two steps explicit in this example.
```
import argparse
import gym
import torch.distributed.rpc as rpc
parser = argparse.ArgumentParser(
description="RPC Reinforcement Learning Example",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('--world_size', default=2, type=int, metavar='W',
help='number of workers')
parser.add_argument('--log_interval', type=int, default=10, metavar='N',
help='interval between training status logs')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
help='how much to value future rewards')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed for reproducibility')
args = parser.parse_args()
class Observer:
def __init__(self):
self.id = rpc.get_worker_info().id
self.env = gym.make('CartPole-v1')
self.env.seed(args.seed)
def run_episode(self, agent_rref):
state, ep_reward = self.env.reset(), 0
for _ in range(10000):
# send the state to the agent to get an action
action = agent_rref.rpc_sync().select_action(self.id, state)
# apply the action to the environment, and get the reward
state, reward, done, _ = self.env.step(action)
# report the reward to the agent for training purpose
agent_rref.rpc_sync().report_reward(self.id, reward)
# finishes after the number of self.env._max_episode_steps
if done:
break
```
The code for agent is a little more complex, and we will break it into multiple pieces. In this example, the agent serves as both the trainer and the master, such that it sends command to multiple distributed observers to run episodes, and it also records all actions and rewards locally which will be used during the training phase after each episode. The code below shows `Agent` constructor where most lines are initializing various components. The loop at the end initializes observers remotely on other workers, and holds `RRefs` to those observers locally. The agent will use those observer `RRefs` later to send commands. Applications donāt need to worry about the lifetime of `RRefs`. The owner of each `RRef` maintains a reference counting map to track its lifetime, and guarantees the remote data object will not be deleted as long as there is any live user of that `RRef`. Please refer to the `RRef` [design doc](https://pytorch.org/docs/stable/rpc/rref.html) for details.
```
import gym
import numpy as np
import torch
import torch.distributed.rpc as rpc
import torch.optim as optim
from torch.distributed.rpc import RRef, rpc_async, remote
from torch.distributions import Categorical
class Agent:
def __init__(self, world_size):
self.ob_rrefs = []
self.agent_rref = RRef(self)
self.rewards = {}
self.saved_log_probs = {}
self.policy = Policy()
self.optimizer = optim.Adam(self.policy.parameters(), lr=1e-2)
self.eps = np.finfo(np.float32).eps.item()
self.running_reward = 0
self.reward_threshold = gym.make('CartPole-v1').spec.reward_threshold
for ob_rank in range(1, world_size):
ob_info = rpc.get_worker_info(OBSERVER_NAME.format(ob_rank))
self.ob_rrefs.append(remote(ob_info, Observer))
self.rewards[ob_info.id] = []
self.saved_log_probs[ob_info.id] = []
```
Next, the agent exposes two APIs to observers for selecting actions and reporting rewards. Those functions only run locally on the agent, but will be triggered by observers through RPC.
```
class Agent:
...
def select_action(self, ob_id, state):
state = torch.from_numpy(state).float().unsqueeze(0)
probs = self.policy(state)
m = Categorical(probs)
action = m.sample()
self.saved_log_probs[ob_id].append(m.log_prob(action))
return action.item()
def report_reward(self, ob_id, reward):
self.rewards[ob_id].append(reward)
```
Letās add a `run_episode` function on agent which tells all observers to execute an episode. In this function, it first creates a list to collect futures from asynchronous RPCs, and then loop over all observer `RRefs` to make asynchronous RPCs. In these RPCs, the agent also passes an `RRef` of itself to the observer, so that the observer can call functions on the agent as well. As shown above, each observer will make RPCs back to the agent, which are nested RPCs. After each episode, the `saved_log_probs` and `rewards` will contain the recorded action probs and rewards.
```
class Agent:
...
def run_episode(self):
futs = []
for ob_rref in self.ob_rrefs:
# make async RPC to kick off an episode on all observers
futs.append(
rpc_async(
ob_rref.owner(),
ob_rref.rpc_sync().run_episode,
args=(self.agent_rref,)
)
)
# wait until all obervers have finished this episode
for fut in futs:
fut.wait()
```
Finally, after one episode, the agent needs to train the model, which is implemented in the `finish_episode` function below. There is no RPCs in this function and it is mostly borrowed from the single-thread [example](https://github.com/pytorch/examples/blob/master/reinforcement_learning). Hence, we skip describing its contents.
```
class Agent:
...
def finish_episode(self):
# joins probs and rewards from different observers into lists
R, probs, rewards = 0, [], []
for ob_id in self.rewards:
probs.extend(self.saved_log_probs[ob_id])
rewards.extend(self.rewards[ob_id])
# use the minimum observer reward to calculate the running reward
min_reward = min([sum(self.rewards[ob_id]) for ob_id in self.rewards])
self.running_reward = 0.05 * min_reward + (1 - 0.05) * self.running_reward
# clear saved probs and rewards
for ob_id in self.rewards:
self.rewards[ob_id] = []
self.saved_log_probs[ob_id] = []
policy_loss, returns = [], []
for r in rewards[::-1]:
R = r + args.gamma * R
returns.insert(0, R)
returns = torch.tensor(returns)
returns = (returns - returns.mean()) / (returns.std() + self.eps)
for log_prob, R in zip(probs, returns):
policy_loss.append(-log_prob * R)
self.optimizer.zero_grad()
policy_loss = torch.cat(policy_loss).sum()
policy_loss.backward()
self.optimizer.step()
return min_reward
```
With `Policy`, `Observer`, and `Agent` classes, we are ready to launch multiple processes to perform the distributed training. In this example, all processes run the same `run_worker` function, and they use the rank to distinguish their role. Rank 0 is always the agent, and all other ranks are observers. The agent serves as master by repeatedly calling `run_episode` and `finish_episode` until the running reward surpasses the reward threshold specified by the environment. All observers passively waiting for commands from the agent. The code is wrapped by [rpc.init\_rpc](https://pytorch.org/docs/stable/rpc.html#torch.distributed.rpc.init_rpc) and [rpc.shutdown](https://pytorch.org/docs/stable/rpc.html#torch.distributed.rpc.shutdown), which initializes and terminates RPC instances respectively. More details are available in the [API page](https://pytorch.org/docs/stable/rpc.html).
```
import os
from itertools import count
import torch.multiprocessing as mp
AGENT_NAME = "agent"
OBSERVER_NAME="obs{}"
def run_worker(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
if rank == 0:
# rank0 is the agent
rpc.init_rpc(AGENT_NAME, rank=rank, world_size=world_size)
agent = Agent(world_size)
print(f"This will run until reward threshold of {agent.reward_threshold}"
" is reached. Ctrl+C to exit.")
for i_episode in count(1):
agent.run_episode()
last_reward = agent.finish_episode()
if i_episode % args.log_interval == 0:
print(f"Episode {i_episode}\tLast reward: {last_reward:.2f}\tAverage reward: "
f"{agent.running_reward:.2f}")
if agent.running_reward > agent.reward_threshold:
print(f"Solved! Running reward is now {agent.running_reward}!")
break
else:
# other ranks are the observer
rpc.init_rpc(OBSERVER_NAME.format(rank), rank=rank, world_size=world_size)
# observers passively waiting for instructions from the agent
# block until all rpcs finish, and shutdown the RPC instance
rpc.shutdown()
mp.spawn(
run_worker,
args=(args.world_size, ),
nprocs=args.world_size,
join=True
)
```
Below are some sample outputs when training with world\_size=2.
```
This will run until reward threshold of 475.0 is reached. Ctrl+C to exit.
Episode 10 Last reward: 26.00 Average reward: 10.01
Episode 20 Last reward: 16.00 Average reward: 11.27
Episode 30 Last reward: 49.00 Average reward: 18.62
Episode 40 Last reward: 45.00 Average reward: 26.09
Episode 50 Last reward: 44.00 Average reward: 30.03
Episode 60 Last reward: 111.00 Average reward: 42.23
Episode 70 Last reward: 131.00 Average reward: 70.11
Episode 80 Last reward: 87.00 Average reward: 76.51
Episode 90 Last reward: 86.00 Average reward: 95.93
Episode 100 Last reward: 13.00 Average reward: 123.93
Episode 110 Last reward: 33.00 Average reward: 91.39
Episode 120 Last reward: 73.00 Average reward: 76.38
Episode 130 Last reward: 137.00 Average reward: 88.08
Episode 140 Last reward: 89.00 Average reward: 104.96
Episode 150 Last reward: 97.00 Average reward: 98.74
Episode 160 Last reward: 150.00 Average reward: 100.87
Episode 170 Last reward: 126.00 Average reward: 104.38
Episode 180 Last reward: 500.00 Average reward: 213.74
Episode 190 Last reward: 322.00 Average reward: 300.22
Episode 200 Last reward: 165.00 Average reward: 272.71
Episode 210 Last reward: 168.00 Average reward: 233.11
Episode 220 Last reward: 184.00 Average reward: 195.02
Episode 230 Last reward: 284.00 Average reward: 208.32
Episode 240 Last reward: 395.00 Average reward: 247.37
Episode 250 Last reward: 500.00 Average reward: 335.42
Episode 260 Last reward: 500.00 Average reward: 386.30
Episode 270 Last reward: 500.00 Average reward: 405.29
Episode 280 Last reward: 500.00 Average reward: 443.29
Episode 290 Last reward: 500.00 Average reward: 464.65
Solved! Running reward is now 475.3163778435275!
```
In this example, we show how to use RPC as the communication vehicle to pass data across workers, and how to use RRef to reference remote objects. It is true that you could build the entire structure directly on top of `ProcessGroup` `send` and `recv` APIs or use other communication/RPC libraries. However, by using torch.distributed.rpc, you can get the native support and continuously optimized performance under the hood.
Next, we will show how to combine RPC and RRef with distributed autograd and distributed optimizer to perform distributed model parallel training.
## Distributed RNN using Distributed Autograd and Distributed Optimizer[\#](https://docs.pytorch.org/tutorials/intermediate/rpc_tutorial.html#distributed-rnn-using-distributed-autograd-and-distributed-optimizer "Link to this heading")
In this section, we use an RNN model to show how to build distributed model parallel training with the RPC API. The example RNN model is very small and can easily fit into a single GPU, but we still divide its layers onto two different workers to demonstrate the idea. Developer can apply the similar techniques to distribute much larger models across multiple devices and machines.
The RNN model design is borrowed from the word language model in PyTorch [example](https://github.com/pytorch/examples/tree/master/word_language_model) repository, which contains three main components, an embedding table, an `LSTM` layer, and a decoder. The code below wraps the embedding table and the decoder into sub-modules, so that their constructors can be passed to the RPC API. In the `EmbeddingTable` sub-module, we intentionally put the `Embedding` layer on GPU to cover the use case. In v1.4, RPC always creates CPU tensor arguments or return values on the destination worker. If the function takes a GPU tensor, you need to move it to the proper device explicitly.
```
class EmbeddingTable(nn.Module):
r"""
Encoding layers of the RNNModel
"""
def __init__(self, ntoken, ninp, dropout):
super(EmbeddingTable, self).__init__()
self.drop = nn.Dropout(dropout)
self.encoder = nn.Embedding(ntoken, ninp).cuda()
self.encoder.weight.data.uniform_(-0.1, 0.1)
def forward(self, input):
return self.drop(self.encoder(input.cuda()).cpu()
class Decoder(nn.Module):
def __init__(self, ntoken, nhid, dropout):
super(Decoder, self).__init__()
self.drop = nn.Dropout(dropout)
self.decoder = nn.Linear(nhid, ntoken)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-0.1, 0.1)
def forward(self, output):
return self.decoder(self.drop(output))
```
With the above sub-modules, we can now piece them together using RPC to create an RNN model. In the code below `ps` represents a parameter server, which hosts parameters of the embedding table and the decoder. The constructor uses the [remote](https://pytorch.org/docs/stable/rpc.html#torch.distributed.rpc.remote) API to create an `EmbeddingTable` object and a `Decoder` object on the parameter server, and locally creates the `LSTM` sub-module. During the forward pass, the trainer uses the `EmbeddingTable` `RRef` to find the remote sub-module and passes the input data to the `EmbeddingTable` using RPC and fetches the lookup results. Then, it runs the embedding through the local `LSTM` layer, and finally uses another RPC to send the output to the `Decoder` sub-module. In general, to implement distributed model parallel training, developers can divide the model into sub-modules, invoke RPC to create sub-module instances remotely, and use on `RRef` to find them when necessary. As you can see in the code below, it looks very similar to single-machine model parallel training. The main difference is replacing `Tensor.to(device)` with RPC functions.
```
class RNNModel(nn.Module):
def __init__(self, ps, ntoken, ninp, nhid, nlayers, dropout=0.5):
super(RNNModel, self).__init__()
# setup embedding table remotely
self.emb_table_rref = rpc.remote(ps, EmbeddingTable, args=(ntoken, ninp, dropout))
# setup LSTM locally
self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
# setup decoder remotely
self.decoder_rref = rpc.remote(ps, Decoder, args=(ntoken, nhid, dropout))
def forward(self, input, hidden):
# pass input to the remote embedding table and fetch emb tensor back
emb = _remote_method(EmbeddingTable.forward, self.emb_table_rref, input)
output, hidden = self.rnn(emb, hidden)
# pass output to the rremote decoder and get the decoded output back
decoded = _remote_method(Decoder.forward, self.decoder_rref, output)
return decoded, hidden
```
Before introducing the distributed optimizer, letās add a helper function to generate a list of RRefs of model parameters, which will be consumed by the distributed optimizer. In local training, applications could call `Module.parameters()` to grab references to all parameter tensors, and pass it to the local optimizer for subsequent updates. However, the same API does not work in distributed training scenarios as some parameters live on remote machines. Therefore, instead of taking a list of parameter `Tensors`, the distributed optimizer takes a list of `RRefs`, one `RRef` per model parameter for both local and remote model parameters. The helper function is pretty simple, just call `Module.parameters()` and creates a local `RRef` on each of the parameters.
```
def _parameter_rrefs(module):
param_rrefs = []
for param in module.parameters():
param_rrefs.append(RRef(param))
return param_rrefs
```
Then, as the `RNNModel` contains three sub-modules, we need to call `_parameter_rrefs` three times, and wrap that into another helper function.
```
class RNNModel(nn.Module):
...
def parameter_rrefs(self):
remote_params = []
# get RRefs of embedding table
remote_params.extend(_remote_method(_parameter_rrefs, self.emb_table_rref))
# create RRefs for local parameters
remote_params.extend(_parameter_rrefs(self.rnn))
# get RRefs of decoder
remote_params.extend(_remote_method(_parameter_rrefs, self.decoder_rref))
return remote_params
```
Now, we are ready to implement the training loop. After initializing model arguments, we create the `RNNModel` and the `DistributedOptimizer`. The distributed optimizer will take a list of parameter `RRefs`, find all distinct owner workers, and create the given local optimizer (i.e., `SGD` in this case, you can use other local optimizers as well) on each of the owner worker using the given arguments (i.e., `lr=0.05`).
In the training loop, it first creates a distributed autograd context, which will help the distributed autograd engine to find gradients and involved RPC send/recv functions. The design details of the distributed autograd engine can be found in its [design note](https://pytorch.org/docs/stable/rpc/distributed_autograd.html). Then, it kicks off the forward pass as if it is a local model, and run the distributed backward pass. For the distributed backward, you only need to specify a list of roots, in this case, it is the loss `Tensor`. The distributed autograd engine will traverse the distributed graph automatically and write gradients properly. Next, it runs the `step` function on the distributed optimizer, which will reach out to all involved local optimizers to update model parameters. Compared to local training, one minor difference is that you donāt need to run `zero_grad()` because each autograd context has dedicated space to store gradients, and as we create a context per iteration, those gradients from different iterations will not accumulate to the same set of `Tensors`.
```
def run_trainer():
batch = 5
ntoken = 10
ninp = 2
nhid = 3
nindices = 3
nlayers = 4
hidden = (
torch.randn(nlayers, nindices, nhid),
torch.randn(nlayers, nindices, nhid)
)
model = rnn.RNNModel('ps', ntoken, ninp, nhid, nlayers)
# setup distributed optimizer
opt = DistributedOptimizer(
optim.SGD,
model.parameter_rrefs(),
lr=0.05,
)
criterion = torch.nn.CrossEntropyLoss()
def get_next_batch():
for _ in range(5):
data = torch.LongTensor(batch, nindices) % ntoken
target = torch.LongTensor(batch, ntoken) % nindices
yield data, target
# train for 10 iterations
for epoch in range(10):
for data, target in get_next_batch():
# create distributed autograd context
with dist_autograd.context() as context_id:
hidden[0].detach_()
hidden[1].detach_()
output, hidden = model(data, hidden)
loss = criterion(output, target)
# run distributed backward pass
dist_autograd.backward(context_id, [loss])
# run distributed optimizer
opt.step(context_id)
# not necessary to zero grads since they are
# accumulated into the distributed autograd context
# which is reset every iteration.
print("Training epoch {}".format(epoch))
```
Finally, letās add some glue code to launch the parameter server and the trainer processes.
```
def run_worker(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
if rank == 1:
rpc.init_rpc("trainer", rank=rank, world_size=world_size)
_run_trainer()
else:
rpc.init_rpc("ps", rank=rank, world_size=world_size)
# parameter server do nothing
pass
# block until all rpcs finish
rpc.shutdown()
if __name__=="__main__":
world_size = 2
mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)
``` |
| Shard | 114 (laksa) |
| Root Hash | 14416670112284949514 |
| Unparsed URL | org,pytorch!docs,/tutorials/intermediate/rpc_tutorial.html s443 |