Skip to main content

第 5 课:数据加载(下)

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
from Covid_DataSet import COVID19Dataset
        
root_dir = "./data/cov_19_demo"  # 数据的根目录
img_dir = os.path.join(root_dir, "imgs")
path_txt_train = os.path.join(root_dir, "labels", "train.txt")

# 设置 dataset
transforms_train = transforms.Compose([
    transforms.Resize((4, 4)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
train_data = COVID19Dataset(root_dir=img_dir, txt_path=path_txt_train, transform=transforms_train)

本章主要介绍 Sampler机制

上面的讲的都是在 Dataset 过程中的,而这里的 Sampler 机制则是在 Dataloader 过程中的
在上一节中我们看到 Dataloader 中有sampler 这个API

Sampler 控制了 DataLoader 如何从数据集中抽取数据的方式

为什么需要 Sampler?

DataLoader 默认会按顺序加载数据集中的样本。但有些场景需要特定的采样方式,比如:

  • 类别不平衡:某些类别样本太少,模型训练时可能无法充分学习这些类别,需要对小样本类别进行加权采样。
  • 随机性要求:数据的加载顺序可能影响训练的效果(如数据打乱的效果)。
  • 自定义需求:对某些特定样本进行重点采样。

常见的 Sampler 有以下几种:

1. 按顺序抽取数据样本

数据从索引 0 开始依次加载。(适用于验证集或测试集)

from torch.utils.data import DataLoader, SequentialSampler

sampler = SequentialSampler(train_data)
dataloader = DataLoader(train_data, batch_size=4, sampler=sampler)

2. 随机抽取数据样本。

打乱数据的加载顺序,适用于训练集。

from torch.utils.data import DataLoader, RandomSampler

sampler = RandomSampler(train_data)
dataloader = DataLoader(train_data, batch_size=4, sampler=sampler)

3. 从数据集中随机抽取子集

适用于在数据集中只加载部分样本的场景。

from torch.utils.data import DataLoader, SubsetRandomSampler

indices = [0, 1, 2, 3, 4]  # 指定抽样的索引
sampler = SubsetRandomSampler(indices)
dataloader = DataLoader(train_data, batch_size=2, sampler=sampler)

4. 基于样本的权重进行抽样

常用于处理类别不平衡的问题。为每个样本分配一个权重,权重越高,样本被抽到的概率越大。

  • weights:权重列表或张量,长度与数据集的样本数相同。
  • num_samples:抽取的样本数量。
  • replacement:是否允许重复采样(True 表示有放回采样)。
from torch.utils.data import DataLoader, WeightedRandomSampler

# 每个样本的权重
weights = [0.1, 0.2, 0.3, 0.4]  # 权重列表
sampler = WeightedRandomSampler(weights, num_samples=10, replacement=True)  # 抽取 10 个样本
dataloader = DataLoader(train_data, batch_size=2, sampler=sampler)

5. 自定义 Sampler

除了内置的 Sampler,还可以通过继承 torch.utils.data.Sampler 来自定义采样逻辑。

from torch.utils.data import Sampler
import random

class CustomSampler(Sampler):
    def __init__(self, data_source, indices=None):
        self.data_source = data_source
        self.indices = indices if indices is not None else list(range(len(data_source)))

    def __iter__(self):
        # 返回一个索引的迭代器,这里我们随机打乱索引
        return iter(random.sample(self.indices, len(self.indices)))

    def __len__(self):
        return len(self.indices)

# 使用自定义采样器
custom_sampler = CustomSampler(train_data)
dataloader = DataLoader(train_data, batch_size=4, sampler=custom_sampler)

6. 使用 WeightedRandomSampler 解决类别不平衡问题

假设有以下类别分布:

  • 类别 0:100 个样本。
  • 类别 1:10 个样本。

可以通过 WeightedRandomSampler 实现均衡采样。

加权随机采样:选择的概率与权重成正比。
这使当某些类样本较少时,我们可以增加它们的采样概率,以平衡训练数据。

from torch.utils.data import WeightedRandomSampler

# 计算每个类的采样权重
weights = torch.tensor([1, 5], dtype=torch.float)
# 样本 0(no-finding)的采样概率是 1。样本 1(covid-19)的采样概率是类 0 的 5 倍。

# 生成每个样本的采样权重
train_targets = [sample[1] for sample in train_data.img_info]  #提取每个样本的标签
samples_weights = weights[train_targets]  #为每个样本生成对应的采样权重

# 实例化WeightedRandomSampler
sampler_w = WeightedRandomSampler(
    weights=samples_weights,
    num_samples=len(samples_weights),  # 指定采样的样本数,为数据集中的样本数量
    replacement=True)   #样本在每次采样后会被放回池中,允许重复选择相同的样本。

# 设置 dataloader
train_loader = DataLoader(dataset=train_data, batch_size=2, sampler=sampler_w)


# 模拟了训练过程
for epoch in range(10):  # 10 次迭代
    for i, (inputs, target) in enumerate(train_loader):
        print(target.shape, target)
# 由于是有放回采样,并且样本1的采样概率比0高5倍,可以看到很多次出现[1, 1]



torch.Size([2]) tensor([1, 1])
torch.Size([2]) tensor([1, 1])
torch.Size([2]) tensor([1, 1])
torch.Size([2]) tensor([0, 1])
torch.Size([2]) tensor([0, 1])
torch.Size([2]) tensor([0, 1])
torch.Size([2]) tensor([1, 1])
torch.Size([2]) tensor([1, 1])
torch.Size([2]) tensor([1, 1])
torch.Size([2]) tensor([1, 1])

7. 实例:解决类别不平衡问题

图片的结构如下

├── 0
│   └── 10张图片
├── 1
│   └── 20张图片
├── 2
│   └── 30张图片
├── 3
│   └── 40张图片
├── 4
│   └── 50张图片
├── 5
│   └── 60张图片
├── 6
│   └── 70张图片
├── 7
│   └── 80张图片
├── 8
│   └── 90张图片
└── 9
│   └── 100张图片
import os
import shutil
import collections
import torch
import random
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from PIL import Image
from torchvision.transforms import transforms
class CifarDataset(Dataset):
    names = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    cls_num = len(names)

    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.img_info = []      # 定义list用于存储样本路径、标签
        self._get_img_info()

    def __getitem__(self, index):
        path_img, label = self.img_info[index]
        img = Image.open(path_img).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    def __len__(self):
        if len(self.img_info) == 0:
            raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(
                self.root_dir))   # 代码具有友好的提示功能,便于debug
        return len(self.img_info)

    def _get_img_info(self):
        for root, dirs, _ in os.walk(self.root_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.png'), img_names))
                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.abspath(os.path.join(root, sub_dir, img_name))
                    label = int(sub_dir)
                    self.img_info.append((path_img, int(label)))
        random.shuffle(self.img_info)   # 将数据顺序打乱

root_dir = r"./data/cifar-unbalance"


transforms_train = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010])
])
train_data = CifarDataset(root_dir=root_dir, transform=transforms_train)

