深度學(xué)習(xí)第七篇---Pytorch 循環(huán)神經(jīng)網(wǎng)絡(luò)-實(shí)現(xiàn)情感極性判定

之前所學(xué)的全連接神經(jīng)網(wǎng)絡(luò)(DNN)和卷積神經(jīng)網(wǎng)絡(luò)(CNN),他們的前一個(gè)輸入和后一個(gè)輸入是沒有關(guān)系的(從輸入層到隱含層再到輸出層,層與層之間是全連接的,每層之間的節(jié)點(diǎn)是無連接的)。如下面這樣,輸出層X,經(jīng)過隱藏層,到輸出層Y,通過調(diào)節(jié)權(quán)重Win和Wout就可以實(shí)現(xiàn)學(xué)習(xí)的效果。

但是當(dāng)我們處理序列信息的時(shí)候,某些前面的輸入和后面的輸入是有關(guān)系的,比如:當(dāng)我們?cè)诶斫庖痪湓捯馑紩r(shí),孤立的理解這句話的每個(gè)詞是不夠的,我們需要處理這些詞連接起來的整個(gè)序列;這個(gè)時(shí)候我們就需要使用到循環(huán)神經(jīng)網(wǎng)絡(luò)(Recurrent Neural Network)。比如:手機(jī)壞了,我要買一個(gè)256g蘋果,結(jié)合前面的手機(jī)壞了,這個(gè)蘋果含義是一臺(tái)手機(jī),而不是不是吃的蘋果。

1 基本循環(huán)神經(jīng)網(wǎng)絡(luò)RNN

X是輸入向量,O是輸出,S是隱藏層,U是輸入到隱藏層的權(quán)重矩陣,V是隱藏層到輸出層的權(quán)重矩陣,循環(huán)神經(jīng)網(wǎng)絡(luò)的隱藏層的值s不僅僅取決于當(dāng)前這次的輸入x,還取決于上一次隱藏層的值s。權(quán)重矩陣w就是隱藏層上一次的值作為這一次的輸入的權(quán)重。上面的圖可以在時(shí)間維度進(jìn)行展開成一個(gè)鏈?zhǔn)降慕Y(jié)構(gòu)。

這個(gè)網(wǎng)絡(luò)在t時(shí)刻接收到輸入Xt之后,隱藏層的值是St,輸出值是Ot。關(guān)鍵一點(diǎn)是,St的值不僅僅取決于Xt,還取決于St-1。

隱藏層計(jì)算公式為:St=f(UXt + W * St-1),其中f為激活函數(shù)。
輸出層計(jì)算公式為:Ot = g(V
St),其中g(shù)為激活函數(shù),V是權(quán)重矩陣。

結(jié)合前面的例子,“手機(jī)壞了,我要買一個(gè)256g蘋果”,被分詞之后,成一組向量[X1,X2,...,X6]


循環(huán)神經(jīng)網(wǎng)絡(luò)從左到右閱讀這個(gè)句子,不斷調(diào)用相同的RNN CELL來處理。但是上面方法有一個(gè)明顯的缺陷,當(dāng)閱讀的句子很長(zhǎng)的時(shí)候,網(wǎng)絡(luò)會(huì)變得復(fù)雜甚至無效。當(dāng)前面的信息在傳遞到后面的同時(shí),信息的權(quán)重會(huì)下降(梯度爆炸和梯度消失),導(dǎo)致預(yù)測(cè)不準(zhǔn)。比如下面兩句話,was和were要根據(jù)前面的student的單復(fù)數(shù)來確定。句子過長(zhǎng)的情況下,就難以判定了,因此RNN這種網(wǎng)絡(luò)被稱為短時(shí)記憶網(wǎng)絡(luò)(Short Term Memory)。

he student,who got A+ in the exam,was excellent.
The students,who got A+ in the exam,were excellent.

2 LSTM循環(huán)神經(jīng)網(wǎng)絡(luò)

 為了解決上面記憶信息不足的問題,引入一種長(zhǎng)短時(shí)記憶網(wǎng)絡(luò)(Long Short Term Memory,LSTM)。原始RNN的隱藏層只有一個(gè)狀態(tài),即h,它對(duì)于短期的輸入非常敏感。那么如果我們?cè)僭黾右粋€(gè)門(gate)機(jī)制用于控制特征的流通和損失,即c,讓它來保存長(zhǎng)期的狀態(tài)。

