提交 90c6f980 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Switch to new Scan API in jacobian

上级 3cec6432
...@@ -2104,16 +2104,13 @@ def jacobian( ...@@ -2104,16 +2104,13 @@ def jacobian(
idx, expr, *wrt = args idx, expr, *wrt = args
return grad(expr[idx], wrt, **grad_kwargs) return grad(expr[idx], wrt, **grad_kwargs)
jacobian_matrices, updates = pytensor.scan( jacobian_matrices = pytensor.scan(
inner_function, inner_function,
sequences=pytensor.tensor.arange(expression.size), sequences=pytensor.tensor.arange(expression.size),
non_sequences=[expression.ravel(), *wrt], non_sequences=[expression.ravel(), *wrt],
return_updates=False,
return_list=True, return_list=True,
) )
if updates:
raise ValueError(
"The scan used to build the jacobian matrices returned a list of updates"
)
if jacobian_matrices[0].ndim < (expression.ndim + wrt[0].ndim): if jacobian_matrices[0].ndim < (expression.ndim + wrt[0].ndim):
# There was some raveling or squeezing done prior to getting the jacobians # There was some raveling or squeezing done prior to getting the jacobians
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论