Pytorch高阶OP操作where,gather原理

1. Pytorch高阶OP操作where,gather原理

Pytorch是一种常用的深度学习框架,具有很多高阶操作(OP),其中where和gather是常见操作。本文将详细介绍这两个高阶操作的原理和用法。

1.1 where操作原理

Pytorch中的where操作能够根据特定条件在两个数组之间选择元素,并返回一个mask张量,该张量指示原数组中在特定位置上选择了哪个元素。where操作的语法如下:

torch.where(condition, x, y)

其中,condition是条件张量,x和y是输入张量。如果condition中的元素为True,则返回的张量中对应位置的元素为x中的元素;否则,则为y中的元素。下面是where操作的示例代码:

import torch

x = torch.tensor([[1, 2],

[3, 4]])

y = torch.tensor([[5, 6],

[7, 8]])

condition = torch.tensor([[True, False],

[False, True]])

result = torch.where(condition, x, y)

print(result)

#输出结果:

# tensor([[1, 6],

# [7, 4]])

在这个示例中,condition中(0,0)位置的元素为True,(0,1)位置的元素为False,(1,0)位置的元素为False,(1,1)位置的元素为True。因此,选择的元素为x中的(0,0)位置的元素1,y中的(0,1)位置的元素6,y中的(1,0)位置的元素7和x中的(1,1)位置的元素4,结果为tensor([[1, 6],[7, 4]])。

1.2 gather操作原理

Pytorch中的gather操作能够沿着指定维按照给定的索引选择张量中的元素,并返回一个新张量。其语法如下:

torch.gather(input, dim, index, out=None)

其中,input是输入张量,dim是沿着哪个维度选择元素,index是一个包含索引值的张量。out是可选的输出张量。下面是一个示例代码,演示如何使用gather操作:

import torch

input = torch.tensor([[1, 2],

[3, 4]])

index = torch.tensor([[0, 0],

[1, 0]])

result = torch.gather(input, 1, index)

print(result)

#输出结果:

# tensor([[1, 1],

# [4, 3]])

在这个示例中,input是一个2x2的输入张量,index是一个2x2的索引张量。因为dim参数的值为1,所以按照索引张量中的值沿着第1个维度对输入张量进行选择。在(0,0)位置选择的是input中的第0行第0列的元素1,(0,1)位置同样选择的是input中的第0行第0列的元素1,(1,0)位置选择的是input中的第1行第1列的元素4,(1,1)位置选择的是input中的第1行第0列的元素3。结果如下:tensor([[1, 1],[4, 3]])。

2. where 和 gather 在深度学习中的应用

where和gather操作在深度学习中经常被使用。其中,where操作常用于掩膜操作,例如在图像分割任务中,需要将每个像素分到特定的类别中。此时,可以使用where操作将已知类别的像素对应位置上的特定值保留,而将其他位置上的像素值设为0。

下面是基于where操作的图像掩膜代码:

import torch

import numpy as np

from PIL import Image

# 加载图像

image_path = "example.jpg"

image = Image.open(image_path)

image_data = np.array(image)

# 创建掩膜

mask = np.zeros_like(image_data)

mask[100:200, 100:200] = 1

# 转换为 Torch 张量

image_tensor = torch.from_numpy(image_data).float()

mask_tensor = torch.from_numpy(mask).bool()

# 掩膜操作

result = torch.where(mask_tensor, image_tensor, torch.zeros_like(image_tensor))

# 显示结果

result = result.int().numpy()

result_image = Image.fromarray(result)

result_image.show()

上述代码中,首先加载了一张图像,然后创建了一个掩膜张量。将掩膜张量中(100,100)到(200,200)之间的所有像素设为1,其他位置为0。接下来,根据掩膜张量和图像张量执行where操作,对已知位置上的像素值进行保留,而将其他位置上的像素值设为0。最后,将操作结果显示为一张图像。如下图所示:

而gather操作则常用于序列化数据的索引。例如,在机器翻译中,需要将源语言中的单词按照其在词汇表中的索引值转换为张量。

下面是基于gather操作的机器翻译代码:

import torch

# 假设源语言词汇表大小为 20,目标语言词汇表大小为 30

source_vocab_size = 20

target_vocab_size = 30

# 假设源语言句子为 ["hello", "world", "!"]

source_sentence = ["hello", "world", "!"]

# 将源语言句子中的单词转换为在词汇表中的索引

source_indexes = [index for word in source_sentence if word in source_vocab for index, token in enumerate(source_vocab) if word == token]

# 将索引转换为词向量

source_embeddings = torch.randn(len(source_indexes), 10)

# 进行gather操作,转换为目标语言索引

target_indexes = [index for index in range(len(source_indexes)) if source_indexes[index] <= target_vocab_size]

target_embeddings = torch.gather(source_embeddings, 1, torch.tensor(target_indexes).view(-1, 1))

print(target_embeddings)

上述代码中,首先假设源语言词汇表大小为20,目标语言词汇表大小为30。源语言句子为["hello", "world", "!"]。将这个句子中的单词转换为在词汇表中的索引。如果单词不在词汇表中,则跳过。接下来使用随机生成的10维词向量表示每个单词。然后使用gather操作按照源语言中单词在词汇表中的索引,将这些词向量转换为目标语言中的张量。它将源语言句子中的第一个单词"hello"映射为目标语言词汇表中的第3个单词,将第二个单词"world"映射为第7个单词。映射结果如下:

tensor([[-0.6745,  1.2639, -0.4718, -0.7016, -1.7440,  0.7187,  0.3235, -1.4861,

-0.0193, -0.2036],

[-1.1052, 0.0052, -0.8629, 0.5254, -0.7089, 0.4641, -0.6773, -0.3078,

-1.7685, -1.5248]])

值得注意的是,在进行Gather操作时,需要手动将输入张量和索引张量广播为相同的维数,以进行匹配。在本例中,使用.view(-1,1)将1维索引张量转换为2维张量。

3. 总结

本文介绍了Pytorch中的where和gather操作的原理和用法,并提供了基于这两个操作的示例代码。在深度学习中,where操作常用于掩膜操作,而gather操作则常用于序列化数据的索引。更多具体应用需要根据具体情况定义。

后端开发标签