关键词

pytorch索引与切片

@

目录

index索引

torch会自动从左向右索引

例子:

a = torch.randn(4,3,28,28)

表示类似一个CNN 的图片的输入数据,4表示这个batch一共有4张照片,而3表示图片的通道数为3(RGB),(28,28)表示图片的大小

基本索引

索引1:表示第零张图片的shape

print(a[0].shape)
#torch.Size([3,28,28])

索引2:第零张图片的第零个通道的size

print(a[0,0].shape)
#torch.Size([28,28])

索引3:表示第零张图片的第零个通道的第二行第四列的像素点的值

print(a[0,0,2,4])
#tensor(0.8082)

连续选取

⭐索引4:连续取两张图片(取第0张以及第一张图片,不包括第二张)

print(a[:2].shape
#torch.Size([2,3,28,28])
#由于是两张图片,所以第一维变为2

⭐索引5:前两张图片上的第一个通道上的数据(所以通道数变为了1)

print(a[:2,:1,:,:].shape)
print(a[:2,:1].shape)
#torch.Size(2,1,28,28)

⭐索引6:从后面取(-1表示最后一个,从最后一个取到最后,也就是一个通道)

print(a[:2,-1:,:,:].shape)

#torch.Size(2,1,28,28)

规则间隔索引

⭐索引7:在图片的矩阵进行隔行与隔列索引 0:28:2表示从0到28(不包括28),间隔数为2

print(a[:,:,0:28:2,0:28:2].shape)
print(a[:,:,::2,::2].shape)
#torch.Size([4,3,14,14])

索引总结

start : end : step

都取

x:从x取到最后 :x 从开始取到x x:y从x取到y

x:y:z从x到y每隔z个点采样一次

不规则间隔索引

使用index_select()函数

第一个参数表示你对哪个维度进行操作;第二个参数是index(必须是tensor类型):对第0张与第2张图片进行操作

a.index_select(0,torch.tensor([0,2])).shape
#【2,3,28,28】

同理:选择了两个通道

a.index_select(1,torch.tensor([1,2])).shape
#【4,2,28,28】

同理:只取8行

a.index_select(2,torch.arange(8)).shape
#【4,2,8,28】

任意多的维度索引

使用符号:...

例子:

a[...].shape
#[4,3,28,28]

a[0,...].shape
#[3,28,28]

a[0,1,...].shape
#[4,28,28]

a[...,2].shape
#[4,3,28,2]

使用掩码来索引

函数:.masked_select()会将筛选出来的元素打平(因为无法维护原来的shape)

x = torch.randn(2,3)
print(x)

tensor([[-1.3081, -0.5651, -0.9843],
        [ 1.0051, -0.3829,  0.6300]])

mask = x.ge(0.5)#大于等于0.5的元素
print(mask)

tensor([[False, False, False],
        [ True, False,  True]])

z = torch.masked_select(x,mask)
print(z)

tensor([1.0051, 0.6300])

打平后的索引

例子:使用take函数:是将输入的tensor打平之后进行index的选择

src = torch.tensor([[4,3,5],[6,7,8]])
torch.take(src,torch.tensor([0,2,8]))
#tensor([4,5,8])

本文链接:http://task.lmcjl.com/news/5567.html

展开阅读全文