很多前人曾說過,深度學習好比煉丹,框架就是丹爐,網絡結構及算法就是單方,而數據集則是原材料,為了能夠煉好丹,首先需要一個使用稱手的丹爐,同時也要有好的單方和原材料,最后就需要煉丹師們有著足夠的經驗和技巧掌握火候和時機,這樣方能煉出絕世好丹。
對于剛剛進入煉丹行業的煉丹師,網上都有一些前人總結的煉丹技巧,同時也有很多煉丹師的心路歷程以及丹師對整個煉丹過程的記錄,有了這些,無疑能夠非常快速知道如何煉丹。但是現在市面上的入門煉丹手冊往往都是將原材料幫你放到了丹爐中,你只需要將丹爐開啟,然后進行簡單的調試,便能出丹。這樣做無疑減少了大家入門的難度,但是往往到了自己真正煉丹的時候便會手足無措,不知道如何將原材料放入丹爐。
本篇煉丹入門指導便是使用PyTorch這個丹爐,教你如何將原材料放入丹爐,雖然這一步并不涉及太多算法,但是卻是煉丹開始非常重要的一步。
PyTorch數據讀入函數介紹
ImageFolder
在PyTorch中有一個現成實現的數據讀取方法,是torchvision.datasets.ImageFolder
,這個api是仿照keras寫的,主要是做分類問題,將每一類數據放到同一個文件夾中,比如有10個類別,那么就在一個大的文件夾下面建立10個子文件夾,每個子文件夾里面放的是同一類的數據。
通過這個函數能夠很簡單的建立一個數據I/O,但是問題來了,如果我要處理的數據不是這樣一個簡單的分類問題,比如我要做機器翻譯,那么我的輸入和輸出都是一個句子,這樣該怎么進行數據讀入呢?
這個問題非常容易解決,我們可以看看ImageFolder的實現,可以發現其是torch.utils.data.Dataset
的子類,所以下面我們介紹一下torch.utils.data.Dataset
這個類。
Dataset
我們可以發現Dataset的定義是下面這樣
這里注釋是說這是一個代表著數據集的抽象類,所有關于數據集的類都可以定義為其子類,只需要重寫__getitem__
和__len__
就可以了。我們再回去看看ImageFolder的實現,確實是這樣的,那么現在問題就變得很簡單,對于機器翻譯問題,我們只需要定義整個數據集的長度,同時定義取出其中一個索引的元素即可。
那么定義好了數據集我們不可能將所有的數據集都放到內存,這樣內存肯定就爆了,我們需要定義一個迭代器,每一步產生一個batch,這里PyTorch已經為我們實現好了,就是下面的torch.utils.data.DataLoader
。
DataLoader
DataLoader能夠為我們自動生成一個多線程的迭代器,只要傳入幾個參數進行就可以了,第一個參數就是上面定義的數據集,后面幾個參數就是batch size的大小,是否打亂數據,讀取數據的線程數目等等,這樣一來,我們就建立了一個多線程的I/O。
讀到這里,你可能覺得PyTorch真的太方便了,這個丹爐真的好用,然后便迫不及待的嘗試了一下,然后有可能性就報錯了,而且你也是一步一步按著實現來的,怎么就報錯了呢?不用著急,下面就來講一下為什么會報錯,以及這一塊pyhon實現的解讀,這樣你就能夠真正知道如何進行自定義的數據讀入。
問題來源
通過上面的實現,可能會遇到各種不同的問題,Dataset非常簡單,一般都不會有錯,只要Dataset實現正確,那么問題的來源只有一個,那就是torch.utils.data.DataLoader
中的一個參數collate_fn
,這里我們需要找到DataLoader的源碼進行查看這個參數到底是什么。
可以看到collate_fn
默認是等于default_collate
,那么這個函數的定義如下。
是不是看著有點頭大,沒有關系,我們先搞清楚他的輸入是什么。這里可以看到他的輸入被命名為batch,但是我們還是不知道到底是什么,可以猜測應該是一個batch size的數據。我們繼續往后找,可以找到這個地方。
我們可以從這里看到collate_fn
在這里進行了調用,那么他的輸入我們就找到了,從這里看這就是一個list,list中的每個元素就是self.data[i]
,如果你在往上看,可以看到這個self.data
就是我們需要預先定義的Dataset,那么這里self.data[i]
就等價于我們在Dataset里面定義的__getitem__
這個函數。
所以我們知道了collate_fn
這個函數的輸入就是一個list,list的長度是一個batch size,list中的每個元素都是__getitem__
得到的結果。
這時我們再去看看collate_fn
這個函數,其實可以看到非常簡單,就是通過對一些情況的排除,然后最后輸出結果,比如第一個if,如果我們的輸入是一個tensor,那么最后會將一個batch size的tensor重新stack在一起,比如輸入的tensor是一張圖片,3x30x30,如果batch size是32,那么按第一維stack之后的結果就是32x3x30x30,這里stack和concat有一點區別就是會增加一維。
所以通過上面的源碼解讀我們知道了數據讀入具體是如何操作的,那么我們就能夠實現自定義的數據讀入了,我們需要自己按需要重新定義collate_fn
這個函數,下面舉個例子。
自定義數據讀入的舉例實現
下面我們來舉一個麻煩的例子,比如做文本識別,需要將一張圖片上的字符識別出來,比如下面這些圖片
那么這個問題的輸入就是一張一張的圖片,他的label就是一串字符,但是由于長度是變化的,所以這個問題比較麻煩。
下面我們就來簡單實現一下。
我們有一個train.txt的文件,上面有圖片的名稱和對應的label,首先我們需要定義一個Dataset。
class custom_dset(Dataset):
def __init__(self,
img_path,
txt_path,
img_transform=None,
loader=default_loader):
with open(txt_path, 'r') as f:
lines = f.readlines()
self.img_list = [
os.path.join(img_path, i.split()[0]) for i in lines
]
self.label_list = [i.split()[1] for i in lines]
self.img_transform = img_transform
self.loader = loader
def __getitem__(self, index):
img_path = self.img_list[index]
label = self.label_list[index]
# img = self.loader(img_path)
img = img_path
if self.img_transform is not None:
img = self.img_transform(img)
return img, label
def __len__(self):
return len(self.label_list)
這里非常簡單,就是將txt文件打開,然后分別讀取圖片名和label,由于存放圖片的文件夾我并沒有放上去,因為數據太大,所以讀取圖片以及對圖片做一些變換的操作就不進行了。
接著我們自定義一個collate_fn
,這里可以使用任何名字,只要在DataLoader里面傳入就可以了。
def collate_fn(batch):
batch.sort(key=lambda x: len(x[1]), reverse=True)
img, label = zip(*batch)
pad_label = []
lens = []
max_len = len(label[0])
for i in range(len(label)):
temp_label = [0] * max_len
temp_label[:len(label[i])] = label[i]
pad_label.append(temp_label)
lens.append(len(label[i]))
return img, pad_label, lens
代碼的細節就不詳細說了,總體來講就是先按label長度進行排序,然后進行長度的pad,最后輸出圖片,label以及每個label的長度的list。
下面我們可以驗證一下,得到如下的結果。
具體的操作大家可以去玩一下,改一改,能夠實現任何你想要的輸出,比如圖片輸出為一個32x3x30x30的tensor,將label中的字母轉化為數字標示,然后也可以輸出為tensor,任何你想要的操作都可以在上面顯示的程序中執行。
以上就是本文所有的內容,后面的例子不是很完整,講得也不是很詳細,因為圖片數據太大,不好傳到github上,當然通過看代碼能夠更快的學習。通過本文的閱讀,大家應該都能夠掌握任何需要的數據讀入,如果有問題歡迎評論留言。
歡迎關注我的知乎專欄深度煉丹
歡迎訪問我的博客