深度學習(單機多gpu訓練)

如果一個機器上有多個gpu,可以使用多gpu訓練。

一般數據量和模型比較大的時候訓練速度會有明顯的提升,模型和數據比較小的時候反而可能因爲數據通信原因導致性能下降。

下面是一個簡單的例子:

import time
import torch
import torchvision.models
from torchvision.transforms import transforms
from torch import nn, optim
from torchvision.datasets import CIFAR10

if __name__ == "__main__":

    device = torch.device("cuda")
    
    dataTransforms = transforms.Compose([
            transforms.ToTensor()
            , transforms.RandomCrop(32, padding=4)  
            , transforms.RandomHorizontalFlip(p=0.5) 
            , transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  
        ])

    trainset = CIFAR10(root='./data', train=True, download=True, transform=dataTransforms)
    trainLoader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True)
 
    model = torchvision.models.resnet18(pretrained=False)
    model.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=True)  
    model.maxpool = nn.MaxPool2d(1, 1, 0) 
    model.fc = nn.Linear(model.fc.in_features, 10)
 
    model.to(device)

    # 將模型包裝成 DataParallel
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    cross = nn.CrossEntropyLoss()
    cross.to(device)

    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

    start = time.time()
    for epoch in range(10):
   
        model.train()  

        correctSum = 0.0
        lossSum = 0.0
        dataLen = 0

        for inputs, labels in trainLoader:
            inputs = inputs.to(device)
            labels = labels.to(device)
 
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = cross(outputs, labels)
 
            _, preds = torch.max(outputs, 1)  
 
            loss.backward() 
            optimizer.step()  
 
            correct = (preds == labels).sum() 
            correctSum +=correct
            lossSum += loss.item()
            dataLen +=inputs.size(0)
        
        print(lossSum/dataLen, correctSum/dataLen)

    timeElapsed = time.time() - start
    print('耗時 {:.0f}m {:.0f}s'.format(timeElapsed // 60, timeElapsed % 60))
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章