提交 00c84ed2 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

allow scalars as well (case in which we just call grad)

上级 e3c2a9d4
...@@ -731,8 +731,9 @@ def jacobian(expression, wrt, consider_constant=None, warn_type=False, ...@@ -731,8 +731,9 @@ def jacobian(expression, wrt, consider_constant=None, warn_type=False,
# Check inputs have the right format # Check inputs have the right format
assert isinstance(expression, TensorVariable), \ assert isinstance(expression, TensorVariable), \
"tensor.jacobian expects a Tensor Variable as `expression`" "tensor.jacobian expects a Tensor Variable as `expression`"
assert expression.ndim == 1, \ assert expression.ndim < 2, \
"tensor.jacobian expects a 1 dimensional variable as `expression`" ("tensor.jacobian expects a 1 dimensional variable as "
"`expression`. If not use flatten to make it a vector")
using_list = isinstance(wrt, list) using_list = isinstance(wrt, list)
using_tuple = isinstance(wrt, tuple) using_tuple = isinstance(wrt, tuple)
...@@ -742,6 +743,10 @@ def jacobian(expression, wrt, consider_constant=None, warn_type=False, ...@@ -742,6 +743,10 @@ def jacobian(expression, wrt, consider_constant=None, warn_type=False,
else: else:
wrt = [wrt] wrt = [wrt]
if expression.ndim == 0:
# expression is just a scalar, use grad
return format_as(using_list, using_tuple, grad(expression, wrt))
def inner_function(*args): def inner_function(*args):
idx = args[0] idx = args[0]
expr = args[1] expr = args[1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论