您现在的位置是:首页 >其他 >【PyTorch】torch.cat() 函数:沿指定维度连接多个张量网站首页其他
【PyTorch】torch.cat() 函数:沿指定维度连接多个张量
简介【PyTorch】torch.cat() 函数:沿指定维度连接多个张量
torch.cat 函数
torch.cat 是 PyTorch 中用来沿指定维度连接多个张量的函数。它可以将多个张量拼接在一起,形成一个新的张量。
语法
torch.cat(tensors, dim=0, out=None)
参数说明
- tensors: 一个张量序列(如列表或元组),其中的每个张量将会在指定的维度上进行拼接。所有张量必须具有相同的形状,除了拼接的维度。
- dim: 指定拼接的维度,默认为 0。此维度上的大小将是各个输入张量相应维度大小的总和。
- out: 可选输出张量,指定结果存储的目标张量,默认为
None。
返回值
返回一个新的张量,它是按指定维度拼接后的结果。
使用场景
- 数据扩展:可以将多个小批量数据拼接成一个大批量数据。
- 模型输入拼接:例如在多模态学习中,可能需要将来自不同来源的特征进行拼接。
- 生成合并数据:将多个张量合并为一个张量,以便后续的处理或操作。
示例
1. 在第 0 维(行)拼接
import torch
# 创建两个形状为 (2, 3) 的张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
y = torch.tensor([[7, 8, 9], [10, 11, 12]])
# 在第 0 维上拼接(行拼接)
result = torch.cat((x, y), dim=0)
print("Result of concatenation along dim 0:")
print(result)
输出:
Result of concatenation along dim 0:
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
在 dim=0 上拼接时,结果张量的行数是两个原始张量行数的和。
2. 在第 1 维(列)拼接
# 在第 1 维上拼接(列拼接)
result = torch.cat((x, y), dim=1)
print("
Result of concatenation along dim 1:")
print(result)
输出:
Result of concatenation along dim 1:
tensor([[ 1, 2, 3, 7, 8, 9],
[ 4, 5, 6, 10, 11, 12]])
在 dim=1 上拼接时,结果张量的列数是两个原始张量列数的和。
3. 使用更多的张量进行拼接
z = torch.tensor([[13, 14, 15], [16, 17, 18]])
# 在第 0 维上拼接三个张量
result = torch.cat((x, y, z), dim=0)
print("
Result of concatenation of three tensors along dim 0:")
print(result)
输出:
Result of concatenation of three tensors along dim 0:
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12],
[13, 14, 15],
[16, 17, 18]])
4. 拼接不同维度的张量
# 创建一个形状为 (2, 1) 的张量
a = torch.tensor([[1], [2]])
# 创建一个形状为 (2, 2) 的张量
b = torch.tensor([[3, 4], [5, 6]])
# 在第 1 维上拼接
result = torch.cat((a, b), dim=1)
print("
Result of concatenation along dim 1 with different shapes:")
print(result)
输出:
Result of concatenation along dim 1 with different shapes:
tensor([[1, 3, 4],
[2, 5, 6]])
在这种情况下,a 和 b 的行数相同,才能在列维度上拼接。
注意事项
- 维度匹配:除了拼接的维度外,其他维度必须是相同的。否则,
torch.cat会报错。 - 内存效率:拼接操作会创建一个新的张量,而不会修改原有的张量,因此需要足够的内存来存储新的张量。
总结
torch.cat是用于拼接多个张量的操作,支持沿指定维度拼接。- 在使用时需要确保除了拼接的维度外,其他维度的大小是一致的。
dim参数指定拼接的维度,常用的是dim=0(行拼接)和dim=1(列拼接)。
风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。





U8W/U8W-Mini使用与常见问题解决
QT多线程的5种用法,通过使用线程解决UI主界面的耗时操作代码,防止界面卡死。...
stm32使用HAL库配置串口中断收发数据(保姆级教程)
分享几个国内免费的ChatGPT镜像网址(亲测有效)
Allegro16.6差分等长设置及走线总结