左邊是不同時(shí)刻的X,中間黃色球是隱藏層H,右邊綠色是輸出Y,和基本RNN相比,除了黃色的鏈條(Short Term Memory),LSTM增加了一個(gè)新的紅色鏈條,用C來表示,叫LongTerm Memory,且兩個(gè)鏈條相互作用,相關(guān)更新,將這兩條線“拍平”后如下:


我們?cè)趯W(xué)習(xí)的時(shí)候,會(huì)經(jīng)常看到下面這幅圖,比較難以理解,主要是二維的圖難以想象成三維的結(jié)構(gòu)。理解上面三維結(jié)構(gòu)的兩條線,就知道下面的二維圖兩條線是怎么進(jìn)行數(shù)據(jù)更新的。

現(xiàn)在正式介紹LSTM中三個(gè)重要的門結(jié)構(gòu)。

2.1 遺忘門

函數(shù)f1是sigmoid函數(shù),可以把矩陣的值壓縮到0-1之間,矩陣元素相乘的時(shí)候,因?yàn)槿魏螖?shù)乘以 0 都得 0,這部分信息就會(huì)剔除掉。同樣的,任何數(shù)乘以 1 都得到它本身,這部分信息就會(huì)完美地保存下來。這樣網(wǎng)絡(luò)就能了解哪些數(shù)據(jù)是需要遺忘,哪些數(shù)據(jù)是需要保存。


數(shù)據(jù)更新公式:


與基本RNN的內(nèi)部結(jié)構(gòu)計(jì)算非常相似,首先將當(dāng)前時(shí)間步輸入x(t)與上一個(gè)時(shí)間步隱含狀態(tài)h(t—1)拼接,得到[x(t),h(t—1)],然后通過一個(gè)全連接層做變換,最后通過sigmoid函數(shù)進(jìn)行激活得到f1(t),我們可以將f1(t)看作是門值,好比一扇門開合的大小程度,門值都將作用在通過該扇門的張量,遺忘門門值將作用的上一層的細(xì)胞狀態(tài)上,代表遺忘過去的多少信息,又因?yàn)檫z忘門門值是由x(t),h(t—1)計(jì)算得來的,因此整個(gè)公式意味著根據(jù)當(dāng)前時(shí)間步輸入和上一個(gè)時(shí)間步隱含狀態(tài)h(t—1)來決定遺忘多少上一層的細(xì)胞狀態(tài)所攜帶的過往信息。

2.2 輸入門

我們看到輸入門的計(jì)算公式有兩個(gè),第一個(gè)就是產(chǎn)生輸入門門值的公式,它和遺忘門公式幾乎相同,區(qū)別只是在于它們之后要作用的目標(biāo)上,這個(gè)公式意味著輸入信息有多少需要進(jìn)行過濾。輸入門的第二個(gè)公式是與傳統(tǒng)RNN的內(nèi)部結(jié)構(gòu)計(jì)算相同。對(duì)于LSTM來講,它得到的是當(dāng)前的細(xì)胞狀態(tài),而不是像經(jīng)典RNN一樣得到的是隱含狀態(tài)。最后,第一個(gè)公式f1與上一次的Ct-1做全連接然后,加上f2之后的結(jié)果,更新給當(dāng)前的Ct,這個(gè)過程被稱為L(zhǎng)STM的細(xì)胞狀態(tài)更新。

2.3 輸出門


輸出門部分的公式也是兩個(gè),第一個(gè)即是計(jì)算輸出門的門值,它和遺忘門,輸入門計(jì)算方式相同。第二個(gè)即是使用這個(gè)門值產(chǎn)生隱含狀態(tài)h(t),他將作用在更新后的細(xì)胞狀態(tài)C(t)上,并做tanh激活,最終得到h(t)作為下一時(shí)間步輸入的一部分.整個(gè)輸出門的過程,就是為了產(chǎn)生隱含狀態(tài)h(t)。

上面的過程用代碼表示就是

# 定義一個(gè)LSTM模型
class LSTM(nn.Module):
   def __init__(self, input_size, hidden_size, num_layers, output_size):
       super(LSTM, self).__init__()
       self.hidden_size = hidden_size
       self.num_layers = num_layers
       self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
       self.fc = nn.Linear(hidden_size, output_size)
       
   def forward(self, x):
       # 初始化隱藏狀態(tài)h0, c0為全0向量
       h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
       c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
       # 將輸入x和隱藏狀態(tài)(h0, c0)傳入LSTM網(wǎng)絡(luò)
       out, _ = self.lstm(x, (h0, c0))
       # 取最后一個(gè)時(shí)間步的輸出作為L(zhǎng)STM網(wǎng)絡(luò)的輸出
       out = self.fc(out[:, -1, :])
       return out

