pytorch人工智能之torch.gather算子用法示例

1. 简介

在pyTorch中,torch.gather函数用于根据指定的index,从输入的张量中提取对应的值。

2. torch.gather函数的用法

torch.gather函数的函数原型如下:

torch.gather(input, dim, index, out=None, sparse_grad=False) -> Tensor

参数说明:

input:输入张量

dim:指定维度,用于确定在哪个维度上进行索引

index:索引张量,用于提取对应的值

out:输出张量,用于保存提取的值,默认为None

sparse_grad:是否使用稀疏的梯度,默认为False

3. 使用示例

3.1 示例1

在本示例中,我们使用torch.gather函数从输入张量中提取特定的值。

import torch

# 创建输入张量

input_tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])

index_tensor = torch.tensor([[0, 0], [2, 1]])

# 在dim=1的维度上进行索引,提取对应的值

output_tensor = torch.gather(input_tensor, 1, index_tensor)

print("input_tensor:")

print(input_tensor)

print("index_tensor:")

print(index_tensor)

print("output_tensor:")

print(output_tensor)

运行以上代码,输出结果为:

input_tensor:

tensor([[1, 2],

[3, 4],

[5, 6]])

index_tensor:

tensor([[0, 0],

[2, 1]])

output_tensor:

tensor([[1, 1],

[5, 4]])

在以上示例中,我们首先创建了一个输入张量input_tensor,大小为3x2。然后,我们创建了一个索引张量index_tensor,大小为2x2。接下来,我们使用torch.gather函数从input_tensor中提取了对应的值,将结果保存在output_tensor中。

在输出结果中,可以看到output_tensor中的第一行是input_tensor中第一行的第一个元素,第二个元素;第二行是input_tensor中第三行的第二个元素,第二行的第一个元素。

3.2 示例2

在本示例中,我们使用torch.gather函数将一个向量中的多个元素替换成指定的值。

import torch

# 创建输入向量

input_tensor = torch.tensor([1, 2, 3, 4, 5, 6])

index_tensor = torch.tensor([0, 2, 4])

# 替换指定位置的元素

output_tensor = torch.gather(input_tensor, 0, index_tensor)

print("input_tensor:")

print(input_tensor)

print("index_tensor:")

print(index_tensor)

print("output_tensor:")

print(output_tensor)

运行以上代码,输出结果为:

input_tensor:

tensor([1, 2, 3, 4, 5, 6])

index_tensor:

tensor([0, 2, 4])

output_tensor:

tensor([1, 3, 5])

在以上示例中,我们首先创建了一个输入向量input_tensor。然后,我们创建了一个索引向量index_tensor,其中包含了需要替换的位置。接下来,我们使用torch.gather函数将input_tensor中的指定位置的元素替换成指定的值,将结果保存在output_tensor中。

在输出结果中,可以看到output_tensor中的第一个元素、第三个元素、第五个元素分别是input_tensor中对应位置的元素。

4. 总结

在本文中,我们介绍了pyTorch中torch.gather函数的用法,并通过示例代码演示了其具体的应用。通过使用torch.gather函数,我们可以方便地从输入的张量中提取特定的值,或者将指定位置的元素替换成指定的值。如有兴趣,你可以进一步阅读官方文档,了解更多关于torch.gather函数的信息。

后端开发标签