提交 90c6f980 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Switch to new Scan API in jacobian

上级 3cec6432
......@@ -2104,16 +2104,13 @@ def jacobian(
idx, expr, *wrt = args
return grad(expr[idx], wrt, **grad_kwargs)
jacobian_matrices, updates = pytensor.scan(
jacobian_matrices = pytensor.scan(
inner_function,
sequences=pytensor.tensor.arange(expression.size),
non_sequences=[expression.ravel(), *wrt],
return_updates=False,
return_list=True,
)
if updates:
raise ValueError(
"The scan used to build the jacobian matrices returned a list of updates"
)
if jacobian_matrices[0].ndim < (expression.ndim + wrt[0].ndim):
# There was some raveling or squeezing done prior to getting the jacobians
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论