Unverified 提交 243836e9 authored 作者: Hsin Fan's avatar Hsin Fan 提交者: GitHub

DOC: Fix docstrings in gradient.py (#415)

上级 49acbc5e
...@@ -196,12 +196,13 @@ def Rop( ...@@ -196,12 +196,13 @@ def Rop(
Returns Returns
------- -------
:class:`~pytensor.graph.basic.Variable` or list/tuple of Variables
A symbolic expression such obeying A symbolic expression such obeying
``R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]``, ``R_op[i] = sum_j (d f[i] / d wrt[j]) eval_point[j]``,
where the indices in that expression are magic multidimensional where the indices in that expression are magic multidimensional
indices that specify both the position within a list and all indices that specify both the position within a list and all
coordinates of the tensor elements. coordinates of the tensor elements.
If `wrt` is a list/tuple, then return a list/tuple with the results. If `f` is a list/tuple, then return a list/tuple with the results.
""" """
if not isinstance(wrt, (list, tuple)): if not isinstance(wrt, (list, tuple)):
...@@ -384,6 +385,7 @@ def Lop( ...@@ -384,6 +385,7 @@ def Lop(
Returns Returns
------- -------
:class:`~pytensor.graph.basic.Variable` or list/tuple of Variables
A symbolic expression satisfying A symbolic expression satisfying
``L_op[i] = sum_i (d f[i] / d wrt[j]) eval_point[i]`` ``L_op[i] = sum_i (d f[i] / d wrt[j]) eval_point[i]``
where the indices in that expression are magic multidimensional where the indices in that expression are magic multidimensional
...@@ -481,10 +483,10 @@ def grad( ...@@ -481,10 +483,10 @@ def grad(
Returns Returns
------- -------
:class:`~pytensor.graph.basic.Variable` or list/tuple of Variables
A symbolic expression for the gradient of `cost` with respect to each A symbolic expression for the gradient of `cost` with respect to each
of the `wrt` terms. If an element of `wrt` is not differentiable with of the `wrt` terms. If an element of `wrt` is not differentiable with
respect to the output, then a zero variable is returned. respect to the output, then a zero variable is returned.
""" """
t0 = time.perf_counter() t0 = time.perf_counter()
...@@ -701,7 +703,6 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False): ...@@ -701,7 +703,6 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
Parameters Parameters
---------- ----------
wrt : list of variables wrt : list of variables
Gradients are computed with respect to `wrt`. Gradients are computed with respect to `wrt`.
...@@ -876,7 +877,6 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant): ...@@ -876,7 +877,6 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
(A variable in consider_constant is not a function of (A variable in consider_constant is not a function of
anything) anything)
""" """
# Validate and format consider_constant # Validate and format consider_constant
...@@ -1035,7 +1035,6 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None): ...@@ -1035,7 +1035,6 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
------- -------
list of Variables list of Variables
A list of gradients corresponding to `wrt` A list of gradients corresponding to `wrt`
""" """
# build a dict mapping node to the terms node contributes to each of # build a dict mapping node to the terms node contributes to each of
# its inputs' gradients # its inputs' gradients
...@@ -1423,8 +1422,9 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None): ...@@ -1423,8 +1422,9 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
def _float_zeros_like(x): def _float_zeros_like(x):
"""Like zeros_like, but forces the object to have a """Like zeros_like, but forces the object to have
a floating point dtype""" a floating point dtype
"""
rval = x.zeros_like() rval = x.zeros_like()
...@@ -1436,7 +1436,8 @@ def _float_zeros_like(x): ...@@ -1436,7 +1436,8 @@ def _float_zeros_like(x):
def _float_ones_like(x): def _float_ones_like(x):
"""Like ones_like, but forces the object to have a """Like ones_like, but forces the object to have a
floating point dtype""" floating point dtype
"""
dtype = x.type.dtype dtype = x.type.dtype
if dtype not in pytensor.tensor.type.float_dtypes: if dtype not in pytensor.tensor.type.float_dtypes:
...@@ -1613,7 +1614,6 @@ class numeric_grad: ...@@ -1613,7 +1614,6 @@ class numeric_grad:
Corresponding ndarrays in `g_pt` and `self.gf` must have the same Corresponding ndarrays in `g_pt` and `self.gf` must have the same
shape or ValueError is raised. shape or ValueError is raised.
""" """
if len(g_pt) != len(self.gf): if len(g_pt) != len(self.gf):
raise ValueError("argument has wrong number of elements", len(g_pt)) raise ValueError("argument has wrong number of elements", len(g_pt))
...@@ -1740,7 +1740,6 @@ def verify_grad( ...@@ -1740,7 +1740,6 @@ def verify_grad(
This function does not support multiple outputs. In `tests.scan.test_basic` This function does not support multiple outputs. In `tests.scan.test_basic`
there is an experimental `verify_grad` that covers that case as well by there is an experimental `verify_grad` that covers that case as well by
using random projections. using random projections.
""" """
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.sharedvalue import shared from pytensor.compile.sharedvalue import shared
...@@ -2267,7 +2266,6 @@ def grad_clip(x, lower_bound, upper_bound): ...@@ -2267,7 +2266,6 @@ def grad_clip(x, lower_bound, upper_bound):
----- -----
We register an opt in tensor/opt.py that remove the GradClip. We register an opt in tensor/opt.py that remove the GradClip.
So it have 0 cost in the forward and only do work in the grad. So it have 0 cost in the forward and only do work in the grad.
""" """
return GradClip(lower_bound, upper_bound)(x) return GradClip(lower_bound, upper_bound)(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论