1. Pytorch高阶OP操作where,gather原理
Pytorch是一种常用的深度学习框架,具有很多高阶操作(OP),其中where和gather是常见操作。本文将详细介绍这两个高阶操作的原理和用法。
1.1 where操作原理
Pytorch中的where操作能够根据特定条件在两个数组之间选择元素,并返回一个mask张量,该张量指示原数组中在特定位置上选择了哪个元素。where操作的语法如下:
torch.where(condition, x, y)
其中,condition是条件张量,x和y是输入张量。如果condition中的元素为True,则返回的张量中对应位置的元素为x中的元素;否则,则为y中的元素。下面是where操作的示例代码:
import torch
x = torch.tensor([[1, 2],
[3, 4]])
y = torch.tensor([[5, 6],
[7, 8]])
condition = torch.tensor([[True, False],
[False, True]])
result = torch.where(condition, x, y)
print(result)
#输出结果:
# tensor([[1, 6],
# [7, 4]])
在这个示例中,condition中(0,0)位置的元素为True,(0,1)位置的元素为False,(1,0)位置的元素为False,(1,1)位置的元素为True。因此,选择的元素为x中的(0,0)位置的元素1,y中的(0,1)位置的元素6,y中的(1,0)位置的元素7和x中的(1,1)位置的元素4,结果为tensor([[1, 6],[7, 4]])。
1.2 gather操作原理
Pytorch中的gather操作能够沿着指定维按照给定的索引选择张量中的元素,并返回一个新张量。其语法如下:
torch.gather(input, dim, index, out=None)
其中,input是输入张量,dim是沿着哪个维度选择元素,index是一个包含索引值的张量。out是可选的输出张量。下面是一个示例代码,演示如何使用gather操作:
import torch
input = torch.tensor([[1, 2],
[3, 4]])
index = torch.tensor([[0, 0],
[1, 0]])
result = torch.gather(input, 1, index)
print(result)
#输出结果:
# tensor([[1, 1],
# [4, 3]])
在这个示例中,input是一个2x2的输入张量,index是一个2x2的索引张量。因为dim参数的值为1,所以按照索引张量中的值沿着第1个维度对输入张量进行选择。在(0,0)位置选择的是input中的第0行第0列的元素1,(0,1)位置同样选择的是input中的第0行第0列的元素1,(1,0)位置选择的是input中的第1行第1列的元素4,(1,1)位置选择的是input中的第1行第0列的元素3。结果如下:tensor([[1, 1],[4, 3]])。
2. where 和 gather 在深度学习中的应用
where和gather操作在深度学习中经常被使用。其中,where操作常用于掩膜操作,例如在图像分割任务中,需要将每个像素分到特定的类别中。此时,可以使用where操作将已知类别的像素对应位置上的特定值保留,而将其他位置上的像素值设为0。
下面是基于where操作的图像掩膜代码:
import torch
import numpy as np
from PIL import Image
# 加载图像
image_path = "example.jpg"
image = Image.open(image_path)
image_data = np.array(image)
# 创建掩膜
mask = np.zeros_like(image_data)
mask[100:200, 100:200] = 1
# 转换为 Torch 张量
image_tensor = torch.from_numpy(image_data).float()
mask_tensor = torch.from_numpy(mask).bool()
# 掩膜操作
result = torch.where(mask_tensor, image_tensor, torch.zeros_like(image_tensor))
# 显示结果
result = result.int().numpy()
result_image = Image.fromarray(result)
result_image.show()
上述代码中,首先加载了一张图像,然后创建了一个掩膜张量。将掩膜张量中(100,100)到(200,200)之间的所有像素设为1,其他位置为0。接下来,根据掩膜张量和图像张量执行where操作,对已知位置上的像素值进行保留,而将其他位置上的像素值设为0。最后,将操作结果显示为一张图像。如下图所示:
而gather操作则常用于序列化数据的索引。例如,在机器翻译中,需要将源语言中的单词按照其在词汇表中的索引值转换为张量。
下面是基于gather操作的机器翻译代码:
import torch
# 假设源语言词汇表大小为 20,目标语言词汇表大小为 30
source_vocab_size = 20
target_vocab_size = 30
# 假设源语言句子为 ["hello", "world", "!"]
source_sentence = ["hello", "world", "!"]
# 将源语言句子中的单词转换为在词汇表中的索引
source_indexes = [index for word in source_sentence if word in source_vocab for index, token in enumerate(source_vocab) if word == token]
# 将索引转换为词向量
source_embeddings = torch.randn(len(source_indexes), 10)
# 进行gather操作,转换为目标语言索引
target_indexes = [index for index in range(len(source_indexes)) if source_indexes[index] <= target_vocab_size]
target_embeddings = torch.gather(source_embeddings, 1, torch.tensor(target_indexes).view(-1, 1))
print(target_embeddings)
上述代码中,首先假设源语言词汇表大小为20,目标语言词汇表大小为30。源语言句子为["hello", "world", "!"]。将这个句子中的单词转换为在词汇表中的索引。如果单词不在词汇表中,则跳过。接下来使用随机生成的10维词向量表示每个单词。然后使用gather操作按照源语言中单词在词汇表中的索引,将这些词向量转换为目标语言中的张量。它将源语言句子中的第一个单词"hello"映射为目标语言词汇表中的第3个单词,将第二个单词"world"映射为第7个单词。映射结果如下:
tensor([[-0.6745, 1.2639, -0.4718, -0.7016, -1.7440, 0.7187, 0.3235, -1.4861,
-0.0193, -0.2036],
[-1.1052, 0.0052, -0.8629, 0.5254, -0.7089, 0.4641, -0.6773, -0.3078,
-1.7685, -1.5248]])
值得注意的是,在进行Gather操作时,需要手动将输入张量和索引张量广播为相同的维数,以进行匹配。在本例中,使用.view(-1,1)将1维索引张量转换为2维张量。
3. 总结
本文介绍了Pytorch中的where和gather操作的原理和用法,并提供了基于这两个操作的示例代码。在深度学习中,where操作常用于掩膜操作,而gather操作则常用于序列化数据的索引。更多具体应用需要根据具体情况定义。