Unverified 提交 0ea61bcc authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Use `grad` to compute jacobian when input shape is known to be (1,) (#1454)

* More robust shape check for `grad` fallback in `jacobian` * Update scalar test
上级 ff98ab8f
......@@ -2069,13 +2069,13 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
else:
wrt = [wrt]
if expression.ndim == 0:
if all(expression.type.broadcastable):
# expression is just a scalar, use grad
return as_list_or_tuple(
using_list,
using_tuple,
grad(
expression,
expression.squeeze(),
wrt,
consider_constant=consider_constant,
disconnected_inputs=disconnected_inputs,
......
......@@ -30,6 +30,7 @@ from pytensor.gradient import (
from pytensor.graph.basic import Apply, graph_inputs
from pytensor.graph.null_type import NullType
from pytensor.graph.op import Op
from pytensor.scan.op import Scan
from pytensor.tensor.math import add, dot, exp, sigmoid, sqr, tanh
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.random import RandomStream
......@@ -1036,6 +1037,17 @@ def test_jacobian_scalar():
vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX)
assert np.allclose(f(vx), 2)
# test when input is a shape (1,) vector -- should still be treated as a scalar
Jx = jacobian(y[None], x)
f = pytensor.function([x], Jx)
# Ensure we hit the scalar grad case (doesn't use scan)
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Scan) for node in nodes)
vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX)
assert np.allclose(f(vx), 2)
# test when the jacobian is called with a tuple as wrt
Jx = jacobian(y, (x,))
assert isinstance(Jx, tuple)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论