您现在的位置是:首页 >其他 >【PyTorch】torch.cat() 函数:沿指定维度连接多个张量网站首页其他

【PyTorch】torch.cat() 函数:沿指定维度连接多个张量

彬彬侠 2025-12-19 00:01:02
简介【PyTorch】torch.cat() 函数:沿指定维度连接多个张量

torch.cat 函数

torch.cat 是 PyTorch 中用来沿指定维度连接多个张量的函数。它可以将多个张量拼接在一起,形成一个新的张量。

语法

torch.cat(tensors, dim=0, out=None)

参数说明

  • tensors: 一个张量序列(如列表或元组),其中的每个张量将会在指定的维度上进行拼接。所有张量必须具有相同的形状,除了拼接的维度。
  • dim: 指定拼接的维度,默认为 0。此维度上的大小将是各个输入张量相应维度大小的总和。
  • out: 可选输出张量,指定结果存储的目标张量,默认为 None

返回值

返回一个新的张量,它是按指定维度拼接后的结果。

使用场景

  • 数据扩展:可以将多个小批量数据拼接成一个大批量数据。
  • 模型输入拼接:例如在多模态学习中,可能需要将来自不同来源的特征进行拼接。
  • 生成合并数据:将多个张量合并为一个张量,以便后续的处理或操作。

示例

1. 在第 0 维(行)拼接
import torch

# 创建两个形状为 (2, 3) 的张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = torch.tensor([[7, 8, 9], [10, 11, 12]])

# 在第 0 维上拼接(行拼接)
result = torch.cat((x, y), dim=0)
print("Result of concatenation along dim 0:")
print(result)

输出:

Result of concatenation along dim 0:
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])

dim=0 上拼接时,结果张量的行数是两个原始张量行数的和。

2. 在第 1 维(列)拼接
# 在第 1 维上拼接(列拼接)
result = torch.cat((x, y), dim=1)
print("
Result of concatenation along dim 1:")
print(result)

输出:

Result of concatenation along dim 1:
tensor([[ 1,  2,  3,  7,  8,  9],
        [ 4,  5,  6, 10, 11, 12]])

dim=1 上拼接时,结果张量的列数是两个原始张量列数的和。

3. 使用更多的张量进行拼接
z = torch.tensor([[13, 14, 15], [16, 17, 18]])

# 在第 0 维上拼接三个张量
result = torch.cat((x, y, z), dim=0)
print("
Result of concatenation of three tensors along dim 0:")
print(result)

输出:

Result of concatenation of three tensors along dim 0:
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12],
        [13, 14, 15],
        [16, 17, 18]])
4. 拼接不同维度的张量
# 创建一个形状为 (2, 1) 的张量
a = torch.tensor([[1], [2]])

# 创建一个形状为 (2, 2) 的张量
b = torch.tensor([[3, 4], [5, 6]])

# 在第 1 维上拼接
result = torch.cat((a, b), dim=1)
print("
Result of concatenation along dim 1 with different shapes:")
print(result)

输出:

Result of concatenation along dim 1 with different shapes:
tensor([[1, 3, 4],
        [2, 5, 6]])

在这种情况下,ab 的行数相同,才能在列维度上拼接。

注意事项

  1. 维度匹配:除了拼接的维度外,其他维度必须是相同的。否则,torch.cat 会报错。
  2. 内存效率:拼接操作会创建一个新的张量,而不会修改原有的张量,因此需要足够的内存来存储新的张量。

总结

  • torch.cat 是用于拼接多个张量的操作,支持沿指定维度拼接。
  • 在使用时需要确保除了拼接的维度外,其他维度的大小是一致的。
  • dim 参数指定拼接的维度,常用的是 dim=0(行拼接)和 dim=1(列拼接)。
风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。