3 BiLSTM神經(jīng)網(wǎng)絡(luò)

雙向循環(huán)神經(jīng)網(wǎng)絡(luò)(Bi-Directional Long Short-Term Memory,BiLSTM)是一種特殊的循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)架構(gòu),它包含一個(gè)正向LSTM 層和一個(gè)反向LSTM層。這兩個(gè)LSTM層分別對(duì)序列中的元素進(jìn)行正向和反向傳遞,并在最后的隱藏層中進(jìn)行合并。這樣,BiLSTM可以同時(shí)考慮序列中的歷史信息和未來信息,使得它在處理序列數(shù)據(jù)任務(wù)中(如文本分類和序列標(biāo)注)有著良好的表現(xiàn)。

前向的LSTML依次輸入“我”,“愛”,“你”得到三個(gè)向量{h0,h1,h2}。后向的LSTMR依次輸入“你”,“愛”,“我”得到三個(gè)向量{h5,h4,h3}。最后將前向和后向的隱向量進(jìn)行拼接得到{[h0,h5],[h1,h4],[h2,h3]},即{A,B,C},對(duì)于情感分類任務(wù)來說,我們采用的句子表示往往是[h2,h5],因?yàn)檫@其中包含了前向和后向的所有信息。

4 電影評(píng)價(jià)的極性分析實(shí)踐

4.1 劃分訓(xùn)練集、測(cè)試集

import torch
import torch.nn.functional as F
from torchtext import data
from torchtext import datasets
from torchtext.legacy import data, datasets
import time
import random
torch.backends.cudnn.deterministic = True

# 定義超參數(shù)
RANDOM_SEED = 123
torch.manual_seed(RANDOM_SEED)

VOCABULARY_SIZE = 20000
LEARNING_RATE = 1e-3
BATCH_SIZE = 128
NUM_EPOCHS = 15
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


if __name__ == '__main__':
   # 注意:由于 RNN 只能處理序列中的非 padded 元素(即非0數(shù)據(jù))
   # 對(duì)于任何 padded 元素輸出都是 0 。所以在準(zhǔn)備數(shù)據(jù)的時(shí)候?qū)nclude_length設(shè)置為True
   # 以獲得句子的實(shí)際長(zhǎng)度。
   TEXT = data.Field(tokenize='spacy', include_lengths=True, tokenizer_language='en_core_web_sm')
   LABEL = data.LabelField(dtype=torch.float)
   train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
   datasets.IMDB.splits(TEXT,LABEL)

   # 從訓(xùn)練集中選取部分做驗(yàn)證集
   train_data, valid_data = train_data.split(random_state=random.seed(RANDOM_SEED), split_ratio=0.8)

   print(f'Num Train: {len(train_data)}')
   print(f'Num Valid: {len(valid_data)}')
   print(f'Num Test: {len(test_data)}')
   print("train_data[0:200]", test_data.examples[0].text[0:100])
Num Train: 20000
Num Valid: 5000
Num Test: 25000
train_data[0:200] ['Based', 'on', 'an', 'actual', 'story', ',', 'John', 'Boorman', 'shows', 'the', 'struggle', 'of', 'an', 'American', 'doctor', ',', 'whose', 'husband', 'and', 'son', 'were', 'murdered', 'and', 'she', 'was', 'continually', 'plagued', 'with', 'her', 'loss', '.', 'A', 'holiday', 'to', 'Burma', 'with', 'her', 'sister', 'seemed', 'like', 'a', 'good', 'idea', 'to', 'get', 'away', 'from', 'it', 'all', ',', 'but', 'when', 'her', 'passport', 'was', 'stolen', 'in', 'Rangoon', ',', 'she', 'could', 'not', 'leave', 'the', 'country', 'with', 'her', 'sister', ',', 'and', 'was', 'forced', 'to', 'stay', 'back', 'until', 'she', 'could', 'get', 'I.D.', 'papers', 'from', 'the', 'American', 'embassy', '.', 'To', 'fill', 'in', 'a', 'day', 'before', 'she', 'could', 'fly', 'out', ',', 'she', 'took', 'a']

