1. Pytorch简介
Pytorch是Facebook开源的深度学习库,基于Torch开发,它主要提供了两个高级核心功能:支持GPU的张量计算和构建基于自动微分的计算图。需要注意的是,Pytorch和Tensorflow的编程风格是不同的。在Tensorflow中,我们首先需要定义计算图,然后在session中进行计算,而在Pytorch中,我们可以通过命令式编程的方式进行计算。
2.自定义数据集的准备
2.1 数据预处理
在使用Pytorch加载自己的图像数据集之前,我们需要进行数据预处理。数据预处理主要包括以下步骤:
读取图像文件
将图像转换成tensor类型
标准化处理
下面是一个实现数据预处理的代码:
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import PIL.Image as Image
transform = transforms.Compose([transforms.Resize((224,224)), # 缩放到224x224大小
transforms.ToTensor(), # 转化为tensor类型
transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])]) # 标准化处理
2.2 数据集的准备
在数据预处理完成之后,我们需要将数据集分为训练集和验证集,并将它们存放在不同的文件夹中。例如,训练集可以存放在一个名为train的文件夹中,而验证集可以存放在一个名为val的文件夹中。文件夹中应该包含每个类别的图片,每个类别应该存放在一个单独的文件夹中。下面是示例代码:
import os
import shutil
from tqdm import tqdm
def create_dataset():
# 数据包含三个类别,分别是Cat,Dog和Panda
root = 'data'
classes = ('Cat', 'Dog', 'Panda')
if not os.path.exists(root):
os.mkdir(root)
# 复制训练图片到data文件夹中
for label in classes:
label_dir = os.path.join(root, label)
if not os.path.exists(label_dir):
os.mkdir(label_dir)
src_dir = os.path.join('train', label)
for file in tqdm(os.listdir(src_dir)):
src_file = os.path.join(src_dir, file)
dst_file = os.path.join(label_dir, file)
shutil.copy(src_file, dst_file)
3.自定义数据集的加载
3.1 数据集类的构建
在Pytorch中,我们可以通过继承Dataset类来构建自定义数据集。我们需要在子类的__init__方法中初始化数据,并实现__getitem__和__len__两个方法。下面是一个简单的数据集类:
class MyDataset(Dataset):
def __init__(self, root, transform=None):
self.root = root
self.transform = transform
self.classes, self.class_to_idx = self._find_classes()
self.samples = self._make_dataset()
def __getitem__(self, index):
path, label = self.samples[index]
img = Image.open(path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.samples)
def _find_classes(self):
classes = [d.name for d in os.scandir(self.root) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def _make_dataset(self):
samples = []
for target in self.class_to_idx.keys():
target_dir = os.path.join(self.root, target)
for root, _, fnames in sorted(os.walk(target_dir)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
item = (path, self.class_to_idx[target])
samples.append(item)
return samples
3.2 数据集的加载
当我们完成自定义的数据集类之后,我们可以使用Dataloader类加载数据集。下面是一个简单的数据集加载的代码:
train_dataset = MyDataset(root='data/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataset = MyDataset(root='data/val', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
4.总结
在本文中,我们介绍了如何使用Pytorch加载自定义的图像数据集。我们通过对数据集进行预处理,将数据集分为训练集和验证集,并实现了一个自定义的数据集类。最后,我们使用Dataloader类加载数据集。通过这些步骤,我们可以快速地创建一个适合自己的数据集,方便我们进行深度学习训练。