PyTorch中的torch.squeeze()和torch.unsqueeze()是两个常用的张量处理函数,它们都是在张量维度上进行操作,但是功能有所不同。
torch.squeeze()函数用于删除张量中维度为1的维度,即把维度为1的维度从张量中去掉。例如,一个形状为(2, 1, 3, 1)的张量,使用torch.squeeze()函数后,形状变为(2, 3),即将维度为1的维度去掉。
a = torch.randn(2, 1, 3, 1) print(a.shape) # torch.Size([2, 1, 3, 1]) b = torch.squeeze(a) print(b.shape) # torch.Size([2, 3])
torch.unsqueeze()函数用于在张量的指定位置插入维度为1的维度,即在指定位置插入维度为1的维度。例如,一个形状为(2, 3)的张量,使用torch.unsqueeze(a, dim=1)函数后,形状变为(2, 1, 3),即在第1维位置插入了维度为1的维度。
a = torch.randn(2, 3) print(a.shape) # torch.Size([2, 3]) b = torch.unsqueeze(a, dim=1) print(b.shape) # torch.Size([2, 1, 3])
本文链接:http://task.lmcjl.com/news/7076.html