pytorch加载自己的图像数据集实例

1. Pytorch简介

Pytorch是Facebook开源的深度学习库,基于Torch开发,它主要提供了两个高级核心功能:支持GPU的张量计算和构建基于自动微分的计算图。需要注意的是,Pytorch和Tensorflow的编程风格是不同的。在Tensorflow中,我们首先需要定义计算图,然后在session中进行计算,而在Pytorch中,我们可以通过命令式编程的方式进行计算。

2.自定义数据集的准备

2.1 数据预处理

在使用Pytorch加载自己的图像数据集之前,我们需要进行数据预处理。数据预处理主要包括以下步骤:

读取图像文件

将图像转换成tensor类型

标准化处理

下面是一个实现数据预处理的代码:

from torchvision import transforms

from torch.utils.data import DataLoader, Dataset

import PIL.Image as Image

transform = transforms.Compose([transforms.Resize((224,224)), # 缩放到224x224大小

transforms.ToTensor(), # 转化为tensor类型

transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])]) # 标准化处理

2.2 数据集的准备

在数据预处理完成之后,我们需要将数据集分为训练集和验证集,并将它们存放在不同的文件夹中。例如,训练集可以存放在一个名为train的文件夹中,而验证集可以存放在一个名为val的文件夹中。文件夹中应该包含每个类别的图片,每个类别应该存放在一个单独的文件夹中。下面是示例代码:

import os

import shutil

from tqdm import tqdm

def create_dataset():

# 数据包含三个类别,分别是Cat,Dog和Panda

root = 'data'

classes = ('Cat', 'Dog', 'Panda')

if not os.path.exists(root):

os.mkdir(root)

# 复制训练图片到data文件夹中

for label in classes:

label_dir = os.path.join(root, label)

if not os.path.exists(label_dir):

os.mkdir(label_dir)

src_dir = os.path.join('train', label)

for file in tqdm(os.listdir(src_dir)):

src_file = os.path.join(src_dir, file)

dst_file = os.path.join(label_dir, file)

shutil.copy(src_file, dst_file)

3.自定义数据集的加载

3.1 数据集类的构建

在Pytorch中,我们可以通过继承Dataset类来构建自定义数据集。我们需要在子类的__init__方法中初始化数据,并实现__getitem__和__len__两个方法。下面是一个简单的数据集类:

class MyDataset(Dataset):

def __init__(self, root, transform=None):

self.root = root

self.transform = transform

self.classes, self.class_to_idx = self._find_classes()

self.samples = self._make_dataset()

def __getitem__(self, index):

path, label = self.samples[index]

img = Image.open(path).convert('RGB')

if self.transform is not None:

img = self.transform(img)

return img, label

def __len__(self):

return len(self.samples)

def _find_classes(self):

classes = [d.name for d in os.scandir(self.root) if d.is_dir()]

classes.sort()

class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}

return classes, class_to_idx

def _make_dataset(self):

samples = []

for target in self.class_to_idx.keys():

target_dir = os.path.join(self.root, target)

for root, _, fnames in sorted(os.walk(target_dir)):

for fname in sorted(fnames):

path = os.path.join(root, fname)

item = (path, self.class_to_idx[target])

samples.append(item)

return samples

3.2 数据集的加载

当我们完成自定义的数据集类之后,我们可以使用Dataloader类加载数据集。下面是一个简单的数据集加载的代码:

train_dataset = MyDataset(root='data/train', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

val_dataset = MyDataset(root='data/val', transform=transform)

val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

4.总结

在本文中,我们介绍了如何使用Pytorch加载自定义的图像数据集。我们通过对数据集进行预处理,将数据集分为训练集和验证集,并实现了一个自定义的数据集类。最后,我们使用Dataloader类加载数据集。通过这些步骤,我们可以快速地创建一个适合自己的数据集,方便我们进行深度学习训练。

免责声明:本文来自互联网,本站所有信息(包括但不限于文字、视频、音频、数据及图表),不保证该信息的准确性、真实性、完整性、有效性、及时性、原创性等,版权归属于原作者,如无意侵犯媒体或个人知识产权,请来电或致函告之,本站将在第一时间处理。猿码集站发布此文目的在于促进信息交流,此文观点与本站立场无关,不承担任何责任。

后端开发标签