第 3 课:数据加载(上)
整个数据处理过程有两个核心——Dataset, DataLoader。
Dataset用于数据读取,是一个抽象基类,提供给用户定义自己的数据读取方式DataLoader用于数据加载,有打乱数据,均衡1:1采样,多进程数据加载,组装成Batch形式等功能。
本章将围绕着它们两个展开介绍pytorch的数据读取、预处理、加载等功能。
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
整个数据加载的流程如下,我们会逐块讲解
+-------------+ +--------------+ +-------------+ +---------+
| Dataset | ---> | Data Loader | ---> | Batch Data | ---> | Model |
+-------------+ +--------------+ +-------------+ +---------+
(一)Dataset
torch.utils.data.Dataset 是数据加载和预处理的核心模块,用于定义和管理数据集。
通过 Dataset,我们可以方便地加载自定义数据集或使用现有的标准数据集。`
Dataset 是一个抽象类,表示数据集的接口。
你需要创建自己的 数据集类 并继承自 torch.utils.data.Dataset,然后实现以下两个核心函数:
__len__():返回数据集的大小。__getitem__(index):通过索引返回数据样本和标签,并进行预处理(包括online的数据增强)
此后 Dataloader 会调用Dataset的getitem函数, 由 返回值 组合成一个样本(batch)。