【Pytorch教程】Pytorch tutorials 04-Training a classifer 中文翻譯

Training a classifier

本篇文章是本人對(duì)Pytorch官方教程的原創(chuàng)翻譯(原文鏈接)僅供學(xué)習(xí)交流使用,轉(zhuǎn)載請(qǐng)注明出處!

現(xiàn)在我們已經(jīng)掌握了如何去定義神經(jīng)網(wǎng)絡(luò)、計(jì)算誤差、更新權(quán)重。但在前面的章節(jié)中,我們用到的數(shù)據(jù)集都是自己構(gòu)造的虛擬數(shù)據(jù),那么如何真正地處理數(shù)據(jù)呢?

通常,我們處理圖像、文本、音頻、視頻等數(shù)據(jù)時(shí),可以使用一些Python的標(biāo)準(zhǔn)庫,將輸入導(dǎo)入為numpy格式,然后我們將導(dǎo)入的numpy數(shù)組轉(zhuǎn)化為tensor。

  • 處理圖像數(shù)據(jù),用PillowOpenCV
  • 處理音頻,用scipylibrosa
  • 處理文本,既可以使用Python/Cython的原生方法,也可以使用NLTKSpacy

pytorch為計(jì)算機(jī)視覺任務(wù)特別提供了一個(gè)torchvision包,內(nèi)含Imagenet、CIFAR10、MNIST等常用數(shù)據(jù)集,以及數(shù)據(jù)集的轉(zhuǎn)換器。他們分別包含在torchvision.datasetstorch.utils.data.DataLoader中。這就極大地避免了編寫大量重復(fù)的代碼。

本篇教程會(huì)使用CIFAR10數(shù)據(jù)集。它由10類圖片組成,每張圖片都是32x32,3通道像素。

Training an image classifier

創(chuàng)建一個(gè)圖像分類器共需5個(gè)步驟:

  1. torchvision加載CIFAR10數(shù)據(jù)集并標(biāo)準(zhǔn)化。
  2. 定義一個(gè)卷積神經(jīng)網(wǎng)絡(luò)
  3. 定義損失函數(shù)
  4. 用訓(xùn)練集訓(xùn)練網(wǎng)絡(luò)
  5. 用測(cè)試集測(cè)試網(wǎng)絡(luò)

步驟1 加載CIFAR10數(shù)據(jù)集并標(biāo)準(zhǔn)化。

import torch
import torchvision
import torchvision.transforms as transforms

torchvision.datasets提供的圖像是PILImage,像素在[0, 1] 區(qū)間,我們需要將其標(biāo)準(zhǔn)化,得到的是[-1, 1]的數(shù)據(jù)。

transform = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])  # (樣本-均值) / 標(biāo)準(zhǔn)差, 需要分別指定3個(gè)通道的均值和標(biāo)注差

'''
加載訓(xùn)練集
root:數(shù)據(jù)集根目錄
train:是否為訓(xùn)練集
download:是否需要下載
transform:transform對(duì)象,對(duì)數(shù)據(jù)集進(jìn)行轉(zhuǎn)換
'''
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# shuffle:是否打亂 num_workers: 多線程數(shù)量 如果在windows下報(bào)錯(cuò)請(qǐng)改為0
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

# 加載測(cè)試集,與上面同理
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))  # 原始數(shù)據(jù)是PILimage,BGR格式,plot只能顯示RGB格式,必須要轉(zhuǎn)置
    plt.show()

# 用迭代器來訪問數(shù)據(jù),一次訪問的數(shù)據(jù)量是一個(gè)batch
dataiter = iter(trainloader)
images, labels = dataiter.next()

imshow(torchvision.utils.make_grid(images))  # make_grid用于給圖像加上邊框
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
horse   car   dog plane

步驟2 定義神經(jīng)網(wǎng)絡(luò)

前面的章節(jié)我們已經(jīng)定義過神經(jīng)網(wǎng)絡(luò)了,直接將代碼復(fù)用,修改為輸入3通道即可。

import torch
import torch.nn as nn
import torch.nn.functional as F  # nn.functional提供了各種激勵(lì)函數(shù)

class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        # 這里將輸入通道改為3
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x

net = Net()

步驟3 誤差計(jì)算和參數(shù)更新

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)  # momentum表示動(dòng)量, 一般設(shè)為0.9,帶動(dòng)量的梯度下降法收斂更快

步驟4 訓(xùn)練神經(jīng)網(wǎng)絡(luò)

for epoch in range(2):  # epoch表示在整個(gè)數(shù)據(jù)集上循環(huán)訓(xùn)練的次數(shù)
    
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):  #enumerate()將會(huì)給可迭代對(duì)象的元素標(biāo)上序號(hào),返回(序號(hào), 元素)
        # 這里的data是以batch為單位的
        inputs, labels = data  # data的特征和標(biāo)簽分開
        
        # 清空梯度
        optimizer.zero_grad()
        
        # 處理輸入、計(jì)算誤差、更新權(quán)重
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        # 做一些統(tǒng)計(jì)
        running_loss += loss.item()  # loss是 1x1的Tenor,可以用item直接訪問數(shù)據(jù)
        if i % 2000 == 1999:  # 每2000batch輸出一次
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training.')
[1,  2000] loss: 2.271
[1,  4000] loss: 1.946
[1,  6000] loss: 1.725
[1,  8000] loss: 1.598
[1, 10000] loss: 1.535
[1, 12000] loss: 1.477
[2,  2000] loss: 1.411
[2,  4000] loss: 1.389
[2,  6000] loss: 1.359
[2,  8000] loss: 1.340
[2, 10000] loss: 1.307
[2, 12000] loss: 1.280
Finished Training.

