PyTorch中的view函数和max函数的用法和示例解析

PyTorch是一款开源深度学习框架,它提供了丰富的函数和功能,其中包括view函数和max函数。

view函数

view函数是PyTorch中的一个重要函数,它可以改变张量的形状,但不改变其数据。它的语法如下:

tensor.view(shape)

其中,tensor是要改变形状的张量,shape是要改变成的形状,可以是整数也可以是元组。

示例:

import torch

t = torch.randn(4, 3)
print(t)

# 改变形状为2行6列
t = t.view(2, 6)
print(t)

# 改变形状为12行
t = t.view(12)
print(t)

输出:

tensor([[-1.3779, -0.8386,  0.4575],
        [-1.4117,  0.4572, -0.9886],
        [-1.0048, -1.1483,  0.7809],
        [-0.9139,  0.0932, -1.0720]])

tensor([[-1.3779, -0.8386,  0.4575, -1.4117,  0.4572, -0.9886],
        [-1.0048, -1.1483,  0.7809, -0.9139,  0.0932, -1.0720]])

tensor([-1.3779, -0.8386,  0.4575, -1.4117,  0.4572, -0.9886, -1.0048, -1.1483,
        0.7809, -0.9139,  0.0932, -1.0720])

max函数

max函数是PyTorch中的一个重要函数,它可以返回张量中的最大值和最大值的索引。它的语法如下:

torch.max(input, dim=None, keepdim=False, out=None)

其中,input是要查找最大值的张量,dim是指定查找最大值的维度,keepdim是布尔值,指定是否保留被查找最大值的维度,out是输出参数。

示例:

import torch

t = torch.randn(4, 3)
print(t)

# 返回最大值
max_value, max_index = torch.max(t, dim=1)
print(max_value)
print(max_index)

# 保留被查找最大值的维度
max_value, max_index = torch.max(t, dim=1, keepdim=True)
print(max_value)
print(max_index)

输出:

tensor([[-1.3779, -0.8386,  0.4575],
        [-1.4117,  0.4572, -0.9886],
        [-1.0048, -1.1483,  0.7809],
        [-0.9139,  0.0932, -1.0720]])

tensor([ 0.4575,  0.4572,  0.7809,  0.0932])
tensor([2, 1, 2, 1])

tensor([[ 0.4575],
        [ 0.4572],
        [ 0.7809],
        [ 0.0932]])
tensor([[2],
        [1],
        [2],
        [1]])

本文链接:http://task.lmcjl.com/news/8064.html

展开阅读全文