在前面讲模型加载和保存的时候,在多GPU情况下,实际上是挖了坑的,比如在多GPU加载时,GPU的利用率是不均衡的,而当时没详细探讨这个问题,今天来详细地讨论一下。
在训练的时候,如果GPU资源有限,而数据量和模型大小较大,那么在单GPU上运行就会极其慢的训练速度,此时就要使用多GPU进行模型训练了,在pytorch上实现多GPU训练实际上十分简单:
只需要将模型使用nn.DataParallel
进行装饰即可。
model = nn.DataParallel(model,device_ids=range(torch.cuda.device_count()))
但是问题在于这样直接处理后的模型的负载可能是不均衡的,因为在不同的GPU上进行运算,而最后的loss计算过程是要合并到主GPU上,这样主GPU的的占用率将比较高,而其余GPU的利用率则没有那么高。
class FullModel(nn.Module):
def __init__(self, model, loss):
super(FullModel, self).__init__()
self.model = model
self.loss = loss
def forward(self, targets, *inputs):
outputs = self.model(*inputs)
loss = self.loss(outputs, targets)
return torch.unsqueeze(loss,0),outputs
在上述的代码中,构建了另外一个包含model和loss的壳,在壳里计算loss的值,需要注意的是,在进行DataParallel时,也需要对这个并行,而到了收集loss的时候,则使用loss的和:
loss,_ = model(gt,input)
loss = loss.sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()
已经有人造了这个轮子,并开源了出来,可以参考:https://github.com/zhanghang1989/PyTorch-Encoding 代码库,整个写法依然没有太大的变化:
from utils.encoding import DataParallelModel, DataParallelCriterion
model = DataParallelModel(model)
criterion = DataParallelCriterion(criterion)
实际上官方考虑过负载不均衡的问题,在文档中也推荐使用distributedDataparallel(ddp)进行训练,尽管ddp是用来解决不同机器的分布式训练问题的。
ddp使用起来比DataParallel更快,数据也更均衡,但是缺点是配置起来相对要麻烦一些。
# 初始化使用的后端
torch.distributed.init_process_group(backend="nccl")
# 对数据进行划分
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=n_worker, pin_memory=True, sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=n_worker, pin_memory=True, sampler=test_sampler) # sampler和shuffle不能同时使用
model=torch.nn.parallel.DistributedDataParallel(model)
注意:需要注意的是,尽量设定pin_memory参数为true,该参数是锁存操作,使用会加快数据读取速度,但是此时要限定内存的大小是要使用显存的两倍
以上就配置好了,经过测试,使用ddp的训练时间比DataParallel快一倍。
在运行的时候,使用以下命令进行分布式训练: python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE yourscript.py
本文链接:http://task.lmcjl.com/news/12693.html