您现在的位置是:首页 >技术教程 >使用 PyTorch 构建神经网络(二):数据集、数据加载器和批大小网站首页技术教程

使用 PyTorch 构建神经网络(二):数据集、数据加载器和批大小

Lemon_wxk 2025-04-11 00:01:02
简介使用 PyTorch 构建神经网络(二):数据集、数据加载器和批大小

在上一节使用 PyTorch 构建一个简单的神经网络中,我们简单介绍了如何使用 PyTorch 训练一个简单的神经网络。这一节我们将继续往下讲,主要介绍数据集(Dataset)、数据加载器(DataLoader)和批大小(Batch Size)。

一、导入用于加载数据和处理数据集的方法

在机器学习和深度学习中,DatasetDataLoader 是 PyTorch 中两个非常重要的概念,它们共同用于管理和加载数据。

1. Dataset

Dataset 是 PyTorch 中的一个抽象类,用于表示数据集。它定义了数据集的基本结构和行为,通常需要用户继承并实现特定的方法。Dataset 类的主要作用是提供一种统一的方式来访问和操作数据。

主要方法:
  • __init__:初始化方法,用于加载数据集。

  • __len__:返回数据集的大小。

  • __getitem__:按索引获取单个数据样本。

示例代码:

Python复制

import torch
from torch.utils.data import Dataset
import torch.nn as nn

# 定义一个简单的 Dataset 类
class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# 创建数据集实例
data = torch.randn(100, 10)  # 100 个样本,每个样本 10 个特征
labels = torch.randint(0, 2, (100,))  # 100 个标签,每个标签是 0 或 1
dataset = MyDataset(data, labels)

2. DataLoader

DataLoader 是 PyTorch 中的一个类,用于从 Dataset 中加载数据。它提供了一种便捷的方式来迭代数据集,支持批量加载、数据打乱、多线程加载等功能。

主要参数:
  • dataset:数据集实例。

  • batch_size:每个批次的样本数量。

  • shuffle:是否在每个 epoch 开始时打乱数据,默认为 False

  • num_workers:加载数据时使用的子线程数量,默认为 0(主进程加载数据)。

示例代码:

Python复制

from torch.utils.data import DataLoader

# 创建 DataLoader 实例
batch_size = 10
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 使用 DataLoader 迭代数据
for batch_data, batch_labels in data_loader:
    print(f"Batch Data Shape: {batch_data.shape}")  # 批量数据的形状
    print(f"Batch Labels Shape: {batch_labels.shape}")  # 批量标签的形状

二、完整示例

1. 准备数据

Python复制

x = [[1, 2], [3, 4], [5, 6], [7, 8]]
y = [[3], [7], [11], [15]]

X = torch.tensor(x).float()
Y = torch.tensor(y).float()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
X = X.to(device)
Y = Y.to(device)

2. 定义 Dataset 和 DataLoader

Python复制

from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.tensor(x).float()
        self.y = torch.tensor(y).float()

    def __len__(self):
        return len(self.x)

    def __getitem__(self, ix):
        return self.x[ix], self.y[ix]

ds = MyDataset(X, Y)
dl = DataLoader(ds, batch_size=2, shuffle=True)

3. 定义神经网络

Python复制

class MyNeuralNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_to_hidden_layer = nn.Linear(2, 8)
        self.hidden_layer_activation = nn.ReLU()
        self.hidden_to_output_layer = nn.Linear(8, 1)

    def forward(self, x):
        x = self.input_to_hidden_layer(x)
        x = self.hidden_layer_activation(x)
        x = self.hidden_to_output_layer(x)
        return x

4. 训练模型

Python复制

mynet = MyNeuralNet().to(device)
loss_func = nn.MSELoss()
from torch.optim import SGD
opt = SGD(mynet.parameters(), lr=0.001)

import time
loss_history = []
start = time.time()
for _ in range(50):
    for data in dl:
        x, y = data
        opt.zero_grad()
        loss_value = loss_func(mynet(x), y)
        loss_value.backward()
        opt.step()
        loss_history.append(loss_value)
end = time.time()
print(end - start)

5. 测试模型

Python复制

val_x = [[10, 11]]
val_x = torch.tensor(val_x).float().to(device)
print(mynet(val_x))

三、总结

在这一节中,我们介绍了 PyTorch 中的 DatasetDataLoader,以及如何使用它们来管理和加载数据。我们还展示了如何使用 DataLoader 进行批量训练,并测量了训练时间。通过这些方法,我们可以更高效地处理和加载数据,为模型训练提供支持。下一节我们将介绍使用一个简单的数据集对模型进行训练。使用 PyTorch 构建一个简单的神经网络(三)用 CIFAR-10 数据集训练你的第一个神经网络

风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。