您现在的位置是:首页 >其他 >【PyTorch】contiguous() 方法:确保张量在内存中的存储是连续的网站首页其他

【PyTorch】contiguous() 方法:确保张量在内存中的存储是连续的

彬彬侠 2026-03-27 12:01:03
简介【PyTorch】contiguous() 方法:确保张量在内存中的存储是连续的

contiguous() 方法

在 PyTorch 中,.contiguous() 是张量(Tensor)对象的一个方法,它的作用是确保张量在内存中的存储是连续的(contiguous),从而使得某些操作(如 view())可以正确执行。


1. 什么是“非连续(non-contiguous)张量”?

PyTorch 中的某些操作(如 transpose()permute())会改变张量的存储顺序,而不是实际移动数据。因此,这些操作返回的张量在内存中通常是非连续的

当你尝试对一个非连续的张量执行 .view() 操作时,会报错,因为 .view() 需要数据在内存中是连续的。如果张量是非连续的,必须先调用 .contiguous() 来创建一个新的连续张量


2. .contiguous() 的作用

  • 确保张量在内存中的数据是连续存储的,从而允许 .view() 操作正确执行。
  • 如果张量已经是连续的,则不会执行任何操作
  • 如果张量是非连续的,则会创建一个新的连续张量,并返回这个新的张量

3. .contiguous() 的使用场景

示例 1:view() 需要连续张量

import torch

# 创建一个 3x3 的张量
x = torch.randn(3, 3)
print("Original x shape:", x.shape)

# 进行转置操作
x_t = x.transpose(0, 1)
print("Transposed x shape:", x_t.shape)

# 尝试使用 .view() 变形
try:
    x_t.view(9)
except RuntimeError as e:
    print("RuntimeError:", e)

# 使用 .contiguous() 之后再调用 .view()
x_contiguous = x_t.contiguous().view(9)
print("View after contiguous():", x_contiguous.shape)

输出

RuntimeError: view size is not compatible with input tensor's size and stride
View after contiguous(): torch.Size([9])

在这里:

  • x_t 经过 transpose(0,1) 操作后,变成了非连续的张量
  • 直接对 x_t.view(9) 会报错。
  • 先调用 x_t.contiguous() 使其变成连续张量,然后再 view(9) 就不会报错。

示例 2:检查张量是否连续

可以使用 is_contiguous() 来检查张量是否连续:

x = torch.randn(2, 3)
print("Is x contiguous?", x.is_contiguous())  # True

x_t = x.transpose(0, 1)
print("Is x_t contiguous?", x_t.is_contiguous())  # False

x_t_cont = x_t.contiguous()
print("Is x_t_cont contiguous?", x_t_cont.is_contiguous())  # True

示例 3:用于 permute()

x = torch.randn(2, 3, 4)
x_permuted = x.permute(2, 0, 1)  # 改变维度顺序
print("Is permuted x contiguous?", x_permuted.is_contiguous())  # False

x_contiguous = x_permuted.contiguous()
print("Is x_contiguous contiguous?", x_contiguous.is_contiguous())  # True

permute() 之后,张量通常变成非连续的,需要 contiguous() 来恢复连续性。


4. .contiguous() vs .clone()

  • .contiguous() 只会在需要时创建新的张量(如果张量本身是连续的,则不会创建新的张量)。
  • .clone() 总是会创建一个新的张量,复制数据,即使张量已经是连续的。
x = torch.randn(3, 3)
x_t = x.transpose(0, 1)

x_cont = x_t.contiguous()
x_clone = x_t.clone()

print(x_cont.data_ptr() == x_t.data_ptr())  # False
print(x_clone.data_ptr() == x_t.data_ptr())  # False

两者都会返回新的张量,但 .clone() 还会创建数据的副本。


5. 总结

  • .contiguous() 确保张量在内存中是连续的,特别是在 transpose()permute() 之后使用 view() 之前必须调用它。
  • 非连续张量的 .contiguous() 会创建一个新的连续张量,但如果张量已经是连续的,它不会做任何操作。
  • .is_contiguous() 可以检查张量是否是连续的
  • .clone().contiguous() 都可以创建新的张量,但 .clone() 复制数据,而 .contiguous() 仅在必要时创建新张量。

如果在 PyTorch 代码中遇到了 RuntimeError: view size is not compatible with input tensor's size and stride,很可能需要先使用 .contiguous()

风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。