4.2 創(chuàng)建詞向量

TEXT.build_vocab(train_data, max_size=VOCABULARY_SIZE)
LABEL.build_vocab(train_data)

print(f'Vocabulary size: {len(TEXT.vocab)}')
print(f'Number of classes: {len(LABEL.vocab)}')

輸出:

Vocabulary size: 20002
Number of classes: 2

TEXT.build_vocab表示從預(yù)訓(xùn)練的詞向量中,將當(dāng)前訓(xùn)練數(shù)據(jù)中的詞匯的詞向量抽取出來,構(gòu)成當(dāng)前訓(xùn)練集的 Vocab(詞匯表)。對(duì)于當(dāng)前詞向量語料庫中沒有出現(xiàn)的單詞(記為UNK)。

4.3 創(chuàng)建數(shù)據(jù)迭代器

BATCH_SIZE = 64

# 根據(jù)當(dāng)前環(huán)境選擇是否調(diào)用GPU進(jìn)行訓(xùn)練
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 創(chuàng)建數(shù)據(jù)迭代器
train_loader, valid_loader, test_loader = data.BucketIterator.splits(
   (train_data, valid_data, test_data),
   batch_size=BATCH_SIZE,
   sort_within_batch=True,  # 為了 packed_padded_sequence
   device=device)

4.4 定義RNN模型

class RNN(nn.Module):
   def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
       super().__init__()

       self.embedding = nn.Embedding(input_dim, embedding_dim)
       self.rnn = nn.RNN(embedding_dim, hidden_dim)
       self.fc = nn.Linear(hidden_dim, output_dim)

   def forward(self, text, text_length):
       embedded = self.embedding(text)
       # pack_padded_sequence 技術(shù)的應(yīng)用
       packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, text_length)
       output, hidden = self.rnn(packed)
       #squeeze(0)的作用是將張量中維度大小為1的維度進(jìn)行壓縮,減少張量的維度數(shù)量和大小。
       #如果張量中沒有維度大小為1的維度,那么squeeze(0)函數(shù)不會(huì)對(duì)張量進(jìn)行任何修改,它會(huì)返回與原始張量相同的張量
       # view(-1)將張量重塑為一維形狀
       return self.fc(hidden.squeeze(0)).view(-1)

關(guān)于pack_padded_sequence(處理Pad問題)的解釋:
"Pad問題"是指填充操作中的一個(gè)常見問題,即如何處理填充元素(通常用特殊的占位符,如<pad>)對(duì)模型訓(xùn)練和推理的影響。我們需要對(duì)電影評(píng)論進(jìn)行情感分類,這些評(píng)論往往具有不同長(zhǎng)度的單詞數(shù)量。當(dāng)我們將這些評(píng)論句子作為輸入傳遞給循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN)進(jìn)行處理時(shí),由于RNN的輸入需要是固定長(zhǎng)度的張量,我們需要對(duì)序列進(jìn)行填充(padding)操作,使得每個(gè)評(píng)論都具有相同的長(zhǎng)度。

假設(shè)我們有三個(gè)電影評(píng)論,分別是"這是一部很好看的電影","這個(gè)電影一般般"和"我不喜歡這部電影"。我們可以將這些評(píng)論編碼為以下張量:

評(píng)論1: [這, 是, 一部, 很, 好看, 的, 電影]
評(píng)論2: [這個(gè), 電影, 一般般]
評(píng)論3: [我, 不喜歡, 這部, 電影]

在這個(gè)例子中,我們有3個(gè)電影評(píng)論。它們的長(zhǎng)度分別是7、3和4。我們需要將它們填充到相同的長(zhǎng)度,以便能夠?qū)⑺鼈冏鳛橐粋€(gè)批次輸入到模型中。填充后的序列是:

評(píng)論1: [這, 是, 一部, 很, 好看, 的, 電影]
評(píng)論2: [這個(gè), 電影, 一般般, <pad>, <pad>, <pad>, <pad>]
評(píng)論3: [我, 不喜歡, 這部, 電影, <pad>, <pad>, <pad>]

4.5 RNN模型訓(xùn)練

INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 128
HIDDEN_DIM = 256
OUTPUT_DIM = 1

torch.manual_seed(RANDOM_SEED)
model = RNN(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM)
model = model.to(DEVICE)
#選擇Adam優(yōu)化器
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

start_time = time.time()

