第 5 课:数据加载(下)
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
from Covid_DataSet import COVID19Dataset
root_dir = "./data/cov_19_demo" # 数据的根目录
img_dir = os.path.join(root_dir, "imgs")
path_txt_train = os.path.join(root_dir, "labels", "train.txt")
# 设置 dataset
transforms_train = transforms.Compose([
transforms.Resize((4, 4)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
train_data = COVID19Dataset(root_dir=img_dir, txt_path=path_txt_train, transform=transforms_train)
本章主要介绍 Sampler机制
上面的讲的都是在 Dataset 过程中的,而这里的 Sampler 机制则是在 Dataloader 过程中的
在上一节中我们看到 Dataloader 中有sampler 这个API
Sampler 控制了 DataLoader 如何从数据集中抽取数据的方式
为什么需要 Sampler?
DataLoader 默认会按顺序加载数据集中的样本。但有些场景需要特定的采样方式,比如:
- 类别不平衡:某些类别样本太少,模型训练时可能无法充分学习这些类别,需要对小样本类别进行加权采样。
- 随机性要求:数据的加载顺序可能影响训练的效果(如数据打乱的效果)。
- 自定义需求:对某些特定样本进行重点采样。
常见的 Sampler 有以下几种:
1. 按顺序抽取数据样本
数据从索引 0 开始依次加载。(适用于验证集或测试集)
from torch.utils.data import DataLoader, SequentialSampler
sampler = SequentialSampler(train_data)
dataloader = DataLoader(train_data, batch_size=4, sampler=sampler)
2. 随机抽取数据样本。
打乱数据的加载顺序,适用于训练集。
from torch.utils.data import DataLoader, RandomSampler
sampler = RandomSampler(train_data)
dataloader = DataLoader(train_data, batch_size=4, sampler=sampler)
3. 从数据集中随机抽取子集
适用于在数据集中只加载部分样本的场景。
from torch.utils.data import DataLoader, SubsetRandomSampler
indices = [0, 1, 2, 3, 4] # 指定抽样的索引
sampler = SubsetRandomSampler(indices)
dataloader = DataLoader(train_data, batch_size=2, sampler=sampler)
4. 基于样本的权重进行抽样
常用于处理类别不平衡的问题。为每个样本分配一个权重,权重越高,样本被抽到的概率越大。
weights:权重列表或张量,长度与数据集的样本数相同。num_samples:抽取的样本数量。replacement:是否允许重复采样(True表示有放回采样)。
from torch.utils.data import DataLoader, WeightedRandomSampler
# 每个样本的权重
weights = [0.1, 0.2, 0.3, 0.4] # 权重列表
sampler = WeightedRandomSampler(weights, num_samples=10, replacement=True) # 抽取 10 个样本
dataloader = DataLoader(train_data, batch_size=2, sampler=sampler)
5. 自定义 Sampler
除了内置的 Sampler,还可以通过继承 torch.utils.data.Sampler 来自定义采样逻辑。
from torch.utils.data import Sampler
import random
class CustomSampler(Sampler):
def __init__(self, data_source, indices=None):
self.data_source = data_source
self.indices = indices if indices is not None else list(range(len(data_source)))
def __iter__(self):
# 返回一个索引的迭代器,这里我们随机打乱索引
return iter(random.sample(self.indices, len(self.indices)))
def __len__(self):
return len(self.indices)
# 使用自定义采样器
custom_sampler = CustomSampler(train_data)
dataloader = DataLoader(train_data, batch_size=4, sampler=custom_sampler)
6. 使用 WeightedRandomSampler 解决类别不平衡问题
假设有以下类别分布:
- 类别 0:100 个样本。
- 类别 1:10 个样本。
可以通过 WeightedRandomSampler 实现均衡采样。
加权随机采样:选择的概率与权重成正比。
这使当某些类样本较少时,我们可以增加它们的采样概率,以平衡训练数据。
from torch.utils.data import WeightedRandomSampler
# 计算每个类的采样权重
weights = torch.tensor([1, 5], dtype=torch.float)
# 样本 0(no-finding) 的采样概率是 1。样本 1(covid-19)的采样概率是类 0 的 5 倍。
# 生成每个样本的采样权重
train_targets = [sample[1] for sample in train_data.img_info] #提取每个样本的标签
samples_weights = weights[train_targets] #为每个样本生成对应的采样权重
# 实例化WeightedRandomSampler
sampler_w = WeightedRandomSampler(
weights=samples_weights,
num_samples=len(samples_weights), # 指定采样的样本数,为数据集中的样本数量
replacement=True) #样本在每次采样后会被放回池中,允许重复选择相同的样本。
# 设置 dataloader
train_loader = DataLoader(dataset=train_data, batch_size=2, sampler=sampler_w)
# 模拟了训练过程
for epoch in range(10): # 10 次迭代
for i, (inputs, target) in enumerate(train_loader):
print(target.shape, target)
# 由于是有放回采样,并且样本1的采样概率比0高5倍,可以看到很多次出现[1, 1]
torch.Size([2]) tensor([1, 1])
torch.Size([2]) tensor([1, 1])
torch.Size([2]) tensor([1, 1])
torch.Size([2]) tensor([0, 1])
torch.Size([2]) tensor([0, 1])
torch.Size([2]) tensor([0, 1])
torch.Size([2]) tensor([1, 1])
torch.Size([2]) tensor([1, 1])
torch.Size([2]) tensor([1, 1])
torch.Size([2]) tensor([1, 1])
7. 实例:解决类别不平衡问题
图片的结构如下
├── 0
│ └── 10张图片
├── 1
│ └── 20张图片
├── 2
│ └── 30张图片
├── 3
│ └── 40张图片
├── 4
│ └── 50张图片
├── 5
│ └── 60张图片
├── 6
│ └── 70张图片
├── 7
│ └── 80张图片
├── 8
│ └── 90张图片
└── 9
│ └── 100张图片
import os
import shutil
import collections
import torch
import random
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from PIL import Image
from torchvision.transforms import transforms
class CifarDataset(Dataset):
names = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
cls_num = len(names)
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.img_info = [] # 定义list用于存储样本路径、标签
self._get_img_info()
def __getitem__(self, index):
path_img, label = self.img_info[index]
img = Image.open(path_img).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
if len(self.img_info) == 0:
raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(
self.root_dir)) # 代码具有友好的提示功能,便于debug
return len(self.img_info)
def _get_img_info(self):
for root, dirs, _ in os.walk(self.root_dir):
# 遍历类别
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith('.png'), img_names))
# 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.abspath(os.path.join(root, sub_dir, img_name))
label = int(sub_dir)
self.img_info.append((path_img, int(label)))
random.shuffle(self.img_info) # 将数据顺序打乱
root_dir = r"./data/cifar-unbalance"
transforms_train = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])
])
train_data = CifarDataset(root_dir=root_dir, transform=transforms_train)
# 第一步:计算各类别的采样权重
# 计算每个类的样本数量
train_targets = [sample[1] for sample in train_data.img_info]
label_counter = collections.Counter(train_targets)
class_sample_counts = [label_counter[k] for k in sorted(label_counter)] # 需要特别注意,此list的顺序!
# 计算权重,利用倒数即可
weights = 1. / torch.tensor(class_sample_counts, dtype=torch.float)
# 第二步:生成每个样本的采样权重
samples_weights = weights[train_targets]
# 第三步:实例化WeightedRandomSampler
sampler_w = WeightedRandomSampler(
weights=samples_weights,
num_samples=len(samples_weights),
replacement=True)
# 配置dataloader
train_loader_sampler = DataLoader(dataset=train_data, batch_size=16, sampler=sampler_w)
train_loader = DataLoader(dataset=train_data, batch_size=16)
def show_sample(loader):
for epoch in range(10):
label_count = []
for i, (inputs, target) in enumerate(loader):
label_count.extend(target.tolist())
print(collections.Counter(label_count))
show_sample(train_loader)
print("\n接下来运用sampler\n")
show_sample(train_loader_sampler)
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
接下来运用sampler
Counter({3: 60, 2: 60, 8: 59, 7: 58, 9: 56, 4: 56, 5: 52, 0: 52, 6: 49, 1: 48})
Counter({0: 69, 5: 62, 7: 57, 9: 55, 3: 53, 6: 52, 2: 51, 8: 51, 4: 51, 1: 49})
Counter({9: 59, 5: 58, 6: 58, 2: 57, 8: 56, 7: 55, 1: 55, 4: 54, 0: 50, 3: 48})
Counter({0: 69, 4: 62, 2: 61, 3: 58, 7: 57, 1: 54, 8: 52, 5: 49, 9: 44, 6: 44})
Counter({6: 66, 3: 65, 1: 63, 2: 61, 7: 57, 9: 51, 5: 48, 4: 48, 0: 47, 8: 44})
Counter({6: 70, 7: 63, 8: 60, 2: 59, 3: 54, 1: 53, 5: 51, 9: 49, 4: 46, 0: 45})
Counter({8: 69, 9: 69, 5: 58, 6: 56, 7: 55, 0: 53, 1: 50, 2: 48, 3: 46, 4: 46})
Counter({9: 64, 6: 63, 4: 59, 3: 57, 1: 55, 2: 54, 7: 53, 5: 52, 0: 51, 8: 42})
Counter({4: 75, 8: 67, 1: 57, 0: 55, 6: 54, 2: 52, 3: 52, 9: 51, 5: 49, 7: 38})
Counter({2: 72, 4: 65, 0: 64, 6: 57, 9: 53, 1: 52, 8: 50, 3: 48, 7: 45, 5: 44})