Pytorch中torch.autograd.grad()函数用法示例
【摘要】 目录
一、函数解释
二、代码范例(y=x^2)
一、函数解释
如果输入x,输出是y,则求y关于x的导数(梯度):
def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False, only_inputs=True, allow_unused=False): r""...
目录
一、函数解释
如果输入x,输出是y,则求y关于x的导数(梯度):
-
def grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False,
-
only_inputs=True, allow_unused=False):
-
r"""Computes and returns the sum of gradients of outputs w.r.t. the inputs.
-
-
``grad_outputs`` should be a sequence of length matching ``output``
-
containing the pre-computed gradients w.r.t. each of the outputs. If an
-
output doesn't require_grad, then the gradient can be ``None``).
-
-
If ``only_inputs`` is ``True``, the function will only return a list of gradients
-
w.r.t the specified inputs. If it's ``False``, then gradient w.r.t. all remaining
-
leaves will still be computed, and will be accumulated into their ``.grad``
-
attribute.
-
-
Arguments:
-
outputs (sequence of Tensor): outputs of the differentiated function.
-
inputs (sequence of Tensor): Inputs w.r.t. which the gradient will be
-
returned (and not accumulated into ``.grad``).
-
grad_outputs (sequence of Tensor): Gradients w.r.t. each output.
-
None values can be specified for scalar Tensors or ones that don't require
-
grad. If a None value would be acceptable for all grad_tensors, then this
-
argument is optional. Default: None.
-
retain_graph (bool, optional): If ``False``, the graph used to compute the grad
-
will be freed. Note that in nearly all cases setting this option to ``True``
-
is not needed and often can be worked around in a much more efficient
-
way. Defaults to the value of ``create_graph``.
-
create_graph (bool, optional): If ``True``, graph of the derivative will
-
be constructed, allowing to compute higher order derivative products.
-
Default: ``False``.
-
allow_unused (bool, optional): If ``False``, specifying inputs that were not
-
used when computing outputs (and therefore their grad is always zero)
-
is an error. Defaults to ``False``.
-
"""
-
if not only_inputs:
-
warnings.warn("only_inputs argument is deprecated and is ignored now "
-
"(defaults to True). To accumulate gradient for other "
-
"parts of the graph, please use torch.autograd.backward.")
-
-
outputs = (outputs,) if isinstance(outputs, torch.Tensor) else tuple(outputs)
-
inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs)
-
if grad_outputs is None:
-
grad_outputs = [None] * len(outputs)
-
elif isinstance(grad_outputs, torch.Tensor):
-
grad_outputs = [grad_outputs]
-
else:
-
grad_outputs = list(grad_outputs)
-
-
grad_outputs = _make_grads(outputs, grad_outputs)
-
if retain_graph is None:
-
retain_graph = create_graph
-
-
return Variable._execution_engine.run_backward(
-
outputs, grad_outputs, retain_graph, create_graph,
-
inputs, allow_unused)
二、代码范例(y=x^2)
-
import torch
-
-
x = torch.randn(3, 4).requires_grad_(True)
-
for i in range(3):
-
for j in range(4):
-
x[i][j] = i + j
-
y = x ** 2
-
print(x)
-
print(y)
-
weight = torch.ones(y.size())
-
print(weight)
-
dydx = torch.autograd.grad(outputs=y,
-
inputs=x,
-
grad_outputs=weight,
-
retain_graph=True,
-
create_graph=True,
-
only_inputs=True)
-
"""(x**2)' = 2*x """
-
print(dydx[0])
-
d2ydx2 = torch.autograd.grad(outputs=dydx[0],
-
inputs=x,
-
grad_outputs=weight,
-
retain_graph=True,
-
create_graph=True,
-
only_inputs=True)
-
print(d2ydx2[0])
x是:
-
tensor([[0., 1., 2., 3.],
-
[1., 2., 3., 4.],
-
[2., 3., 4., 5.]], grad_fn=<CopySlices>)
y = x的平方:
-
tensor([[ 0., 1., 4., 9.],
-
[ 1., 4., 9., 16.],
-
[ 4., 9., 16., 25.]], grad_fn=<PowBackward0>)
weight:
-
tensor([[1., 1., 1., 1.],
-
[1., 1., 1., 1.],
-
[1., 1., 1., 1.]])
dydx就是(一阶导数),得到结果还需要乘以weight:
-
tensor([[ 0., 2., 4., 6.],
-
[ 2., 4., 6., 8.],
-
[ 4., 6., 8., 10.]], grad_fn=<ThMulBackward>)
d2ydx2就是(二阶导数),得到结果还需要乘以weight:
-
tensor([[2., 2., 2., 2.],
-
[2., 2., 2., 2.],
-
[2., 2., 2., 2.]], grad_fn=<ThMulBackward>)
是不是很简单呢~
文章来源: nickhuang1996.blog.csdn.net,作者:悲恋花丶无心之人,版权归原作者所有,如需转载,请联系作者。
原文链接:nickhuang1996.blog.csdn.net/article/details/91982925
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)