訓(xùn)練完成后,要記得保存訓(xùn)練好的模型:

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

步驟5 測(cè)試神經(jīng)網(wǎng)絡(luò)

我們已經(jīng)用數(shù)據(jù)集對(duì)神經(jīng)網(wǎng)絡(luò)訓(xùn)練了2遍,接下來要檢驗(yàn)一下神經(jīng)網(wǎng)絡(luò)是否學(xué)到了東西。

檢驗(yàn)的方法就是讓神經(jīng)網(wǎng)絡(luò)再產(chǎn)生一些輸出,并且和它們的標(biāo)簽做比對(duì)。

首先我們來看一組圖片的標(biāo)簽:

dataiter = iter(testloader)
images, labels = dataiter.next()

imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
GroundTruth:    cat  ship  ship plane

接下來我們導(dǎo)入保存好的模型,看看模型認(rèn)為這些圖片是什么。模型的輸出是圖片的“能量”,能量共有10個(gè)值,分別表示這場(chǎng)圖片屬于對(duì)應(yīng)類別的可能性,能量越大,代表我們的分類器認(rèn)為圖片越屬于一個(gè)類。

net = Net()
net.load_state_dict(torch.load(PATH))

outputs = net(images)
# torch.max不僅可以返回最大值,還可以返回最大值的索引(第二個(gè)返回值),我們不需要知道能量的具體值,只需要知道圖片歸屬哪一類即可,最大能量對(duì)應(yīng)的索引即是它被歸為的類
_, predicted = torch.max(outputs, 1)  

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
Predicted:    cat  ship  ship  ship

結(jié)果還算不錯(cuò),接下來我們把網(wǎng)絡(luò)應(yīng)用到完整數(shù)據(jù)集上試一試:

correct = 0 
total = 0
with torch.no_grad():
    for data in testloader:
        images,labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()


print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))     
Accuracy of the network on the 10000 test images: 55 %

再按類別做一次統(tǒng)計(jì),看一看我們的網(wǎng)絡(luò)的優(yōu)勢(shì)和短板是什么:

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))
Accuracy of plane : 70 %
Accuracy of   car : 67 %
Accuracy of  bird : 34 %
Accuracy of   cat : 43 %
Accuracy of  deer : 52 %
Accuracy of   dog : 52 %
Accuracy of  frog : 67 %
Accuracy of horse : 58 %
Accuracy of  ship : 60 %
Accuracy of truck : 51 %

Training on GPU

在GPU上進(jìn)行訓(xùn)練也非常簡(jiǎn)單,怎么把Tensor轉(zhuǎn)到GPU,就怎么把網(wǎng)絡(luò)轉(zhuǎn)到GPU:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)
cuda:0

接下來我們直接使用net.to(device)即可把網(wǎng)絡(luò)遷移到GPU上,程序會(huì)自動(dòng)識(shí)別所有的參數(shù),將他們轉(zhuǎn)化為CUDA Tensor。

需要注意的是,我們必須把輸入的數(shù)據(jù)和標(biāo)簽也都遷移至GPU:

inputs, labels = data[0].to(device), data[1].to(device)

至此,Pytorch tutorial篇已經(jīng)完結(jié),官方原版第5篇教程Optional: Data Parallelism為可選部分,不再另行翻譯。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 227,797評(píng)論 6 531
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 98,179評(píng)論 3 414
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事。” “怎么了?”我有些...
    開封第一講書人閱讀 175,628評(píng)論 0 373
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)。 經(jīng)常有香客問我,道長(zhǎng),這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 62,642評(píng)論 1 309
  • 正文 為了忘掉前任,我火速辦了婚禮,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 71,444評(píng)論 6 405
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 54,948評(píng)論 1 321
  • 那天,我揣著相機(jī)與錄音,去河邊找鬼。 笑死,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,040評(píng)論 3 440
  • 文/蒼蘭香墨 我猛地睜開眼,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 42,185評(píng)論 0 287
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 48,717評(píng)論 1 333
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 40,602評(píng)論 3 354
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 42,794評(píng)論 1 369
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 38,316評(píng)論 5 358
  • 正文 年R本政府宣布,位于F島的核電站,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 44,045評(píng)論 3 347
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 34,418評(píng)論 0 26
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 35,671評(píng)論 1 281
  • 我被黑心中介騙來泰國(guó)打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 51,414評(píng)論 3 390
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 47,750評(píng)論 2 370

推薦閱讀更多精彩內(nèi)容