Unverified 提交 0ea61bcc authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Use `grad` to compute jacobian when input shape is known to be (1,) (#1454)

* More robust shape check for `grad` fallback in `jacobian` * Update scalar test
上级 ff98ab8f
...@@ -2069,13 +2069,13 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise ...@@ -2069,13 +2069,13 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
else: else:
wrt = [wrt] wrt = [wrt]
if expression.ndim == 0: if all(expression.type.broadcastable):
# expression is just a scalar, use grad # expression is just a scalar, use grad
return as_list_or_tuple( return as_list_or_tuple(
using_list, using_list,
using_tuple, using_tuple,
grad( grad(
expression, expression.squeeze(),
wrt, wrt,
consider_constant=consider_constant, consider_constant=consider_constant,
disconnected_inputs=disconnected_inputs, disconnected_inputs=disconnected_inputs,
......
...@@ -30,6 +30,7 @@ from pytensor.gradient import ( ...@@ -30,6 +30,7 @@ from pytensor.gradient import (
from pytensor.graph.basic import Apply, graph_inputs from pytensor.graph.basic import Apply, graph_inputs
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.scan.op import Scan
from pytensor.tensor.math import add, dot, exp, sigmoid, sqr, tanh from pytensor.tensor.math import add, dot, exp, sigmoid, sqr, tanh
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.random import RandomStream from pytensor.tensor.random import RandomStream
...@@ -1036,6 +1037,17 @@ def test_jacobian_scalar(): ...@@ -1036,6 +1037,17 @@ def test_jacobian_scalar():
vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX)
assert np.allclose(f(vx), 2) assert np.allclose(f(vx), 2)
# test when input is a shape (1,) vector -- should still be treated as a scalar
Jx = jacobian(y[None], x)
f = pytensor.function([x], Jx)
# Ensure we hit the scalar grad case (doesn't use scan)
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Scan) for node in nodes)
vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX)
assert np.allclose(f(vx), 2)
# test when the jacobian is called with a tuple as wrt # test when the jacobian is called with a tuple as wrt
Jx = jacobian(y, (x,)) Jx = jacobian(y, (x,))
assert isinstance(Jx, tuple) assert isinstance(Jx, tuple)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论