您现在的位置是:首页 >其他 >【PyTorch】contiguous() 方法:确保张量在内存中的存储是连续的网站首页其他
【PyTorch】contiguous() 方法:确保张量在内存中的存储是连续的
简介【PyTorch】contiguous() 方法:确保张量在内存中的存储是连续的
contiguous() 方法
在 PyTorch 中,.contiguous() 是张量(Tensor)对象的一个方法,它的作用是确保张量在内存中的存储是连续的(contiguous),从而使得某些操作(如 view())可以正确执行。
1. 什么是“非连续(non-contiguous)张量”?
PyTorch 中的某些操作(如 transpose()、permute())会改变张量的存储顺序,而不是实际移动数据。因此,这些操作返回的张量在内存中通常是非连续的。
当你尝试对一个非连续的张量执行 .view() 操作时,会报错,因为 .view() 需要数据在内存中是连续的。如果张量是非连续的,必须先调用 .contiguous() 来创建一个新的连续张量。
2. .contiguous() 的作用
- 确保张量在内存中的数据是连续存储的,从而允许
.view()操作正确执行。 - 如果张量已经是连续的,则不会执行任何操作。
- 如果张量是非连续的,则会创建一个新的连续张量,并返回这个新的张量。
3. .contiguous() 的使用场景
示例 1:view() 需要连续张量
import torch
# 创建一个 3x3 的张量
x = torch.randn(3, 3)
print("Original x shape:", x.shape)
# 进行转置操作
x_t = x.transpose(0, 1)
print("Transposed x shape:", x_t.shape)
# 尝试使用 .view() 变形
try:
x_t.view(9)
except RuntimeError as e:
print("RuntimeError:", e)
# 使用 .contiguous() 之后再调用 .view()
x_contiguous = x_t.contiguous().view(9)
print("View after contiguous():", x_contiguous.shape)
输出
RuntimeError: view size is not compatible with input tensor's size and stride
View after contiguous(): torch.Size([9])
在这里:
x_t经过transpose(0,1)操作后,变成了非连续的张量。- 直接对
x_t.view(9)会报错。 - 先调用
x_t.contiguous()使其变成连续张量,然后再view(9)就不会报错。
示例 2:检查张量是否连续
可以使用 is_contiguous() 来检查张量是否连续:
x = torch.randn(2, 3)
print("Is x contiguous?", x.is_contiguous()) # True
x_t = x.transpose(0, 1)
print("Is x_t contiguous?", x_t.is_contiguous()) # False
x_t_cont = x_t.contiguous()
print("Is x_t_cont contiguous?", x_t_cont.is_contiguous()) # True
示例 3:用于 permute()
x = torch.randn(2, 3, 4)
x_permuted = x.permute(2, 0, 1) # 改变维度顺序
print("Is permuted x contiguous?", x_permuted.is_contiguous()) # False
x_contiguous = x_permuted.contiguous()
print("Is x_contiguous contiguous?", x_contiguous.is_contiguous()) # True
在 permute() 之后,张量通常变成非连续的,需要 contiguous() 来恢复连续性。
4. .contiguous() vs .clone()
.contiguous()只会在需要时创建新的张量(如果张量本身是连续的,则不会创建新的张量)。.clone()总是会创建一个新的张量,复制数据,即使张量已经是连续的。
x = torch.randn(3, 3)
x_t = x.transpose(0, 1)
x_cont = x_t.contiguous()
x_clone = x_t.clone()
print(x_cont.data_ptr() == x_t.data_ptr()) # False
print(x_clone.data_ptr() == x_t.data_ptr()) # False
两者都会返回新的张量,但 .clone() 还会创建数据的副本。
5. 总结
.contiguous()确保张量在内存中是连续的,特别是在transpose()或permute()之后使用view()之前必须调用它。- 非连续张量的
.contiguous()会创建一个新的连续张量,但如果张量已经是连续的,它不会做任何操作。 .is_contiguous()可以检查张量是否是连续的。.clone()和.contiguous()都可以创建新的张量,但.clone()复制数据,而.contiguous()仅在必要时创建新张量。
如果在 PyTorch 代码中遇到了 RuntimeError: view size is not compatible with input tensor's size and stride,很可能需要先使用 .contiguous()。
风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。





QT多线程的5种用法,通过使用线程解决UI主界面的耗时操作代码,防止界面卡死。...
U8W/U8W-Mini使用与常见问题解决
stm32使用HAL库配置串口中断收发数据(保姆级教程)
分享几个国内免费的ChatGPT镜像网址(亲测有效)
Allegro16.6差分等长设置及走线总结