for epoch in range(NUM_EPOCHS):
   model.train()
   for batch_idx, batch_data in enumerate(train_loader):
       text, text_lengths = batch_data.text
       logits = model(text, text_lengths)
       cost = F.binary_cross_entropy_with_logits(logits, batch_data.label)
       optimizer.zero_grad()
       cost.backward()
       optimizer.step()

       if not batch_idx % 50:
           print(f'Epoch: {epoch + 1:03d}/{NUM_EPOCHS:03d} | '
                 f'Batch {batch_idx:03d}/{len(train_loader):03d} | '
                 f'Cost: {cost:.4f}')

4.6 RNN模型評(píng)估

def compute_binary_accuracy(model, data_loader, device):
   model.eval()
   correct_pred, num_examples = 0, 0
   with torch.no_grad():
       for batch_idx, batch_data in enumerate(data_loader):
           text, text_lengths = batch_data.text
           logits = model(text, text_lengths)
           predicted_labels = (torch.sigmoid(logits) > 0.5).long()
           num_examples += batch_data.label.size(0)
           correct_pred += (predicted_labels == batch_data.label.long()).sum()
       return correct_pred.float() / num_examples * 100


def predict_sentiment(model, sentence):
   model.eval()
   tokenized = [tok.text for tok in nlp.tokenizer(sentence)]
   indexed = [TEXT.vocab.stoi[t] for t in tokenized]
   length = [len(indexed)]
   tensor = torch.LongTensor(indexed).to(DEVICE)
   tensor = tensor.unsqueeze(1)
   length_tensor = torch.LongTensor(length)
   prediction = torch.sigmoid(model(tensor, length_tensor))
   return prediction.item()
with torch.set_grad_enabled(False):
    print(f'training accuracy: '
                 f'{compute_binary_accuracy(model, train_loader, DEVICE):.2f}%'
                 f'\nvalid accuracy: '
                 f'{compute_binary_accuracy(model, valid_loader, DEVICE):.2f}%')

print(f'Time elapsed: {(time.time() - start_time) / 60:.2f} min')
print(f'Total Training Time: {(time.time() - start_time) / 60:.2f} min')
print(f'Test accuracy: {compute_binary_accuracy(model, test_loader, DEVICE):.2f}%')

nlp = spacy.load('en_core_web_sm')
ret = predict_sentiment(model, "I really love this movie. This movie is so great!")
print("ret=", ret)

輸出:

Num Train: 20000
Num Valid: 5000
Num Test: 25000
train_data[0:200] ['Based', 'on', 'an', 'actual', 'story', ',', 'John', 'Boorman', 'shows', 'the', 'struggle', 'of', 'an', 'American', 'doctor', ',', 'whose', 'husband', 'and', 'son', 'were', 'murdered', 'and', 'she', 'was', 'continually', 'plagued', 'with', 'her', 'loss', '.', 'A', 'holiday', 'to', 'Burma', 'with', 'her', 'sister', 'seemed', 'like', 'a', 'good', 'idea', 'to', 'get', 'away', 'from', 'it', 'all', ',', 'but', 'when', 'her', 'passport', 'was', 'stolen', 'in', 'Rangoon', ',', 'she', 'could', 'not', 'leave', 'the', 'country', 'with', 'her', 'sister', ',', 'and', 'was', 'forced', 'to', 'stay', 'back', 'until', 'she', 'could', 'get', 'I.D.', 'papers', 'from', 'the', 'American', 'embassy', '.', 'To', 'fill', 'in', 'a', 'day', 'before', 'she', 'could', 'fly', 'out', ',', 'she', 'took', 'a']
Vocabulary size: 20002
Number of classes: 2
Epoch: 001/004 | Batch 000/313 | Cost: 0.7078
Epoch: 001/004 | Batch 050/313 | Cost: 0.6911
Epoch: 001/004 | Batch 100/313 | Cost: 0.6901
Epoch: 001/004 | Batch 150/313 | Cost: 0.6965
Epoch: 001/004 | Batch 200/313 | Cost: 0.6274
Epoch: 001/004 | Batch 250/313 | Cost: 0.6855
Epoch: 001/004 | Batch 300/313 | Cost: 0.6413
training accuracy: 66.27%
valid accuracy: 65.26%
Time elapsed: 6.34 min
Epoch: 002/004 | Batch 000/313 | Cost: 0.6546
Epoch: 002/004 | Batch 050/313 | Cost: 0.6024
Epoch: 002/004 | Batch 100/313 | Cost: 0.6676
Epoch: 002/004 | Batch 150/313 | Cost: 0.6437
Epoch: 002/004 | Batch 200/313 | Cost: 0.6236
Epoch: 002/004 | Batch 250/313 | Cost: 0.6862
Epoch: 002/004 | Batch 300/313 | Cost: 0.5634
training accuracy: 54.29%
valid accuracy: 52.32%
Time elapsed: 12.72 min
Epoch: 003/004 | Batch 000/313 | Cost: 0.6892
Epoch: 003/004 | Batch 050/313 | Cost: 0.6420
Epoch: 003/004 | Batch 100/313 | Cost: 0.6250
Epoch: 003/004 | Batch 150/313 | Cost: 0.6815
Epoch: 003/004 | Batch 200/313 | Cost: 0.5970
Epoch: 003/004 | Batch 250/313 | Cost: 0.6502
Epoch: 003/004 | Batch 300/313 | Cost: 0.5945
training accuracy: 68.32%
valid accuracy: 61.98%
Time elapsed: 19.41 min
Epoch: 004/004 | Batch 000/313 | Cost: 0.5901
Epoch: 004/004 | Batch 050/313 | Cost: 0.3887
Epoch: 004/004 | Batch 100/313 | Cost: 0.6483
Epoch: 004/004 | Batch 150/313 | Cost: 0.5912
Epoch: 004/004 | Batch 200/313 | Cost: 0.5973
Epoch: 004/004 | Batch 250/313 | Cost: 0.4288
Epoch: 004/004 | Batch 300/313 | Cost: 0.4574
training accuracy: 75.49%
valid accuracy: 65.98%
Time elapsed: 25.52 min
Total Training Time: 25.52 min
Test accuracy: 65.84%
ret= 0.9203115105628967

