🕷️ Crawler Inspector

URL Lookup

Direct Parameter Lookup

Raw Queries and Responses

1. Shard Calculation

Query:
Response:
Calculated Shard: 8 (from laksa070)

2. Crawled Status Check

Query:
Response:

3. Robots.txt Check

Query:
Response:

4. Spam/Ban Check

Query:
Response:

5. Seen Status Check

ℹ️ Skipped - page is already crawled

📄
INDEXABLE
CRAWLED
6 days ago
🤖
ROBOTS ALLOWED

Page Info Filters

FilterStatusConditionDetails
HTTP statusPASSdownload_http_code = 200HTTP 200
Age cutoffPASSdownload_stamp > now() - 6 MONTH0.2 months ago
History dropPASSisNull(history_drop_reason)No drop reason
Spam/banPASSfh_dont_index != 1 AND ml_spam_score = 0ml_spam_score=0
CanonicalPASSmeta_canonical IS NULL OR = '' OR = src_unparsedNot set

Page Details

PropertyValue
URLhttps://www.w3cschool.cn/pytorch/pytorch-it483bt6.html
Last Crawled2026-03-31 21:34:11 (6 days ago)
First Indexed2020-09-10 11:25:39 (5 years ago)
HTTP Status Code200
Meta TitlePyTorch 分布式训练师与 AWS 集成实战教程_w3cschool
Meta Description本教程深入讲解如何在 Amazon AWS 多 GPU 节点上配置和运行 PyTorch 1.0 分布式训练环境,帮助您轻松扩展训练代码,提升训练效率,快速掌握分布式训练实用技能。_来自PyTorch 中文教程,w3cschool编程狮。
Meta Canonicalnull
Boilerpipe Text
随着深度学习模型规模的不断扩大和数据量的持续增长,单机训练方式已难以满足高效训练的需求。分布式训练成为一种必然选择,它通过将计算任务分布在多个 GPU 或服务器上,显著提升了训练效率。AWS 作为全球领先的云计算平台,提供了强大的计算资源和灵活的服务架构,为分布式训练提供了理想的运行环境。本文将深入探讨如何在 AWS 上搭建和运行 PyTorch 分布式训练系统,通过实际案例助力您高效开展深度学习项目。 (一)创建实例 在 AWS 上创建两个多 GPU 节点,选择适合深度学习任务的实例类型,如 p2.8xlarge ,其配备 8 个 NVIDIA Tesla K80 GPU,为分布式训练提供强大的计算支持。 (二)配置安全组 确保实例之间的通信畅通无阻,是分布式训练成功的关键。创建一个新的安全组,并配置入站和出站规则,允许节点之间所有类型的数据流量。具体操作步骤如下: 登录 AWS 管理控制台,选择 “EC2” 服务。 在左侧导航栏中,选择 “安全组”。 点击 “创建安全组”,设置安全组名称和描述。 在 “入站规则” 栏中,添加规则允许来自新安全组的 “所有流量”。 在 “出站规则” 栏中,同样添加规则允许流向新安全组的 “所有流量”。 (三)获取节点 IP 地址 在 EC2 仪表板中找到正在运行的实例,记录每个节点的 IPv4 公网 IP 和私网 IP。公网 IP 用于 SSH 连接,私网 IP 用于节点间通信。这些 IP 地址在后续配置中将被频繁使用。 二、环境配置 (一)创建并激活 conda 环境 在每个节点上创建并激活一个新的 conda 环境,为 PyTorch 提供干净的运行环境: conda create -n pytorch_env python=3.8 conda activate pytorch_env (二)安装 PyTorch 和 torchvision 安装支持 CUDA 的 PyTorch 夜度构建版本以及从源代码构建的 torchvision: pip install torch --index-url https://download.pytorch.org/whl/nightly/cu118 cd ~ git clone https://github.com/pytorch/vision.git cd vision python setup.py install (三)设置 NCCL 网络接口 为了优化 GPU 之间的通信,设置 NCCL 套接字的网络接口名称。通过运行 ifconfig 命令确定网络接口名称,并设置环境变量: export NCCL_SOCKET_IFNAME=ens3 三、分布式训练代码实现 (一)导入必要的模块 import time import sys import torch import torch.nn as nn import torch.nn.parallel import torch.distributed as dist import torch.optim import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models (二)定义辅助函数和类 class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res (三)定义训练和验证函数 def train(train_loader, model, criterion, optimizer, epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.train() end = time.time() for i, (input, target) in enumerate(train_loader): data_time.update(time.time() - end) input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) output = model(input) loss = criterion(output, target) prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1[0], input.size(0)) top5.update(prec5[0], input.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) end = time.time() if i % 10 == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) def validate(val_loader, model, criterion): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() with torch.no_grad(): end = time.time() for i, (input, target) in enumerate(val_loader): input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) output = model(input) loss = criterion(output, target) prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1[0], input.size(0)) top5.update(prec5[0], input.size(0)) batch_time.update(time.time() - end) end = time.time() if i % 100 == 0: print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( i, len(val_loader), batch_time=batch_time, loss=losses, top1=top1, top5=top5)) print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' .format(top1=top1, top5=top5)) return top1.avg (四)初始化进程组 def main(): batch_size = 32 workers = 2 num_epochs = 2 starting_lr = 0.1 world_size = 4 dist_backend = 'nccl' dist_url = "tcp://<node0-privateIP>:23456" # 替换为实际的节点私有 IP print("Initialize Process Group...") dist.init_process_group(backend=dist_backend, init_method=dist_url, rank=int(sys.argv[1]), world_size=world_size) local_rank = int(sys.argv[2]) dp_device_ids = [local_rank] torch.cuda.set_device(local_rank) print("Initialize Model...") model = models.resnet18(pretrained=False).cuda() model = torch.nn.parallel.DistributedDataParallel(model, device_ids=dp_device_ids) criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), starting_lr, momentum=0.9, weight_decay=1e-4) print("Initialize Dataloaders...") transform = transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) trainset = datasets.STL10(root='./data', split='train', download=True, transform=transform) valset = datasets.STL10(root='./data', split='test', download=True, transform=transform) train_sampler = torch.utils.data.distributed.DistributedSampler(trainset) train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=(train_sampler is None), num_workers=workers, pin_memory=False, sampler=train_sampler) val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=False) best_prec1 = 0 for epoch in range(num_epochs): train_sampler.set_epoch(epoch) adjust_learning_rate(starting_lr, optimizer, epoch) train(train_loader, model, criterion, optimizer, epoch) prec1 = validate(val_loader, model, criterion) best_prec1 = max(prec1, best_prec1) print("Epoch Summary: ") print("\tEpoch Accuracy: {}".format(prec1)) print("\tBest Accuracy: {}".format(best_prec1)) if __name__ == "__main__": main() 四、运行训练 在每个节点上打开多个 SSH 终端,分别运行以下命令: 在 node0 的第一个终端上: python main.py 0 0 在 node0 的第二个终端上: python main.py 1 1 在 node1 的第一个终端上: python main.py 2 0 在 node1 的第二个终端上: python main.py 3 1 以上内容已同步发布至编程狮网站,欢迎访问 编程狮 PyTorch 教程 获取更多深度学习和 PyTorch 相关的优质教程。在学习过程中,如果您有任何疑问或需要进一步的技术支持,欢迎加入编程狮社区,与广大编程爱好者和专家进行交流和互动。
Markdown
[![w3cschool](https://7nsts.w3cschool.cn/images/logonew2.png)](http://www.w3cschool.cn/ "w3cschool") - [编程课程](http://www.w3cschool.cn/courses "编程课程") - [编程实战](http://www.w3cschool.cn/codecamp "编程实战") - [编程题库](http://www.w3cschool.cn/exam "编程题库") - [编程教程](http://www.w3cschool.cn/tutorial "编程教程") - [在线工具](https://123.w3cschool.cn/webtools "w3cschool在线工具集合") - [![](https://atts.w3cschool.cn/trae.png) 免费 AI IDE](https://www.trae.com.cn/?utm_source=advertising&utm_medium=w3cschool_ug_cpa&utm_term=hw_trae_w3cschool) - [VIP会员 *开学季*](http://www.w3cschool.cn/vip?fcode=headermenu "VIP会员") [App下载](http://www.w3cschool.cn/download "App下载") ![APP二维码](https://7nsts.w3cschool.cn/images/w3c/app-qrcode2.png) 扫码下载编程狮APP [注册](http://www.w3cschool.cn/register?refer=/)\|[登录](http://www.w3cschool.cn/login?refer=/) 注册成功 X ![](https://atts.w3cschool.cn/attachments/avatar2/avatar_0.jpg) W3Cschool 恭喜您成为首批注册用户 获得88经验值奖励 马上体验 - [入门教程](https://www.w3cschool.cn/tutorial "编程入门教程") - [编程课程](https://www.w3cschool.cn/learn "编程课程") - [VIP会员](https://www.w3cschool.cn/vip?fcode=m_indexmenu "VIP会员") [![PyTorch 中文教程](https://atts.w3cschool.cn/attachments/cover/cover_pytorch.png?t=1666251799?imageView2/1/w/48/h/48) PyTorch 中文教程](https://www.w3cschool.cn/pytorch "PyTorch 中文教程") - [赞]() - [收藏]() - [更多文章](https://www.w3cschool.cn/pytorch/list/ "更多文章") - [目录]("目录") - [搜索]("搜索") - [书签]("书签") 1. [PyTorch 入门](https://www.w3cschool.cn/pytorch/pytorch-o95j3bbp.html "PyTorch 入门") 1. [PyTorch 深度学习](https://www.w3cschool.cn/pytorch/pytorch-oraf3bbx.html "PyTorch 入门教程:60 分钟掌握深度学习基础") 1. [PyTorch是什么?](https://www.w3cschool.cn/pytorch/pytorch-5ubt3bby.html "PyTorch是什么?") 2. [PyTorch Autograd自动求导](https://www.w3cschool.cn/pytorch/pytorch-n63v3kt8.html "PyTorch Autograd自动求导") 3. [PyTorch 神经网络](https://www.w3cschool.cn/pytorch/pytorch-64gk3kt5.html "PyTorch 神经网络") 4. [PyTorch 训练分类器](https://www.w3cschool.cn/pytorch/pytorch-w18e3be1.html "PyTorch 训练分类器") 5. [PyTorch 可选: 数据并行处理](https://www.w3cschool.cn/pytorch/pytorch-wiyd3be2.html "PyTorch 可选: 数据并行处理") 2. [PyTorch 编写自定义数据集,数据加载器和转换](https://www.w3cschool.cn/pytorch/pytorch-typm3be3.html "PyTorch 编写自定义数据集,数据加载器和转换") 3. [PyTorch 使用 TensorBoard 可视化模型,数据和训练](https://www.w3cschool.cn/pytorch/pytorch-jafd3bz3.html "PyTorch 使用 TensorBoard 可视化模型,数据和训练") 2. [PyTorch 图片](https://www.w3cschool.cn/pytorch/pytorch-oy463bbr.html "PyTorch 图片") 1. [PyTorch TorchVision 对象检测微调教程](https://www.w3cschool.cn/pytorch/pytorch-5yko3be6.html "PyTorch TorchVision 对象检测微调教程") 2. [PyTorch 转移学习的计算机视觉教程](https://www.w3cschool.cn/pytorch/pytorch-5rw93be7.html "PyTorch 转移学习的计算机视觉教程") 3. [PyTorch 空间变压器网络教程](https://www.w3cschool.cn/pytorch/pytorch-jfe13bht.html "PyTorch 空间变压器网络教程") 4. [PyTorch 进行神经传递](https://www.w3cschool.cn/pytorch/pytorch-afb53bil.html "PyTorch 进行神经传递") 5. [PyTorch 对抗示例生成](https://www.w3cschool.cn/pytorch/pytorch-uvpm3bm3.html "PyTorch 对抗示例生成") 6. [PyTorch DCGAN 教程](https://www.w3cschool.cn/pytorch/pytorch-rqs93bn0.html "PyTorch DCGAN 教程") 3. [Pytorch 音频](https://www.w3cschool.cn/pytorch/pytorch-fljd3bbu.html "Pytorch 音频") 1. [PyTorch torchaudio教程](https://www.w3cschool.cn/pytorch/pytorch-52d93bnc.html "PyTorch torchaudio教程") 4. [Pytorch 文本](https://www.w3cschool.cn/pytorch/pytorch-up9q3bne.html "Pytorch 文本") 1. [PyTorch NLP From Scratch: 使用char-RNN对姓氏进行分类](https://www.w3cschool.cn/pytorch/pytorch-h7u13bnf.html "PyTorch NLP From Scratch: 使用char-RNN对姓氏进行分类") 2. [PyTorch NLP From Scratch: 生成名称与字符级RNN](https://www.w3cschool.cn/pytorch/pytorch-51wi3bz4.html "PyTorch NLP From Scratch: 生成名称与字符级RNN") 3. [PyTorch NLP From Scratch: 基于注意力机制的 seq2seq 神经网络翻译](https://www.w3cschool.cn/pytorch/pytorch-9dfn3bnh.html "PyTorch NLP From Scratch: 基于注意力机制的 seq2seq 神经网络翻译") 4. [PyTorch 使用 TorchText 进行文本分类](https://www.w3cschool.cn/pytorch/pytorch-5yam3bof.html "PyTorch 使用 TorchText 进行文本分类") 5. [PyTorch 使用 TorchText 进行语言翻译](https://www.w3cschool.cn/pytorch/pytorch-9aen3bpa.html "PyTorch 使用 TorchText 进行语言翻译") 6. [PyTorch 使用 nn.Transformer 和 TorchText 进行序列到序列建模](https://www.w3cschool.cn/pytorch/pytorch-r47x3bpb.html "PyTorch 使用 nn.Transformer 和 TorchText 进行序列到序列建模") 5. [PyTorch 命名为 Tensor(实验性)]("PyTorch 命名为 Tensor(实验性)") 1. [PyTorch 中的命名张量简介(实验性)](https://www.w3cschool.cn/pytorch/pytorch-skug3bpf.html "PyTorch 中的命名张量简介(实验性)") 6. [PyTorch 强化学习]("PyTorch 强化学习") 1. [PyTorch 强化学习(DQN)教程](https://www.w3cschool.cn/pytorch/pytorch-1zvj3bpn.html "PyTorch 强化学习(DQN)教程") 7. [PyTorch 在生产中部署 PyTorch 模型]("PyTorch 在生产中部署 PyTorch 模型") 1. [PyTorch 通过带有 Flask 的 REST API 在 Python 中部署 PyTorch](https://www.w3cschool.cn/pytorch/pytorch-x9jo3bpm.html "PyTorch 通过带有 Flask 的 REST API 在 Python 中部署 PyTorch") 2. [PyTorch TorchScript 简介](https://www.w3cschool.cn/pytorch/pytorch-ea8n3bsm.html "PyTorch TorchScript 简介") 3. [PyTorch 在 C ++中加载 TorchScript 模型](https://www.w3cschool.cn/pytorch/pytorch-nr8s3bsu.html "PyTorch 在 C ++中加载 TorchScript 模型") 4. [PyTorch (可选)将模型从 PyTorch 导出到 ONNX 并使用 ONNX Runtime 运行](https://www.w3cschool.cn/pytorch/pytorch-fs5q3bsv.html "PyTorch (可选)将模型从 PyTorch 导出到 ONNX 并使用 ONNX Runtime 运行") 8. [PyTorch 并行和分布式训练]("PyTorch 并行和分布式训练") 1. [PyTorch 单机模型并行最佳实践](https://www.w3cschool.cn/pytorch/pytorch-8ce63bz5.html "PyTorch 单机模型并行最佳实践") 2. [PyTorch 分布式数据并行入门](https://www.w3cschool.cn/pytorch/pytorch-hv9o3bsy.html "PyTorch 分布式数据并行入门") 3. [PyTorch 用 PyTorch 编写分布式应用程序](https://www.w3cschool.cn/pytorch/pytorch-pfac3bt2.html "PyTorch 用 PyTorch 编写分布式应用程序") 4. [PyTorch 分布式 RPC 框架入门](https://www.w3cschool.cn/pytorch/pytorch-t8g53bt3.html "PyTorch 分布式 RPC 框架入门") 5. [PyTorch 分布式训练师与 AWS 集成实战教程](https://www.w3cschool.cn/pytorch/pytorch-it483bt6.html "PyTorch 分布式训练师与 AWS 集成实战教程") 9. [PyTorch 扩展]("PyTorch 扩展") 1. [PyTorch 使用自定义 C ++运算符扩展 TorchScript](https://www.w3cschool.cn/pytorch/pytorch-ljs93bz6.html "PyTorch 使用自定义 C ++运算符扩展 TorchScript") 2. [PyTorch 使用自定义 C ++类扩展 TorchScript](https://www.w3cschool.cn/pytorch/pytorch-jsrc3bta.html "PyTorch 使用自定义 C ++类扩展 TorchScript") 3. [PyTorch 与 NumPy、SciPy 扩展应用入门教程](https://www.w3cschool.cn/pytorch/pytorch-96lx3btb.html "PyTorch 与 NumPy、SciPy 扩展应用入门教程") 4. [PyTorch 自定义 C ++和 CUDA 扩展](https://www.w3cschool.cn/pytorch/pytorch-r9fw3btc.html "PyTorch 自定义 C ++和 CUDA 扩展") 10. [PyTorch 模型优化](https://www.w3cschool.cn/pytorch/pytorch-hu473btd.html "PyTorch 模型优化") 1. [PyTorch LSTM Word 语言模型上的(实验)动态量化](https://www.w3cschool.cn/pytorch/pytorch-xk8d3bte.html "PyTorch LSTM Word 语言模型上的(实验)动态量化") 2. [PyTorch (实验性)在 PyTorch 中使用 Eager 模式进行静态量化](https://www.w3cschool.cn/pytorch/pytorch-sbhv3btf.html "PyTorch (实验性)在 PyTorch 中使用 Eager 模式进行静态量化") 3. [PyTorch (实验性)计算机视觉教程的量化转移学习](https://www.w3cschool.cn/pytorch/pytorch-wvl73btg.html "PyTorch (实验性)计算机视觉教程的量化转移学习") 4. [PyTorch (实验)BERT 上的动态量化](https://www.w3cschool.cn/pytorch/pytorch-wcjv3bz7.html "PyTorch (实验)BERT 上的动态量化") 5. [PyTorch 修剪教程](https://www.w3cschool.cn/pytorch/pytorch-rnmi3bti.html "PyTorch 修剪教程") 11. [PyTorch 用其他语言]("PyTorch 用其他语言") 1. [PyTorch 使用 PyTorch C ++前端](https://www.w3cschool.cn/pytorch/pytorch-8nyw3bz8.html "PyTorch 使用 PyTorch C ++前端") 12. [PyTorch 基础知识](https://www.w3cschool.cn/pytorch/pytorch-6g5l3btl.html "PyTorch 基础知识") 1. [PyTorch 通过示例学习 PyTorch](https://www.w3cschool.cn/pytorch/pytorch-mug23btm.html "PyTorch 通过示例学习 PyTorch") 2. [PyTorch torch.nn 到底是什么?](https://www.w3cschool.cn/pytorch/pytorch-mrni3btn.html "PyTorch torch.nn 到底是什么?") 13. [PyTorch 笔记](https://www.w3cschool.cn/pytorch/pytorch-eg293bto.html "PyTorch 笔记") 1. [PyTorch 自动求导机制](https://www.w3cschool.cn/pytorch/pytorch-mbex3bz9.html "PyTorch 自动求导机制") 2. [PyTorch 广播语义详解及应用实例](https://www.w3cschool.cn/pytorch/pytorch-2p973btq.html "PyTorch 广播语义详解及应用实例") 3. [PyTorch CPU 线程与 TorchScript 推断优化详解](https://www.w3cschool.cn/pytorch/pytorch-fq513btr.html "PyTorch CPU 线程与 TorchScript 推断优化详解") 4. [PyTorch CUDA 语义](https://www.w3cschool.cn/pytorch/pytorch-v9hz3bts.html "PyTorch CUDA 语义详解及应用优化") 5. [PyTorch 分布式 Autograd 设计](https://www.w3cschool.cn/pytorch/pytorch-sbgu3btt.html "PyTorch 分布式 Autograd 设计") 6. [PyTorch 扩展机制详解](https://www.w3cschool.cn/pytorch/pytorch-nowf3btu.html "PyTorch 扩展机制详解") 7. [PyTorch 经常问的问题(FAQ)详解](https://www.w3cschool.cn/pytorch/pytorch-kl1g3btv.html "PyTorch 经常问的问题(FAQ)详解") 8. [PyTorch 大规模部署的功能](https://www.w3cschool.cn/pytorch/pytorch-evow3btw.html "PyTorch 大规模部署的功能") 9. [PyTorch 并行处理最佳实践](https://www.w3cschool.cn/pytorch/pytorch-vaey3bud.html "PyTorch 并行处理最佳实践") 10. [PyTorch 重现性](https://www.w3cschool.cn/pytorch/pytorch-inrs3bue.html "PyTorch 重现性") 11. [PyTorch 远程参考协议](https://www.w3cschool.cn/pytorch/pytorch-cdva3buf.html "PyTorch 远程参考协议") 12. [PyTorch 序列化语义](https://www.w3cschool.cn/pytorch/pytorch-c9613buh.html "PyTorch 序列化语义") 13. [PyTorch Windows 常见问题](https://www.w3cschool.cn/pytorch/pytorch-mfdy3bza.html "PyTorch Windows 常见问题") 14. [PyTorch XLA 设备上的 PyTorch](https://www.w3cschool.cn/pytorch/pytorch-xy4u3buj.html "PyTorch XLA 设备上的 PyTorch") 14. [PyTorch 语言绑定](https://www.w3cschool.cn/pytorch/pytorch-9ry43buk.html "PyTorch 语言绑定") 1. [PyTorch C ++ API](https://www.w3cschool.cn/pytorch/pytorch-3gk43bul.html "PyTorch C ++ API") 15. [Python API](https://www.w3cschool.cn/pytorch/pytorch-juo53bun.html "Python API") 1. [PyTorch torch](https://www.w3cschool.cn/pytorch/pytorch-z7mj3buo.html "PyTorch torch") 2. [PyTorch torch.nn](https://www.w3cschool.cn/pytorch/pytorch-79mp3bvy.html "PyTorch torch.nn") 3. [PyTorch torch功能](https://www.w3cschool.cn/pytorch/pytorch-rbmj3bwf.html "PyTorch torch功能") 4. [PyTorch torch张量](https://www.w3cschool.cn/pytorch/pytorch-g4dh3bwg.html "PyTorch torch张量") 5. [Pytorch 张量属性](https://www.w3cschool.cn/pytorch/pytorch-8ixf3bwh.html "Pytorch 张量属性") 6. [PyTorch 自动差分包-Torch.Autograd](https://www.w3cschool.cn/pytorch/pytorch-bo3g3bwi.html "PyTorch 自动差分包-Torch.Autograd") 7. [PyTorch torch.cuda](https://www.w3cschool.cn/pytorch/pytorch-15zb3bwj.html "PyTorch torch.cuda") 8. [PyTorch 分布式通讯包-Torch.Distributed](https://www.w3cschool.cn/pytorch/pytorch-v6td3bwy.html "PyTorch 分布式通讯包-Torch.Distributed") 9. [PyTorch 概率分布-torch分布](https://www.w3cschool.cn/pytorch/pytorch-5blq3bwz.html "PyTorch 概率分布-torch分布") 10. [PyTorch torch.hub](https://www.w3cschool.cn/pytorch/pytorch-th213bx3.html "PyTorch torch.hub") 11. [PyTorch torch脚本](https://www.w3cschool.cn/pytorch/pytorch-x7js3bx4.html "PyTorch torch脚本") 12. [PyTorch torch.nn.init](https://www.w3cschool.cn/pytorch/pytorch-162i3bxb.html "PyTorch torch.nn.init") 13. [PyTorch torch.onnx](https://www.w3cschool.cn/pytorch/pytorch-ohnu3bxc.html "PyTorch torch.onnx") 14. [PyTorch torch.optim](https://www.w3cschool.cn/pytorch/pytorch-pb4o3bxd.html "PyTorch torch.optim") 15. [PyTorch 量化](https://www.w3cschool.cn/pytorch/pytorch-ildt3bxe.html "PyTorch 量化") 16. [PyTorch 分布式 RPC 框架](https://www.w3cschool.cn/pytorch/pytorch-me1q3bxf.html "PyTorch 分布式 RPC 框架") 17. [PyTorch torch随机](https://www.w3cschool.cn/pytorch/pytorch-cnxe3bxg.html "PyTorch torch随机") 18. [PyTorch torch稀疏](https://www.w3cschool.cn/pytorch/pytorch-k6vx3bxh.html "PyTorch torch稀疏") 19. [PyTorch torch存储](https://www.w3cschool.cn/pytorch/pytorch-5huz3bxi.html "PyTorch torch存储") 20. [PyTorch torch.utils.bottleneck](https://www.w3cschool.cn/pytorch/pytorch-yi6d3bxj.html "PyTorch torch.utils.bottleneck") 21. [PyTorch torch.utils.checkpoint](https://www.w3cschool.cn/pytorch/pytorch-wj1h3bxk.html "PyTorch torch.utils.checkpoint") 22. [PyTorch torch.utils.cpp\_extension](https://www.w3cschool.cn/pytorch/pytorch-dhfk3bxl.html "PyTorch torch.utils.cpp_extension") 23. [PyTorch torch.utils.data](https://www.w3cschool.cn/pytorch/pytorch-eznw3bxm.html "PyTorch torch.utils.data") 24. [PyTorch torch.utils.dlpack](https://www.w3cschool.cn/pytorch/pytorch-hcd33bxn.html "PyTorch torch.utils.dlpack") 25. [PyTorch torch.utils.model\_zoo](https://www.w3cschool.cn/pytorch/pytorch-jvfx3bxo.html "PyTorch torch.utils.model_zoo") 26. [PyTorch torch.utils.tensorboard](https://www.w3cschool.cn/pytorch/pytorch-trnd3bxp.html "PyTorch torch.utils.tensorboard") 27. [PyTorch 类型信息](https://www.w3cschool.cn/pytorch/pytorch-fiey3bxq.html "PyTorch 类型信息") 28. [PyTorch 命名张量](https://www.w3cschool.cn/pytorch/pytorch-w4zc3bxr.html "PyTorch 命名张量") 29. [PyTorch 命名为 Tensors 操作员范围](https://www.w3cschool.cn/pytorch/pytorch-de743bxs.html "PyTorch 命名为 Tensors 操作员范围") 16. [PyTorch torchvision参考]("PyTorch torchvision参考") 1. [PyTorch torchvision](https://www.w3cschool.cn/pytorch/pytorch-abcs3by7.html "PyTorch torchvision") 17. [PyTorch 音频参考]("PyTorch 音频参考") 1. [PyTorch torchaudio](https://www.w3cschool.cn/pytorch/pytorch-zr4g3by9.html "PyTorch torchaudio") 18. [PyTorch 社区]("PyTorch 社区") 1. [PyTorch 贡献指南](https://www.w3cschool.cn/pytorch/pytorch-5xzs3byv.html "PyTorch 贡献指南") 2. [PyTorch 治理](https://www.w3cschool.cn/pytorch/pytorch-d3ik3byw.html "PyTorch 治理") 3. [PyTorch 治理\| 感兴趣的人](https://www.w3cschool.cn/pytorch/pytorch-wfp23byx.html "PyTorch 治理| 感兴趣的人") 搜索 A A 默认 护眼 夜间 阅读(3.8k) [书签]() [赞(0)]() [分享]("分享") [我要纠错](https://www.w3cschool.cn/edit/pytorch/pytorch-it483bt6) # PyTorch 分布式训练师与 AWS 集成实战教程 2025-06-23 10:02 更新 随着深度学习模型规模的不断扩大和数据量的持续增长,单机训练方式已难以满足高效训练的需求。分布式训练成为一种必然选择,它通过将计算任务分布在多个 GPU 或服务器上,显著提升了训练效率。AWS 作为全球领先的云计算平台,提供了强大的计算资源和灵活的服务架构,为分布式训练提供了理想的运行环境。本文将深入探讨如何在 AWS 上搭建和运行 PyTorch 分布式训练系统,通过实际案例助力您高效开展深度学习项目。 ## 一、AWS 环境搭建 ### (一)创建实例 在 AWS 上创建两个多 GPU 节点,选择适合深度学习任务的实例类型,如 `p2.8xlarge`,其配备 8 个 NVIDIA Tesla K80 GPU,为分布式训练提供强大的计算支持。 ### (二)配置安全组 确保实例之间的通信畅通无阻,是分布式训练成功的关键。创建一个新的安全组,并配置入站和出站规则,允许节点之间所有类型的数据流量。具体操作步骤如下: 1. 登录 AWS 管理控制台,选择 “EC2” 服务。 2. 在左侧导航栏中,选择 “安全组”。 3. 点击 “创建安全组”,设置安全组名称和描述。 4. 在 “入站规则” 栏中,添加规则允许来自新安全组的 “所有流量”。 5. 在 “出站规则” 栏中,同样添加规则允许流向新安全组的 “所有流量”。 ### (三)获取节点 IP 地址 在 EC2 仪表板中找到正在运行的实例,记录每个节点的 IPv4 公网 IP 和私网 IP。公网 IP 用于 SSH 连接,私网 IP 用于节点间通信。这些 IP 地址在后续配置中将被频繁使用。 ## 二、环境配置 ### (一)创建并激活 conda 环境 在每个节点上创建并激活一个新的 conda 环境,为 PyTorch 提供干净的运行环境: ``` conda create -n pytorch_env python=3.8 conda activate pytorch_env ``` ### (二)安装 PyTorch 和 torchvision 安装支持 CUDA 的 PyTorch 夜度构建版本以及从源代码构建的 torchvision: ``` pip install torch --index-url https://download.pytorch.org/whl/nightly/cu118 cd ~ git clone https://github.com/pytorch/vision.git cd vision python setup.py install ``` ### (三)设置 NCCL 网络接口 为了优化 GPU 之间的通信,设置 NCCL 套接字的网络接口名称。通过运行 `ifconfig` 命令确定网络接口名称,并设置环境变量: ``` export NCCL_SOCKET_IFNAME=ens3 ``` ## 三、分布式训练代码实现 ### (一)导入必要的模块 ``` import time import sys import torch import torch.nn as nn import torch.nn.parallel import torch.distributed as dist import torch.optim import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models ``` ### (二)定义辅助函数和类 ``` class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res ``` ### (三)定义训练和验证函数 ``` def train(train_loader, model, criterion, optimizer, epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.train() end = time.time() for i, (input, target) in enumerate(train_loader): data_time.update(time.time() - end) input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) output = model(input) loss = criterion(output, target) prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1[0], input.size(0)) top5.update(prec5[0], input.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) end = time.time() if i % 10 == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) def validate(val_loader, model, criterion): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() with torch.no_grad(): end = time.time() for i, (input, target) in enumerate(val_loader): input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) output = model(input) loss = criterion(output, target) prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1[0], input.size(0)) top5.update(prec5[0], input.size(0)) batch_time.update(time.time() - end) end = time.time() if i % 100 == 0: print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( i, len(val_loader), batch_time=batch_time, loss=losses, top1=top1, top5=top5)) print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' .format(top1=top1, top5=top5)) return top1.avg ``` ### (四)初始化进程组 ``` def main(): batch_size = 32 workers = 2 num_epochs = 2 starting_lr = 0.1 world_size = 4 dist_backend = 'nccl' dist_url = "tcp://<node0-privateIP>:23456" # 替换为实际的节点私有 IP print("Initialize Process Group...") dist.init_process_group(backend=dist_backend, init_method=dist_url, rank=int(sys.argv[1]), world_size=world_size) local_rank = int(sys.argv[2]) dp_device_ids = [local_rank] torch.cuda.set_device(local_rank) print("Initialize Model...") model = models.resnet18(pretrained=False).cuda() model = torch.nn.parallel.DistributedDataParallel(model, device_ids=dp_device_ids) criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), starting_lr, momentum=0.9, weight_decay=1e-4) print("Initialize Dataloaders...") transform = transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) trainset = datasets.STL10(root='./data', split='train', download=True, transform=transform) valset = datasets.STL10(root='./data', split='test', download=True, transform=transform) train_sampler = torch.utils.data.distributed.DistributedSampler(trainset) train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=(train_sampler is None), num_workers=workers, pin_memory=False, sampler=train_sampler) val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=False) best_prec1 = 0 for epoch in range(num_epochs): train_sampler.set_epoch(epoch) adjust_learning_rate(starting_lr, optimizer, epoch) train(train_loader, model, criterion, optimizer, epoch) prec1 = validate(val_loader, model, criterion) best_prec1 = max(prec1, best_prec1) print("Epoch Summary: ") print("\tEpoch Accuracy: {}".format(prec1)) print("\tBest Accuracy: {}".format(best_prec1)) if __name__ == "__main__": main() ``` ## 四、运行训练 在每个节点上打开多个 SSH 终端,分别运行以下命令: - 在 node0 的第一个终端上:`python main.py 0 0` - 在 node0 的第二个终端上:`python main.py 1 1` - 在 node1 的第一个终端上:`python main.py 2 0` - 在 node1 的第二个终端上:`python main.py 3 1` 以上内容已同步发布至编程狮网站,欢迎访问[编程狮 PyTorch 教程](https://www.w3cschool.cn/pytorch/)获取更多深度学习和 PyTorch 相关的优质教程。在学习过程中,如果您有任何疑问或需要进一步的技术支持,欢迎加入编程狮社区,与广大编程爱好者和专家进行交流和互动。 以上内容是否对您有帮助: 在文档使用的过程中是否遇到以下问题: - 内容错误 - 更新不及时 - 链接错误 - 缺少代码/图片示列 - 太简单/步骤待完善 - 其他 更多建议: [提交建议]() ← [PyTorch 分布式 RPC 框架入门](https://www.w3cschool.cn/pytorch/pytorch-t8g53bt3.html "上一篇:PyTorch 分布式 RPC 框架入门") [PyTorch 使用自定义 C ++运算符扩展 TorchScript](https://www.w3cschool.cn/pytorch/pytorch-ljs93bz6.html "下一篇:PyTorch 使用自定义 C ++运算符扩展 TorchScript") → [写笔记]() [我要补充]() ## 推荐文章 - [W3Cschool 热门编程语言排行榜 2020年 10月 TOP10](https://www.w3cschool.cn/article/44760098.html "W3Cschool 热门编程语言排行榜 2020年 10月 TOP10") - [几种后端开发中常用的语言。](https://www.w3cschool.cn/article/92223460.html "几种后端开发中常用的语言。") - [2020年10月编程语言排行榜:Python 即将超越 Java](https://www.w3cschool.cn/article/cf0fe6dbee7042.html "2020年10月编程语言排行榜:Python 即将超越 Java") - [怎么查看python版本?有几种方法?](https://www.w3cschool.cn/article/82929779.html "怎么查看python版本?有几种方法?") - [python怎么保留两位小数?几种方法总结!](https://www.w3cschool.cn/article/89475330.html "python怎么保留两位小数?几种方法总结!") ## 推荐教程 - [Python Tornado 介绍](https://www.w3cschool.cn/python_tornado "Python Tornado 介绍") - [Django4 中文教程](https://www.w3cschool.cn/django4 "Django4 中文教程") - [笨办法学Python](https://www.w3cschool.cn/tzwdhj "笨办法学Python") - [零基础学python(第二版)](https://www.w3cschool.cn/uqmpir "零基础学python(第二版)") - [白话Python3](https://www.w3cschool.cn/py_practice "白话Python3") ## 推荐课程 - [Python os模块](https://www.w3cschool.cn/minicourse/play/py_os?fcode=tutorial-pytorch "Python os模块") - [Python GUI编程 PyQt6入门到实战](https://www.w3cschool.cn/minicourse/play/antpython10?fcode=tutorial-pytorch "Python GUI编程 PyQt6入门到实战") - [Python Requests权威指南](https://www.w3cschool.cn/minicourse/play/python_requests_detail?fcode=tutorial-pytorch "Python Requests权威指南") - [Python Scrapy爬虫入门到实战](https://www.w3cschool.cn/minicourse/play/scrapy_14day?fcode=tutorial-pytorch "Python Scrapy爬虫入门到实战") - [Python Django 框架入门课程](https://www.w3cschool.cn/minicourse/play/pythondjango?fcode=tutorial-pytorch "Python Django 框架入门课程") 精选笔记 Copyright©2021 [w3cschool](https://www.w3cschool.cn/ "w3cschool")编程狮\|[闽ICP备15016281号-3](https://beian.miit.gov.cn/)\|[闽公网安备35020302033924号](http://www.beian.gov.cn/portal/registerSystemInfo?recordcode=35020302033924) 违法和不良信息举报电话:173-0602-2364\|[举报邮箱:jubao@eeedong.com](mailto:jubao@eeedong.com) 在线笔记 App下载 ![App下载](https://7nsts.w3cschool.cn/images/w3c/app-qrcode2.png) 扫描二维码 下载编程狮App 公众号 ![微信公众号](https://7nsts.w3cschool.cn/images/w3c/mp-qrcode.png) 编程狮公众号 意见反馈 意见反馈 X - 意见反馈: 联系方式: 提交 [查看完整版笔记](https://www.w3cschool.cn/my/note) 保存 关闭 教程纠错 教程纠错 违规举报 X - 广告等垃圾信息 - 不友善内容 - 违反法律法规的内容 - 不宜公开讨论的政治内容 - 其他 提交 工具 推荐 [更多](https://123.w3cschool.cn/webtools) [![](https://atts.w3cschool.cn/trae.png) Trae CN](https://www.trae.com.cn/?utm_source=advertising&utm_medium=w3cschool_ug_cpa&utm_term=hw_trae_w3cschool) [![](https://atts.w3cschool.cn/Turtle.png) Turtle绘图](https://www.w3cschool.cn/tools/index?name=pythonturtle) [![](https://atts.w3cschool.cn/Markdown.png) Markdown编辑器](https://www.w3cschool.cn/tools/index?name=editormd) [![](https://atts.w3cschool.cn/shijianchuo.png) Unix时间戳](https://www.w3cschool.cn/tools/index?name=timestamptrans) [![](https://atts.w3cschool.cn/Mermaid.png) Mermaid编辑器](https://www.w3cschool.cn/tools/index?name=mermaid) [![](https://atts.w3cschool.cn/Python.png) Python在线编译器](https://www.w3cschool.cn/tryrun/runcode?lang=python3) [![](https://atts.w3cschool.cn/shiseqi.png) 在线拾色器](https://www.w3cschool.cn/tools/index?name=cpicker) [![](https://atts.w3cschool.cn/zhengze.png) 正则工具](https://www.w3cschool.cn/tools/index?name=decode_encode_tool) AI编程工具 [更多](https://123.w3cschool.cn/navaitools) [![](https://atts.w3cschool.cn/trae.png) Trae](https://www.trae.com.cn/?utm_source=advertising&utm_medium=w3cschool_ug_cpa&utm_term=hw_trae_w3cschool) [![](data:image/svg+xml,%3csvg%20id='raccoon_sm_light'%20data-name='raccoon'%20xmlns='http://www.w3.org/2000/svg'%20viewBox='50%2040%20200%20120'%3e%3cdefs%3e%3cstyle%3e%20.cls-1%20{%20fill:%20%23fff;%20}%20.cls-2%20{%20fill:%20%23192842;%20}%20%3c/style%3e%3c/defs%3e%3cg%3e%3cpath%20class='cls-1'%20d='M244.52,103.68l-22-22A5.42,5.42,0,0,1,221,77.13l3.92-28.95a8.31,8.31,0,0,0-8.23-9.42H195.4a12.37,12.37,0,0,0-12.14,9.58,3.92,3.92,0,0,1-3.79,3H120.53a3.92,3.92,0,0,1-3.79-3,12.37,12.37,0,0,0-12.14-9.58H83.33a8.31,8.31,0,0,0-8.23,9.42L79,77.13a5.42,5.42,0,0,1-1.53,4.54l-22,22a15.09,15.09,0,0,0,0,21.31l33.18,33.19A10.43,10.43,0,0,0,96,161.24H204a10.43,10.43,0,0,0,7.38-3.06L244.52,125A15.09,15.09,0,0,0,244.52,103.68Z'/%3e%3cg%3e%3cpath%20class='cls-2'%20d='M161.65,110.58a4.23,4.23,0,0,0-2.79-7.41H141.14a4.23,4.23,0,0,0-2.79,7.41l6.12,5.36a8.38,8.38,0,0,0,11.06,0Z'/%3e%3cpath%20class='cls-2'%20d='M211.33,70.49l2.52-18.59a3.49,3.49,0,0,0-3.46-4H197.21a3.49,3.49,0,0,0-2.47,6Z'/%3e%3cpath%20class='cls-2'%20d='M88.61,70.49,105.2,53.9a3.49,3.49,0,0,0-2.47-6H89.55a3.5,3.5,0,0,0-3.46,4Z'/%3e%3cpath%20class='cls-2'%20d='M234.77,108.45,192.36,66.12a14.93,14.93,0,0,0-10.56-4.37H164.55a13.25,13.25,0,0,0-8.13,2.79l-4.19,3.26a3.63,3.63,0,0,1-4.46,0l-4.19-3.26a13.25,13.25,0,0,0-8.13-2.79H118.2a14.93,14.93,0,0,0-10.56,4.37L65.23,108.45a8.32,8.32,0,0,0,0,11.75L94,148.94a6.61,6.61,0,0,0,4.68,1.94H121.8l-16.64-16.64a6.73,6.73,0,0,1,0-9.52l39.45-39.43a7.63,7.63,0,0,1,10.78,0l31.48,31.47-34.13,34.12h14.65l26.8-26.8.65.64a6.73,6.73,0,0,1,0,9.52L178.2,150.88h23.14a6.61,6.61,0,0,0,4.68-1.94l28.75-28.74A8.32,8.32,0,0,0,234.77,108.45ZM118.36,93.26a5.48,5.48,0,1,1,5.48-5.48A5.49,5.49,0,0,1,118.36,93.26Zm63.28,0a5.48,5.48,0,1,1,5.48-5.48A5.48,5.48,0,0,1,181.64,93.26Z'/%3e%3c/g%3e%3c/g%3e%3c/svg%3e) 代码小浣熊](https://www.xiaohuanxiong.com/login?utm_source=blmay41) [![](https://atts.w3cschool.cn/aitool-icon.png) 星辰Agent](https://agent.xfyun.cn/home?ch=xcagent-aitool16) [![](https://atts.w3cschool.cn/aitool-icon-1.png) 通义灵码](https://www.w3cschool.cn/tongyilingma/) [![](https://atts.w3cschool.cn/aitool-icon-5.png) 文心快码](https://comate.baidu.com/zh) [![](https://www.w3cschool.cn/attachments/webnav/202412/11734432755.png) CodeGeeX](https://www.w3cschool.cn/codegeex/) [![](https://atts.w3cschool.cn/aitool-icon-3.png) GitHub Copilot](https://github.com/features/copilot) [![](https://atts.w3cschool.cn/aitool-icon-4.png) Fitten Code](https://code.fittentech.com/) ![](https://atts.w3cschool.cn/aitool-m.png)
Readable Markdown
随着深度学习模型规模的不断扩大和数据量的持续增长,单机训练方式已难以满足高效训练的需求。分布式训练成为一种必然选择,它通过将计算任务分布在多个 GPU 或服务器上,显著提升了训练效率。AWS 作为全球领先的云计算平台,提供了强大的计算资源和灵活的服务架构,为分布式训练提供了理想的运行环境。本文将深入探讨如何在 AWS 上搭建和运行 PyTorch 分布式训练系统,通过实际案例助力您高效开展深度学习项目。 ### (一)创建实例 在 AWS 上创建两个多 GPU 节点,选择适合深度学习任务的实例类型,如 `p2.8xlarge`,其配备 8 个 NVIDIA Tesla K80 GPU,为分布式训练提供强大的计算支持。 ### (二)配置安全组 确保实例之间的通信畅通无阻,是分布式训练成功的关键。创建一个新的安全组,并配置入站和出站规则,允许节点之间所有类型的数据流量。具体操作步骤如下: 1. 登录 AWS 管理控制台,选择 “EC2” 服务。 2. 在左侧导航栏中,选择 “安全组”。 3. 点击 “创建安全组”,设置安全组名称和描述。 4. 在 “入站规则” 栏中,添加规则允许来自新安全组的 “所有流量”。 5. 在 “出站规则” 栏中,同样添加规则允许流向新安全组的 “所有流量”。 ### (三)获取节点 IP 地址 在 EC2 仪表板中找到正在运行的实例,记录每个节点的 IPv4 公网 IP 和私网 IP。公网 IP 用于 SSH 连接,私网 IP 用于节点间通信。这些 IP 地址在后续配置中将被频繁使用。 ## 二、环境配置 ### (一)创建并激活 conda 环境 在每个节点上创建并激活一个新的 conda 环境,为 PyTorch 提供干净的运行环境: ``` conda create -n pytorch_env python=3.8 conda activate pytorch_env ``` ### (二)安装 PyTorch 和 torchvision 安装支持 CUDA 的 PyTorch 夜度构建版本以及从源代码构建的 torchvision: ``` pip install torch --index-url https://download.pytorch.org/whl/nightly/cu118 cd ~ git clone https://github.com/pytorch/vision.git cd vision python setup.py install ``` ### (三)设置 NCCL 网络接口 为了优化 GPU 之间的通信,设置 NCCL 套接字的网络接口名称。通过运行 `ifconfig` 命令确定网络接口名称,并设置环境变量: ``` export NCCL_SOCKET_IFNAME=ens3 ``` ## 三、分布式训练代码实现 ### (一)导入必要的模块 ``` import time import sys import torch import torch.nn as nn import torch.nn.parallel import torch.distributed as dist import torch.optim import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models ``` ### (二)定义辅助函数和类 ``` class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res ``` ### (三)定义训练和验证函数 ``` def train(train_loader, model, criterion, optimizer, epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.train() end = time.time() for i, (input, target) in enumerate(train_loader): data_time.update(time.time() - end) input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) output = model(input) loss = criterion(output, target) prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1[0], input.size(0)) top5.update(prec5[0], input.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) end = time.time() if i % 10 == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5)) def validate(val_loader, model, criterion): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() with torch.no_grad(): end = time.time() for i, (input, target) in enumerate(val_loader): input = input.cuda(non_blocking=True) target = target.cuda(non_blocking=True) output = model(input) loss = criterion(output, target) prec1, prec5 = accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1[0], input.size(0)) top5.update(prec5[0], input.size(0)) batch_time.update(time.time() - end) end = time.time() if i % 100 == 0: print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( i, len(val_loader), batch_time=batch_time, loss=losses, top1=top1, top5=top5)) print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' .format(top1=top1, top5=top5)) return top1.avg ``` ### (四)初始化进程组 ``` def main(): batch_size = 32 workers = 2 num_epochs = 2 starting_lr = 0.1 world_size = 4 dist_backend = 'nccl' dist_url = "tcp://<node0-privateIP>:23456" # 替换为实际的节点私有 IP print("Initialize Process Group...") dist.init_process_group(backend=dist_backend, init_method=dist_url, rank=int(sys.argv[1]), world_size=world_size) local_rank = int(sys.argv[2]) dp_device_ids = [local_rank] torch.cuda.set_device(local_rank) print("Initialize Model...") model = models.resnet18(pretrained=False).cuda() model = torch.nn.parallel.DistributedDataParallel(model, device_ids=dp_device_ids) criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), starting_lr, momentum=0.9, weight_decay=1e-4) print("Initialize Dataloaders...") transform = transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) trainset = datasets.STL10(root='./data', split='train', download=True, transform=transform) valset = datasets.STL10(root='./data', split='test', download=True, transform=transform) train_sampler = torch.utils.data.distributed.DistributedSampler(trainset) train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=(train_sampler is None), num_workers=workers, pin_memory=False, sampler=train_sampler) val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=False) best_prec1 = 0 for epoch in range(num_epochs): train_sampler.set_epoch(epoch) adjust_learning_rate(starting_lr, optimizer, epoch) train(train_loader, model, criterion, optimizer, epoch) prec1 = validate(val_loader, model, criterion) best_prec1 = max(prec1, best_prec1) print("Epoch Summary: ") print("\tEpoch Accuracy: {}".format(prec1)) print("\tBest Accuracy: {}".format(best_prec1)) if __name__ == "__main__": main() ``` ## 四、运行训练 在每个节点上打开多个 SSH 终端,分别运行以下命令: - 在 node0 的第一个终端上:`python main.py 0 0` - 在 node0 的第二个终端上:`python main.py 1 1` - 在 node1 的第一个终端上:`python main.py 2 0` - 在 node1 的第二个终端上:`python main.py 3 1` 以上内容已同步发布至编程狮网站,欢迎访问[编程狮 PyTorch 教程](https://www.w3cschool.cn/pytorch/)获取更多深度学习和 PyTorch 相关的优质教程。在学习过程中,如果您有任何疑问或需要进一步的技术支持,欢迎加入编程狮社区,与广大编程爱好者和专家进行交流和互动。
Shard8 (laksa)
Root Hash17163751617681654208
Unparsed URLcn,w3cschool!www,/pytorch/pytorch-it483bt6.html s443