pytorch中的select by mask

#select by mask
x = torch.randn(3,4)
print(x)
# tensor([[ 1.1132,  0.8882, -1.4683,  1.4100],
#         [-0.4903, -0.8422,  0.3576,  0.6806],
#         [-0.7180, -0.8218, -0.5010, -0.0607]])

mask = x.ge(0.5)
print(mask)
# tensor([[1, 0, 1, 0],
#         [1, 0, 0, 0],
#         [0, 0, 0, 0]], dtype=torch.uint8)
y = torch.masked_select(x,mask)
print(y)  #tensor([1.0361, 0.6217, 0.6854])
print(y.shape)    #torch.Size([3])
print(y.share_memory_())  #tensor([0.8596, 0.6594, 1.3755])
print(y.is_shared())    #True
torch.ge
torch.ge(input, other, out=None) → Tensor
逐元素比较input和other,即是否 input>=otherinput>=other。

如果两个张量有相同的形状和元素值,则返回True ,否则 False。 第二个参数可以为一个数或与第一个参数相同形状和类型的张量

参数:

input (Tensor) – 待对比的张量
other (Tensor or float) – 对比的张量或float值
out (Tensor, optional) – 输出张量。必须为ByteTensor或者与第一个参数tensor相同类型。
返回值: 一个 torch.ByteTensor 张量,包含了每个位置的比较结果(是否 input >= other )。 返回类型: Tensor

例子:

>>> torch.ge(torch.Tensor([[1, 2], [3, 4]]), torch.Tensor([[1, 1], [4, 4]]))
 1  1
 0  1
[torch.ByteTensor of size 2x2]
torch.gt
torch.gt(input, other, out=None) → Tensor
逐元素比较input和other , 即是否input>otherinput>other 如果两个张量有相同的形状和元素值,则返回True ,否则 False。 第二个参数可以为一个数或与第一个参数相同形状和类型的张量

参数:

input (Tensor) – 要对比的张量
other (Tensor or float) – 要对比的张量或float值
out (Tensor, optional) – 输出张量。必须为ByteTensor或者与第一个参数tensor相同类型。
返回值: 一个 torch.ByteTensor 张量,包含了每个位置的比较结果(是否 input >= other )。 返回类型: Tensor

例子:

>>> torch.gt(torch.Tensor([[1, 2], [3, 4]]), torch.Tensor([[1, 1], [4, 4]]))
 0  1
 0  0
[torch.ByteTensor of size 2x2]

torch.le torch.le(input, other, out=None) → Tensor

  逐元素比较input和other , 即是否input<=otherinput<=other 第二个参数可以为一个数或与第一个参数相同形状和类型的张量

参数: input (Tensor) – 要对比的张量

other (Tensor or float ) – 对比的张量或float值

out (Tensor, optional) – 输出张量。

必须为ByteTensor或者与第一个参数tensor相同类型。 返回值: 一个 torch.ByteTensor 张量,包含了每个位置的比较结果(是否 input >= other )。 返回类型: Tensor

例子: >>> torch.le(torch.Tensor([[1, 2], [3, 4]]), torch.Tensor([[1, 1], [4, 4]])) 1 0 1 1 [torch.ByteTensor of size 2x2] torch.lt torch.lt(input, other, out=None) → Tensor 逐元素比较input和other , 即是否 input<otherinput<other 第二个参数可以为一个数或与第一个参数相同形状和类型的张量 参数: input (Tensor) – 要对比的张量 other (Tensor or float ) – 对比的张量或float值 out (Tensor, optional) – 输出张量。必须为ByteTensor或者与第一个参数tensor相同类型。 input: 一个 torch.ByteTensor 张量,包含了每个位置的比较结果(是否 tensor >= other )。 返回类型: Tensor 例子: >>> torch.lt(torch.Tensor([[1, 2], [3, 4]]), torch.Tensor([[1, 1], [4, 4]])) 0 0 1 0 [torch.ByteTensor of size 2x2]