提交 083c21d9 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Pass over additional arguments to grad

上级 94f34a8e
......@@ -819,7 +819,11 @@ def hessian(cost, wrt, consider_constant=None, warn_type=False,
"tensor.hessian expects a (list of) 1 dimensional variable"\
"as `wrt`"
expr = grad(cost, input)
hess, _ = scan(lambda i, y, x: grad(y[i], x),
hess, _ = scan(lambda i, y, x: grad(y[i],
x,
consider_constant=consider_constant,
warn_type=warn_type,
disconnected_inputs=disconnected_inputs),
sequences=arange(expr.shape[0]),
non_sequences=[expr, input])
hessians.append(hess)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论