关键词

用法 注意事项

PyTorch中Variable的用法和注意事项

PyTorch中的Variable是一种封装Tensor的容器,可以跟踪其上的操作,从而实现自动求导。它的使用方法如下:

1. 创建Variable

# 将Tensor封装成Variable
x = Variable(torch.Tensor([1,2,3]))

# 从已有的Variable创建
x = Variable(x.data, requires_grad=True)

2. 运算

# 运算
y = x + 2
z = y * y * 3
out = z.mean()

3. 求导

# 求导
out.backward()

# 查看梯度
x.grad

注意事项

  • Variable的requires_grad属性默认值为False,如果想要求导,需要在创建时将其设置为True。
  • 在计算过程中,只有Variable上的操作才会被记录,Tensor上的操作不会被记录。
  • 在求导之前,需要调用Variable的backward函数,计算出所有的梯度。
  • 求导之后,梯度将存储在Variable的grad属性中。

本文链接:http://task.lmcjl.com/news/2666.html

展开阅读全文