1 為什么要整這一出
神經(jīng)網(wǎng)絡(luò)需要數(shù)據(jù)傳入才能進(jìn)行訓(xùn)練等操作,那怎樣才能把圖片以及標(biāo)注信息整合成神經(jīng)網(wǎng)絡(luò)正規(guī)輸入的格式呢?
回答:pytorch 的數(shù)據(jù)加載到模型的操作順序是這樣的:
① 創(chuàng)建一個(gè) Dataset 對(duì)象
② 創(chuàng)建一個(gè) DataLoader 對(duì)象
③ 循環(huán)這個(gè) DataLoader 對(duì)象,將img, label加載到模型中進(jìn)行訓(xùn)練
整之前,先了解一些基礎(chǔ)知識(shí)。
2 基礎(chǔ)知識(shí)
代碼中經(jīng)常看到這兩行,那Dataset和DataLoader是什么玩意?
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
2.1 Dataset
Dataset是一個(gè)包裝類,用來將數(shù)據(jù)包裝為Dataset類,然后傳入DataLoader中。
當(dāng)用戶想要加載自定義的數(shù)據(jù)時(shí),只需要繼承這個(gè)類,并且覆寫其中的兩個(gè)方法即可:
-
__len__
:實(shí)現(xiàn)len(dataset),返回整個(gè)數(shù)據(jù)集的大小。 -
__getitem__
:用來獲取一些索引的數(shù)據(jù),使dataset[i]返回?cái)?shù)據(jù)集中第i個(gè)樣本。 - 不覆寫這兩個(gè)方法會(huì)直接返回錯(cuò)誤。
簡(jiǎn)單看一眼,有點(diǎn)感覺就行,繼續(xù)往下。
class YoloDataset(Dataset):
def __init__(self, annotation_lines, input_shape, num_classes, train):
super(YoloDataset, self).__init__()
...
def __len__(self):
...
def __getitem__(self, index):
...
2.2 DataLoader
DataLoader將自定義的Dataset根據(jù)batch size大小、是否shuffle等封裝成一個(gè)Batch Size大小的Tensor,用于后面的訓(xùn)練。
- dataloader本質(zhì)上是一個(gè)可迭代對(duì)象,使用iter()訪問,不能使用next()訪問;
- 使用iter(dataloader)返回的是一個(gè)迭代器,然后可以使用next訪問;
- 一般使用
for inputs, labels in dataloaders
進(jìn)行可迭代對(duì)象的訪問;
DataLoader參數(shù)介紹:
class torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None, # <function default_collate>
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None)
部分關(guān)鍵參數(shù)含義:
- batch_size:每個(gè)batch的大小
- shuffle:在每個(gè)epoch開始的時(shí)候,是否對(duì)數(shù)據(jù)進(jìn)行重新排序
- num_workers:加載數(shù)據(jù)的時(shí)候使用幾個(gè)子進(jìn)程,0意味著所有的數(shù)據(jù)都會(huì)被load進(jìn)主進(jìn)程。(默認(rèn)為0)
- collate_fn:如何取樣本,可以自己定義函數(shù)來準(zhǔn)確地實(shí)現(xiàn)想要的功能
- drop_last:告訴如何處理數(shù)據(jù)集長(zhǎng)度除以batch_size 余下的數(shù)據(jù)。True就拋棄,否則保留
3 Dataset與DataLoader綜合使用
最樸實(shí)的情況:
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
for img, label in dataloader:
....
在YOLOv3中的操作示例:
train_dataset = YoloDataset(train_lines, input_shape, num_classes, train=True)
val_dataset = YoloDataset(val_lines, input_shape, num_classes, train=False)
# gen常寫為train_loader
gen = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
# gen_val常寫為val_loader
gen_val = DataLoader(val_dataset , shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
for iteration, batch in enumerate(gen):
images, targets = batch[0], batch[1]
...
那重寫的Dataset內(nèi)部是怎么操作的呢?它的輸入又是什么意思呢?
4 YoloDataset的實(shí)際使用
訓(xùn)練時(shí)會(huì)使用一些數(shù)據(jù)增強(qiáng)手段,包括:
1. 裁剪(需改變bbox)
2. 平移(需改變bbox)
3. 改變亮度
4. 加噪聲
5. 旋轉(zhuǎn)角度(需要改變bbox)
6. 鏡像(需要改變bbox)
7. cutout
整個(gè)學(xué)習(xí)過程中,存在兩個(gè)問題:
輸出GT box的[中心點(diǎn)x,中心點(diǎn)y,寬w,高h(yuǎn),cls_num],其中坐標(biāo)點(diǎn)位置以及box寬和高是歸一化的嗎?(0~1)
回答:看網(wǎng)絡(luò),YOLO需要?dú)w一化,SSD不需要?dú)w一化,原因是:網(wǎng)絡(luò)中使用的定位損失函數(shù)有區(qū)別。在網(wǎng)絡(luò)訓(xùn)練過程中,所謂的圖像縮放、扭曲、翻轉(zhuǎn),色域變換等數(shù)據(jù)增強(qiáng)技術(shù),都是在輸入圖像上變換嗎?有沒有增加訓(xùn)練數(shù)據(jù)量?
回答:數(shù)據(jù)增強(qiáng)不是數(shù)據(jù)擴(kuò)充。每一個(gè)epoch取出原數(shù)據(jù)后,樣本有一定概率使用數(shù)據(jù)增強(qiáng)技術(shù),這樣導(dǎo)致每一次訓(xùn)練的圖片其實(shí)有一些區(qū)別,并不完全相同。總結(jié),確實(shí)是在輸入圖像上變換的,沒有增加訓(xùn)練數(shù)據(jù)量。
直接看代碼:
import cv2
import numpy as np
from PIL import Image
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
#---------------------------------------------------------#
# 將圖像轉(zhuǎn)換成RGB圖像,防止灰度圖在預(yù)測(cè)時(shí)報(bào)錯(cuò)。
# 代碼僅僅支持RGB圖像的預(yù)測(cè),所有其它類型的圖像都會(huì)轉(zhuǎn)化成RGB
# .convert('RGB')的使用與理解,可見http://www.lxweimin.com/p/5b53af742ad5
#---------------------------------------------------------#
def cvtColor(image):
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
return image
else:
image = image.convert('RGB')
return image
def preprocess_input(image):
image /= 255.0
return image
class YoloDataset(Dataset):
def __init__(self, annotation_lines, input_shape, num_classes, train):
super(YoloDataset, self).__init__()
# annotation_lines[index]:圖片路徑 目標(biāo)1的xmin,ymin,xmax,ymax,class_num 目標(biāo)2的xmin,ymin,xmax,ymax,class_num ...
self.annotation_lines = annotation_lines
self.input_shape = input_shape # [416, 416] 【高,寬】
self.num_classes = num_classes # 20
self.length = len(self.annotation_lines) # self.annotation_lines是個(gè)list
self.train = train # self.train是bool型,用來確定是否進(jìn)行數(shù)據(jù)增強(qiáng),train時(shí)增強(qiáng),val時(shí)不增強(qiáng)
def __len__(self):
return self.length
def __getitem__(self, index):
index = index % self.length # 這一步保證index不超過length,不然self.annotation_lines[index]取不到值
# ---------------------------------------------------#
# 訓(xùn)練時(shí)進(jìn)行數(shù)據(jù)的隨機(jī)增強(qiáng)
# 驗(yàn)證時(shí)不進(jìn)行數(shù)據(jù)的隨機(jī)增強(qiáng)
# ---------------------------------------------------#
image, box = self.get_random_data(self.annotation_lines[index], self.input_shape[0:2], random=self.train)
# ---------------------------------------------#
# 把圖片數(shù)據(jù)image轉(zhuǎn)成CHW格式,float32類型數(shù)據(jù),并歸一化
# ---------------------------------------------#
image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
box = np.array(box, dtype=np.float32)
if len(box) != 0:
# 左上點(diǎn)和右下點(diǎn)坐標(biāo)x 歸一化?
box[:, [0, 2]] = box[:, [0, 2]] / self.input_shape[1]
# 左上點(diǎn)和右下點(diǎn)坐標(biāo)y 歸一化?
box[:, [1, 3]] = box[:, [1, 3]] / self.input_shape[0]
# box位置信息從[xmin,ymin,xmax,ymax,cls_num]到[xmin,ymin,寬w,高h(yuǎn),cls_num]
box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
# box位置信息從[xmin,ymin,寬w,高h(yuǎn),cls_num]到[中心點(diǎn)x,中心點(diǎn)y,寬w,高h(yuǎn),cls_num]
box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2
return image, box
# 下面get_random_data函數(shù)中要用到這個(gè)函數(shù)
def rand(self, a=0, b=1):
# np.random.rand()返回一個(gè)或一組服從“0~1”均勻分布的隨機(jī)樣本值。
# 隨機(jī)樣本取值范圍是[0,1),不包括1
return np.random.rand() * (b - a) + a
def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True):
# ------------------------------#
# annotation_line是字符串,路徑、各標(biāo)簽信息之間 空格 隔開
# 進(jìn)過split(),line是list,每個(gè)元素是str
# ------------------------------#
line = annotation_line.split()
# ------------------------------#
# 讀取圖像并轉(zhuǎn)換成RGB圖像
# line[0]是路徑
# ------------------------------#
image = Image.open(line[0])
image = cvtColor(image)
# ------------------------------#
# 獲得圖像的高寬與目標(biāo)高寬
# ------------------------------#
iw, ih = image.size # 原圖的寬和高,Image讀取圖片,img.size返回圖片寬和高,詳見http://www.lxweimin.com/p/5b53af742ad5
h, w = input_shape # input_shape:[416, 416]
# ------------------------------#
# 獲得預(yù)測(cè)框
# 二維數(shù)組,里面每一維,一個(gè)bbox的標(biāo)簽
# 內(nèi)部操作:str->int 一個(gè)bbox的標(biāo)簽成list,再np轉(zhuǎn),再套個(gè)列表,再轉(zhuǎn)
# ------------------------------#
box = np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]])
# ----------------------------------#
# 不進(jìn)行數(shù)據(jù)增強(qiáng),也就是測(cè)試的時(shí)候
# random為False
# ----------------------------------#
if not random:
# -------------------------------------------#
# 獲取縮放參數(shù)
# 可參考http://www.lxweimin.com/p/2ae3a497f5f4
# -------------------------------------------#
scale = min(w / iw, h / ih)
nw = int(iw * scale)
nh = int(ih * scale)
dx = (w - nw) // 2
dy = (h - nh) // 2
# ---------------------------------#
# 原image等比例縮放后,新建一個(gè)期待大小的灰度圖,如416x416,
# 把縮放后的image,貼在灰圖上,從(dx,dy)那兒貼,也就是左上頂點(diǎn)對(duì)齊(dx,dy)
# 就像給圖像加灰條的感覺
# ---------------------------------#
image = image.resize((nw, nh), Image.BICUBIC)
new_image = Image.new('RGB', (w, h), (128, 128, 128))
new_image.paste(image, (dx, dy))
image_data = np.array(new_image, np.float32)
# ---------------------------------#
# 對(duì)真實(shí)框進(jìn)行調(diào)整
# ---------------------------------#
if len(box) > 0:
np.random.shuffle(box) # 用來打亂真實(shí)框的順序
# -----------------------------------------------#
# box是二維數(shù)組,里面一個(gè)元素:[xmin,ymin,xmax,ymax,class_num]
# 若 b = array([[1, 2, 3], [4, 5, 6]])
# 則 b[:,[0,2]]: array([[1, 3], [4, 6]])
# b[:,0:2]: array([[1, 2], [4, 5]])
# b[:,0:2]<0: array([[False, False], [False, False]])
# b[:,0:2][b[:,0:2]<2]=0,則b=array([[0, 2, 3], [4, 5, 6]])
# b[:,1]-b[:,0]:array([2, 1]),array對(duì)應(yīng)位置相減,得到一個(gè)array
# b[np.array([True, False])]:array([[0, 2, 3]])
# -----------------------------------------------#
# 對(duì)標(biāo)簽的xmin和xmax進(jìn)行變換,到resize后圖片里的位置
box[:, [0, 2]] = box[:, [0, 2]] * nw / iw + dx
# 對(duì)標(biāo)簽的ymin和ymax進(jìn)行變換,到resize后圖片里的位置
box[:, [1, 3]] = box[:, [1, 3]] * nh / ih + dy
# 出界了就整到邊界上去
# xmin和ymin小于0,就置為0
box[:, 0:2][box[:, 0:2] < 0] = 0
# xmax和ymax大于w和h,就置為w和h
box[:, 2][box[:, 2] > w] = w
box[:, 3][box[:, 3] > h] = h
box_w = box[:, 2] - box[:, 0] # 得到框的寬
box_h = box[:, 3] - box[:, 1] # 得到框的高
# -------------------------------------------------#
# np.logical_and邏輯與,都是True,才為True。寬個(gè)高不大于1像素,就舍棄
# np.logical_and(box_w > 1, box_h > 1)得到一個(gè)array,
# 類似于array([False, False], dtype=bool)
# 初始:box[[GT框1信息], [GT框2信息], [GT框3信息]]
# 經(jīng)過:box[np.array([True, False, True])]
# 結(jié)果:box[[GT框1信息], [GT框3信息]]
# -------------------------------------------------#
box = box[np.logical_and(box_w > 1, box_h > 1)] # discard invalid box
return image_data, box # np.array的圖片數(shù)據(jù)、有效的np.array的標(biāo)簽數(shù)據(jù)
# ------------------------------------------#
# 下面都是 數(shù)據(jù)增強(qiáng)技術(shù)
# 所謂的圖像縮放、扭曲、翻轉(zhuǎn),色域變換等,都是在輸入圖像上變換嗎?有沒有增加訓(xùn)練數(shù)據(jù)量?
# ------------------------------------------#
# 對(duì)圖像進(jìn)行縮放并且進(jìn)行長(zhǎng)和寬的扭曲
# ------------------------------------------#
new_ar = w / h * self.rand(1 - jitter, 1 + jitter) / self.rand(1 - jitter, 1 + jitter)
scale = self.rand(.25, 2)
if new_ar < 1:
nh = int(scale * h)
nw = int(nh * new_ar)
else:
nw = int(scale * w)
nh = int(nw / new_ar)
image = image.resize((nw, nh), Image.BICUBIC)
# ------------------------------------------#
# 將圖像多余的部分加上灰條
# ------------------------------------------#
dx = int(self.rand(0, w - nw))
dy = int(self.rand(0, h - nh))
new_image = Image.new('RGB', (w, h), (128, 128, 128))
new_image.paste(image, (dx, dy))
image = new_image
# ------------------------------------------#
# 翻轉(zhuǎn)圖像
# ------------------------------------------#
flip = self.rand() < .5
if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
image_data = np.array(image, np.uint8)
#---------------------------------#
# 對(duì)圖像進(jìn)行色域變換
# 計(jì)算色域變換的參數(shù)
#---------------------------------#
r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
#---------------------------------#
# 將圖像轉(zhuǎn)到HSV上
#---------------------------------#
hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
dtype = image_data.dtype
#---------------------------------#
# 應(yīng)用變換
#---------------------------------#
x = np.arange(0, 256, dtype=r.dtype)
lut_hue = ((x * r[0]) % 180).astype(dtype)
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
#---------------------------------#
# 對(duì)真實(shí)框進(jìn)行調(diào)整
#---------------------------------#
if len(box)>0:
np.random.shuffle(box)
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
if flip: box[:, [0,2]] = w - box[:, [2,0]]
box[:, 0:2][box[:, 0:2]<0] = 0
box[:, 2][box[:, 2]>w] = w
box[:, 3][box[:, 3]>h] = h
box_w = box[:, 2] - box[:, 0]
box_h = box[:, 3] - box[:, 1]
box = box[np.logical_and(box_w>1, box_h>1)]
return image_data, box
# DataLoader中collate_fn使用
def yolo_dataset_collate(batch):
images = []
bboxes = []
for img, box in batch:
images.append(img)
bboxes.append(box)
images = np.array(images)
return images, bboxes
if __name__ == '__main__':
# ------------------------------------------------------#
# 數(shù)據(jù)集中類別個(gè)數(shù),以voc為例,20類
# ------------------------------------------------------#
num_classes = 20
# ------------------------------------------------------#
# 輸入的shape大小,一定要是32的倍數(shù)
# ------------------------------------------------------#
input_shape = [416, 416]
num_workers = 0
batch_size = 64
# ----------------------------------------------------#
# 獲得圖片路徑和標(biāo)簽
# 圖片路徑 目標(biāo)1的xmin,ymin,xmax,ymax,class_num 目標(biāo)2的xmin,ymin,xmax,ymax,class_num ...
# D:\VOCdevkit/VOC2007/JPEGImages/000005.jpg 263,211,324,339,8 165,264,253,372,8 241,194,295,299,8
# D:\VOCdevkit/VOC2007/JPEGImages/000007.jpg 141,50,500,330,6
# 2007_train.txt和2007_val.txt怎么得到的,之后再聊
# ----------------------------------------------------#
train_annotation_path = '2007_train.txt'
val_annotation_path = '2007_val.txt'
# ------------------------------------------------------------------#
# 讀取數(shù)據(jù)集對(duì)應(yīng)的txt
# train_lines是一個(gè)list,里面每個(gè)元素是一個(gè)str,每個(gè)str內(nèi)有圖片路徑和標(biāo)簽信息,以 空格 分開
# 每個(gè)元素的最后是 換行符\n
# ------------------------------------------------------------------#
with open(train_annotation_path) as f:
train_lines = f.readlines()
with open(val_annotation_path) as f:
val_lines = f.readlines()
train_dataset = YoloDataset(train_lines, input_shape, num_classes, train=True)
val_dataset = YoloDataset(val_lines, input_shape, num_classes, train=False)
# gen就是常規(guī)的train_loader
gen = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
# gen_val就是常規(guī)的val_loader
gen_val = DataLoader(val_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
drop_last=True, collate_fn=yolo_dataset_collate)
for iteration, batch in enumerate(gen):
images, targets = batch[0], batch[1]
調(diào)試時(shí)train_dataset和gen的結(jié)果: