1 Dataset
2 定义自己的数据集
python
# 定义数据集
class MyDataset(data.Dataset):
def __init__(self):
# fake data
self.data = torch.randn(100, 10)
self.target = torch.randint(0, 2, (100,))
def __getitem__(self, index):
return self.data[index], self.target[index]
def __len__(self):
return len(self.data)
python
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform: callable = None, target_transform=None):
# self.img_labels = pd.read_csv(annotations_file) # 每一张图片名称,以及对应的label
self.img_labels = os.listdir(folder_path)
self.img_dir = img_dir # 图像的跟目录
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
# 最核心之处:idx
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) # 得到一张图片的完整路径
image = read_image(img_path) # 用我们的图像读取工具来读取图片(opencv、pillow)
label = self.img_labels.iloc[idx, 1] # 读取图片对应的label
if self.transform:
image = self.transform(image) # 图像的预处理
if self.target_transform:
label = self.target_transform(label) # 标签的预处理
return image, label #把最终结果返回
3 torch 中的 dataloader
python
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size=32, shuffle=True)
for batch_idx, (data, target) in enumerate(train_dataloader):
# batch_idx : 第几次循环
# data:最终输入data
# target:label
print(data)
4 torchvision
4.1 torchvision 中的dataset
python
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset
import pandas as pd
import os
import PIL as PIL
from torchvision.io import read_image
def data_download():
training_data = datasets.MNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
a = training_data[10]
print(a)
# from PIL import Image
# image = Image.open('image.jpg') # 打开图像
# image.save('new_image.jpg') # 保存图像
torchvision 中的 transforms
python
######################################################################
# Transforms
# ----------
#
# One issue we can see from the above is that the samples are not of the
# same size. Most neural networks expect the images of a fixed size.
# Therefore, we will need to write some preprocessing code.
# Let's create three transforms:
#
# - ``Rescale``: to scale the image
# - ``RandomCrop``: to crop from image randomly. This is data
# augmentation.
# - ``ToTensor``: to convert the numpy images to torch images (we need to
# swap axes).
#
# We will write them as callable classes instead of simple functions so
# that parameters of the transform need not be passed everytime it's
# called. For this, we just need to implement ``__call__`` method and
# if required, ``__init__`` method. We can then use a transform like this:
#
# ::
#
# tsfm = Transform(params)
# transformed_sample = tsfm(sample)
#
# Observe below how these transforms had to be applied both on the image and
# landmarks.
#
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
# h and w are swapped for landmarks because for images,
# x and y axes are axis 1 and 0 respectively
landmarks = landmarks * [new_w / w, new_h / h]
return {'image': img, 'landmarks': landmarks}
class RandomCrop(object):
"""Crop randomly the image in a sample.
Args:
output_size (tuple or int): Desired output size. If int, square crop
is made.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top: top + new_h,
left: left + new_w]
landmarks = landmarks - [left, top]
return {'image': image, 'landmarks': landmarks}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
# swap color axis because
# numpy image: H x W x C
# torch image: C x H x W
image = image.transpose((2, 0, 1))
return {'image': torch.from_numpy(image),
'landmarks': torch.from_numpy(landmarks)}
######################################################################
# .. note::
# In the example above, `RandomCrop` uses an external library's random number generator
# (in this case, Numpy's `np.random.int`). This can result in unexpected behavior with `DataLoader`
# (see https://pytorch.org/docs/stable/notes/faq.html#my-data-loader-workers-return-identical-random-numbers).
# In practice, it is safer to stick to PyTorch's random number generator, e.g. by using `torch.randint` instead.
######################################################################
# Compose transforms
# ~~~~~~~~~~~~~~~~~~
#
# Now, we apply the transforms on a sample.
#
# Let's say we want to rescale the shorter side of the image to 256 and
# then randomly crop a square of size 224 from it. i.e, we want to compose
# ``Rescale`` and ``RandomCrop`` transforms.
# ``torchvision.transforms.Compose`` is a simple callable class which allows us
# to do this.
#
scale = Rescale(256)
crop = RandomCrop(128)
composed = transforms.Compose([Rescale(256),
RandomCrop(224)])
transformed_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',
root_dir='data/faces/',
transform=transforms.Compose([
Rescale(256),
RandomCrop(224),
ToTensor()
]))
for i in range(len(transformed_dataset)):
sample = transformed_dataset[i]
print(i, sample['image'].size(), sample['landmarks'].size())
if i == 3:
break