提交 0e127704 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixed some bugs in the L operator function

上级 95f550fd
...@@ -5365,7 +5365,12 @@ def Lop(f, wrt, eval_points, consider_constant=[], warn_type=False, ...@@ -5365,7 +5365,12 @@ def Lop(f, wrt, eval_points, consider_constant=[], warn_type=False,
if not isinstance(f, TensorVariable): if not isinstance(f, TensorVariable):
raise TypeError('In tensor.Lop(), cost argument should be a TensorVariable.', f) raise TypeError('In tensor.Lop(), cost argument should be a TensorVariable.', f)
inputs = gof.graph.inputs([cost]) if type(eval_points) not in (list, tuple):
eval_points = [eval_points]
if type(f) not in (list, tuple):
f = [f]
inputs = gof.graph.inputs(f)
gmap = gradient.grad_sources_inputs( gmap = gradient.grad_sources_inputs(
zip(f,eval_points), zip(f,eval_points),
list(inputs) + list(consider_constant), list(inputs) + list(consider_constant),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论