0.9大于0.5 代表是積極的觀點(diǎn),但是測(cè)試集上準(zhǔn)確率只有65.84%。

4.7 定義LSTM模型

class LSTM(nn.Module):
   def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
       super().__init__()
       self.embedding = nn.Embedding(input_dim, embedding_dim)
       self.lstm = nn.LSTM(embedding_dim, hidden_dim)
       self.fc = nn.Linear(hidden_dim, output_dim)

   def forward(self, text, text_length):
       embedded = self.embedding(text)
       packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, text_length)
       packed_output, (hidden, cell) = self.lstm(packed)

       return self.fc(hidden.unsqueeze(0)).view(-1)

4.8 LSTM 模型訓(xùn)練和評(píng)估


model2 = LSTM(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM).to(DEVICE)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=LEARNING_RATE)
start_time = time.time()
for epoch in range(NUM_EPOCHS):
   model2.train()
   for batch_idx, batch_data in enumerate(train_loader):
       text, text_lengths = batch_data.text
       logits = model2(text, text_lengths)
       cost2 = F.binary_cross_entropy_with_logits(logits, batch_data.label)
       optimizer2.zero_grad()
       cost2.backward()
       optimizer2.step()
       
       if not batch_idx % 50:
           print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '
                  f'Batch {batch_idx:03d}/{len(train_loader):03d} | '
                  f'Cost: {cost2:.4f}')

   with torch.set_grad_enabled(False):
       print(f'training accuracy: '
             f'{compute_binary_accuracy(model2, train_loader, DEVICE):.2f}%'
             f'\nvalid accuracy: '
             f'{compute_binary_accuracy(model2, valid_loader, DEVICE):.2f}%')
       
   print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')
   
print(f'Total Training Time: {(time.time() - start_time)/60:.2f} min')
print(f'Test accuracy: {compute_binary_accuracy(model2, test_loader, DEVICE):.2f}%')

輸出:

