Unverified 提交 b2d8bc24 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Do not coerce gradients to TensorVariable (#1685)

* Do not coerce gradients to TensorVariable This could cause spurious disconnected errors, because the tensorified variable was not in the graph of the cost * Type-consistent checks --------- Co-authored-by: 's avatarjessegrabowski <jessegrabowski@gmail.com>
上级 945e9799
...@@ -494,22 +494,25 @@ def Lop( ...@@ -494,22 +494,25 @@ def Lop(
coordinates of the tensor elements. coordinates of the tensor elements.
If `f` is a list/tuple, then return a list/tuple with the results. If `f` is a list/tuple, then return a list/tuple with the results.
""" """
if not isinstance(eval_points, list | tuple): from pytensor.tensor import as_tensor_variable
_eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)]
else:
_eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points]
if not isinstance(f, list | tuple): if not isinstance(eval_points, Sequence):
_f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)] eval_points = [eval_points]
else: _eval_points = [
_f = [pytensor.tensor.as_tensor_variable(x) for x in f] x if isinstance(x, Variable) else as_tensor_variable(x) for x in eval_points
]
if not isinstance(f, Sequence):
f = [f]
_f = [x if isinstance(x, Variable) else as_tensor_variable(x) for x in f]
grads = list(_eval_points) grads = list(_eval_points)
if not isinstance(wrt, list | tuple): using_list = isinstance(wrt, list)
_wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)] using_tuple = isinstance(wrt, tuple)
else: if not isinstance(wrt, Sequence):
_wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt] wrt = [wrt]
_wrt = [x if isinstance(x, Variable) else as_tensor_variable(x) for x in wrt]
assert len(_f) == len(grads) assert len(_f) == len(grads)
known = dict(zip(_f, grads, strict=True)) known = dict(zip(_f, grads, strict=True))
...@@ -523,8 +526,6 @@ def Lop( ...@@ -523,8 +526,6 @@ def Lop(
return_disconnected=return_disconnected, return_disconnected=return_disconnected,
) )
using_list = isinstance(wrt, list)
using_tuple = isinstance(wrt, tuple)
return as_list_or_tuple(using_list, using_tuple, ret) return as_list_or_tuple(using_list, using_tuple, ret)
......
...@@ -11,6 +11,7 @@ from pytensor.gradient import ( ...@@ -11,6 +11,7 @@ from pytensor.gradient import (
DisconnectedType, DisconnectedType,
GradClip, GradClip,
GradScale, GradScale,
Lop,
NullTypeGradError, NullTypeGradError,
Rop, Rop,
UndefinedGrad, UndefinedGrad,
...@@ -32,6 +33,7 @@ from pytensor.graph.basic import Apply ...@@ -32,6 +33,7 @@ from pytensor.graph.basic import Apply
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.traversal import graph_inputs from pytensor.graph.traversal import graph_inputs
from pytensor.scalar import float64
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, sqrt, tanh from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, sqrt, tanh
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
...@@ -1207,3 +1209,13 @@ class TestHessianVectorProduct: ...@@ -1207,3 +1209,13 @@ class TestHessianVectorProduct:
hessp_x_eval, hessp_y_eval = hessp_fn(**test) hessp_x_eval, hessp_y_eval = hessp_fn(**test)
np.testing.assert_allclose(hessp_x_eval, [2, 4, 6]) np.testing.assert_allclose(hessp_x_eval, [2, 4, 6])
np.testing.assert_allclose(hessp_y_eval, [-6, -4, -2]) np.testing.assert_allclose(hessp_y_eval, [-6, -4, -2])
def test_scalar_Lop():
xtm1 = float64("xtm1")
xt = xtm1**2
dout_dxt = float64("dout_dxt")
dout_dxtm1 = Lop(xt, wrt=xtm1, eval_points=dout_dxt)
assert dout_dxtm1.type == dout_dxt.type
assert dout_dxtm1.eval({xtm1: 3.0, dout_dxt: 1.5}) == 2 * 3.0 * 1.5
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论