1. 简介
在pyTorch中,torch.gather函数用于根据指定的index,从输入的张量中提取对应的值。
2. torch.gather函数的用法
torch.gather函数的函数原型如下:
torch.gather(input, dim, index, out=None, sparse_grad=False) -> Tensor
参数说明:
input:输入张量
dim:指定维度,用于确定在哪个维度上进行索引
index:索引张量,用于提取对应的值
out:输出张量,用于保存提取的值,默认为None
sparse_grad:是否使用稀疏的梯度,默认为False
3. 使用示例
3.1 示例1
在本示例中,我们使用torch.gather函数从输入张量中提取特定的值。
import torch
# 创建输入张量
input_tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])
index_tensor = torch.tensor([[0, 0], [2, 1]])
# 在dim=1的维度上进行索引,提取对应的值
output_tensor = torch.gather(input_tensor, 1, index_tensor)
print("input_tensor:")
print(input_tensor)
print("index_tensor:")
print(index_tensor)
print("output_tensor:")
print(output_tensor)
运行以上代码,输出结果为:
input_tensor:
tensor([[1, 2],
[3, 4],
[5, 6]])
index_tensor:
tensor([[0, 0],
[2, 1]])
output_tensor:
tensor([[1, 1],
[5, 4]])
在以上示例中,我们首先创建了一个输入张量input_tensor,大小为3x2。然后,我们创建了一个索引张量index_tensor,大小为2x2。接下来,我们使用torch.gather函数从input_tensor中提取了对应的值,将结果保存在output_tensor中。
在输出结果中,可以看到output_tensor中的第一行是input_tensor中第一行的第一个元素,第二个元素;第二行是input_tensor中第三行的第二个元素,第二行的第一个元素。
3.2 示例2
在本示例中,我们使用torch.gather函数将一个向量中的多个元素替换成指定的值。
import torch
# 创建输入向量
input_tensor = torch.tensor([1, 2, 3, 4, 5, 6])
index_tensor = torch.tensor([0, 2, 4])
# 替换指定位置的元素
output_tensor = torch.gather(input_tensor, 0, index_tensor)
print("input_tensor:")
print(input_tensor)
print("index_tensor:")
print(index_tensor)
print("output_tensor:")
print(output_tensor)
运行以上代码,输出结果为:
input_tensor:
tensor([1, 2, 3, 4, 5, 6])
index_tensor:
tensor([0, 2, 4])
output_tensor:
tensor([1, 3, 5])
在以上示例中,我们首先创建了一个输入向量input_tensor。然后,我们创建了一个索引向量index_tensor,其中包含了需要替换的位置。接下来,我们使用torch.gather函数将input_tensor中的指定位置的元素替换成指定的值,将结果保存在output_tensor中。
在输出结果中,可以看到output_tensor中的第一个元素、第三个元素、第五个元素分别是input_tensor中对应位置的元素。
4. 总结
在本文中,我们介绍了pyTorch中torch.gather函数的用法,并通过示例代码演示了其具体的应用。通过使用torch.gather函数,我们可以方便地从输入的张量中提取特定的值,或者将指定位置的元素替换成指定的值。如有兴趣,你可以进一步阅读官方文档,了解更多关于torch.gather函数的信息。