Unverified 提交 b2d8bc24 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Do not coerce gradients to TensorVariable (#1685)

* Do not coerce gradients to TensorVariable This could cause spurious disconnected errors, because the tensorified variable was not in the graph of the cost * Type-consistent checks --------- Co-authored-by: 's avatarjessegrabowski <jessegrabowski@gmail.com>
上级 945e9799
......@@ -494,22 +494,25 @@ def Lop(
coordinates of the tensor elements.
If `f` is a list/tuple, then return a list/tuple with the results.
"""
if not isinstance(eval_points, list | tuple):
_eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)]
else:
_eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points]
from pytensor.tensor import as_tensor_variable
if not isinstance(f, list | tuple):
_f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)]
else:
_f = [pytensor.tensor.as_tensor_variable(x) for x in f]
if not isinstance(eval_points, Sequence):
eval_points = [eval_points]
_eval_points = [
x if isinstance(x, Variable) else as_tensor_variable(x) for x in eval_points
]
if not isinstance(f, Sequence):
f = [f]
_f = [x if isinstance(x, Variable) else as_tensor_variable(x) for x in f]
grads = list(_eval_points)
if not isinstance(wrt, list | tuple):
_wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)]
else:
_wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt]
using_list = isinstance(wrt, list)
using_tuple = isinstance(wrt, tuple)
if not isinstance(wrt, Sequence):
wrt = [wrt]
_wrt = [x if isinstance(x, Variable) else as_tensor_variable(x) for x in wrt]
assert len(_f) == len(grads)
known = dict(zip(_f, grads, strict=True))
......@@ -523,8 +526,6 @@ def Lop(
return_disconnected=return_disconnected,
)
using_list = isinstance(wrt, list)
using_tuple = isinstance(wrt, tuple)
return as_list_or_tuple(using_list, using_tuple, ret)
......
......@@ -11,6 +11,7 @@ from pytensor.gradient import (
DisconnectedType,
GradClip,
GradScale,
Lop,
NullTypeGradError,
Rop,
UndefinedGrad,
......@@ -32,6 +33,7 @@ from pytensor.graph.basic import Apply
from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op
from pytensor.graph.traversal import graph_inputs
from pytensor.scalar import float64
from pytensor.scan.op import Scan
from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, sqrt, tanh
from pytensor.tensor.math import sum as pt_sum
......@@ -1207,3 +1209,13 @@ class TestHessianVectorProduct:
hessp_x_eval, hessp_y_eval = hessp_fn(**test)
np.testing.assert_allclose(hessp_x_eval, [2, 4, 6])
np.testing.assert_allclose(hessp_y_eval, [-6, -4, -2])
def test_scalar_Lop():
xtm1 = float64("xtm1")
xt = xtm1**2
dout_dxt = float64("dout_dxt")
dout_dxtm1 = Lop(xt, wrt=xtm1, eval_points=dout_dxt)
assert dout_dxtm1.type == dout_dxt.type
assert dout_dxtm1.eval({xtm1: 3.0, dout_dxt: 1.5}) == 2 * 3.0 * 1.5
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论