Num Train: 20000
Num Valid: 5000
Num Test: 25000
train_data[0:200] ['Based', 'on', 'an', 'actual', 'story', ',', 'John', 'Boorman', 'shows', 'the', 'struggle', 'of', 'an', 'American', 'doctor', ',', 'whose', 'husband', 'and', 'son', 'were', 'murdered', 'and', 'she', 'was', 'continually', 'plagued', 'with', 'her', 'loss', '.', 'A', 'holiday', 'to', 'Burma', 'with', 'her', 'sister', 'seemed', 'like', 'a', 'good', 'idea', 'to', 'get', 'away', 'from', 'it', 'all', ',', 'but', 'when', 'her', 'passport', 'was', 'stolen', 'in', 'Rangoon', ',', 'she', 'could', 'not', 'leave', 'the', 'country', 'with', 'her', 'sister', ',', 'and', 'was', 'forced', 'to', 'stay', 'back', 'until', 'she', 'could', 'get', 'I.D.', 'papers', 'from', 'the', 'American', 'embassy', '.', 'To', 'fill', 'in', 'a', 'day', 'before', 'she', 'could', 'fly', 'out', ',', 'she', 'took', 'a']
Vocabulary size: 20002
Number of classes: 2
Epoch: 001/010 | Batch 000/313 | Cost: 0.6930
Epoch: 001/010 | Batch 050/313 | Cost: 0.6436
Epoch: 001/010 | Batch 100/313 | Cost: 0.6402
Epoch: 001/010 | Batch 150/313 | Cost: 0.5405
Epoch: 001/010 | Batch 200/313 | Cost: 0.6803
Epoch: 001/010 | Batch 250/313 | Cost: 0.6905
Epoch: 001/010 | Batch 300/313 | Cost: 0.6695
training accuracy: 56.28%
valid accuracy: 56.62%
Time elapsed: 73.09 min
Epoch: 002/010 | Batch 000/313 | Cost: 0.6772
Epoch: 002/010 | Batch 050/313 | Cost: 0.6866
Epoch: 002/010 | Batch 100/313 | Cost: 0.6674
Epoch: 002/010 | Batch 150/313 | Cost: 0.6037
Epoch: 002/010 | Batch 200/313 | Cost: 0.6808
Epoch: 002/010 | Batch 250/313 | Cost: 0.6685
Epoch: 002/010 | Batch 300/313 | Cost: 0.6927
training accuracy: 50.33%
valid accuracy: 50.40%
Time elapsed: 148.10 min
Epoch: 003/010 | Batch 000/313 | Cost: 0.7443
Epoch: 003/010 | Batch 050/313 | Cost: 0.6509
Epoch: 003/010 | Batch 100/313 | Cost: 0.6160
Epoch: 003/010 | Batch 150/313 | Cost: 0.6501
Epoch: 003/010 | Batch 200/313 | Cost: 0.5341
Epoch: 003/010 | Batch 250/313 | Cost: 0.4378
Epoch: 003/010 | Batch 300/313 | Cost: 0.4366
training accuracy: 84.29%
valid accuracy: 81.76%
Time elapsed: 218.03 min
Epoch: 004/010 | Batch 000/313 | Cost: 0.3864
Epoch: 004/010 | Batch 050/313 | Cost: 0.2678
Epoch: 004/010 | Batch 100/313 | Cost: 0.2225
Epoch: 004/010 | Batch 150/313 | Cost: 0.3614
Epoch: 004/010 | Batch 200/313 | Cost: 0.2415
Epoch: 004/010 | Batch 250/313 | Cost: 0.1816
Epoch: 004/010 | Batch 300/313 | Cost: 0.2577
training accuracy: 91.74%
valid accuracy: 86.48%
Time elapsed: 285.13 min
Test accuracy: 90.72%
ret= 0.9601113557815552 
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 227,156評(píng)論 6 529
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 97,866評(píng)論 3 413
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事。” “怎么了?”我有些...
    開封第一講書人閱讀 174,880評(píng)論 0 373
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)。 經(jīng)常有香客問我,道長(zhǎng),這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 62,398評(píng)論 1 308
  • 正文 為了忘掉前任,我火速辦了婚禮,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 71,202評(píng)論 6 405
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 54,743評(píng)論 1 320
  • 那天,我揣著相機(jī)與錄音,去河邊找鬼。 笑死,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 42,822評(píng)論 3 438
  • 文/蒼蘭香墨 我猛地睜開眼,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 41,962評(píng)論 0 285
  • 序言:老撾萬榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 48,476評(píng)論 1 331
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 40,444評(píng)論 3 354
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 42,579評(píng)論 1 365
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 38,129評(píng)論 5 355
  • 正文 年R本政府宣布,位于F島的核電站,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 43,840評(píng)論 3 344
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 34,231評(píng)論 0 25
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 35,487評(píng)論 1 281
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 51,177評(píng)論 3 388
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 47,568評(píng)論 2 370

推薦閱讀更多精彩內(nèi)容