# 第一步:计算各类别的采样权重
# 计算每个类的样本数量
train_targets = [sample[1] for sample in train_data.img_info]
label_counter = collections.Counter(train_targets)
class_sample_counts = [label_counter[k] for k in sorted(label_counter)]  # 需要特别注意,此list的顺序!
# 计算权重,利用倒数即可
weights = 1. / torch.tensor(class_sample_counts, dtype=torch.float)


# 第二步:生成每个样本的采样权重
samples_weights = weights[train_targets]

# 第三步:实例化WeightedRandomSampler
sampler_w = WeightedRandomSampler(
    weights=samples_weights,
    num_samples=len(samples_weights),
    replacement=True)

# 配置dataloader
train_loader_sampler = DataLoader(dataset=train_data, batch_size=16, sampler=sampler_w)
train_loader = DataLoader(dataset=train_data, batch_size=16)

def show_sample(loader):
    for epoch in range(10):
        label_count = []
        for i, (inputs, target) in enumerate(loader):
            label_count.extend(target.tolist())
        print(collections.Counter(label_count))


show_sample(train_loader)
print("\n接下来运用sampler\n")
show_sample(train_loader_sampler)


Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})
Counter({9: 100, 8: 90, 7: 80, 6: 70, 5: 60, 4: 50, 3: 40, 2: 30, 1: 20, 0: 10})

接下来运用sampler

Counter({3: 60, 2: 60, 8: 59, 7: 58, 9: 56, 4: 56, 5: 52, 0: 52, 6: 49, 1: 48})
Counter({0: 69, 5: 62, 7: 57, 9: 55, 3: 53, 6: 52, 2: 51, 8: 51, 4: 51, 1: 49})
Counter({9: 59, 5: 58, 6: 58, 2: 57, 8: 56, 7: 55, 1: 55, 4: 54, 0: 50, 3: 48})
Counter({0: 69, 4: 62, 2: 61, 3: 58, 7: 57, 1: 54, 8: 52, 5: 49, 9: 44, 6: 44})
Counter({6: 66, 3: 65, 1: 63, 2: 61, 7: 57, 9: 51, 5: 48, 4: 48, 0: 47, 8: 44})
Counter({6: 70, 7: 63, 8: 60, 2: 59, 3: 54, 1: 53, 5: 51, 9: 49, 4: 46, 0: 45})
Counter({8: 69, 9: 69, 5: 58, 6: 56, 7: 55, 0: 53, 1: 50, 2: 48, 3: 46, 4: 46})
Counter({9: 64, 6: 63, 4: 59, 3: 57, 1: 55, 2: 54, 7: 53, 5: 52, 0: 51, 8: 42})
Counter({4: 75, 8: 67, 1: 57, 0: 55, 6: 54, 2: 52, 3: 52, 9: 51, 5: 49, 7: 38})
Counter({2: 72, 4: 65, 0: 64, 6: 57, 9: 53, 1: 52, 8: 50, 3: 48, 7: 45, 5: 44})