提交 c37e8c25 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

do not assume that by default you are dealing only with tensors

上级 001185db
...@@ -250,21 +250,27 @@ def Rop(f, wrt, eval_points): ...@@ -250,21 +250,27 @@ def Rop(f, wrt, eval_points):
for pack in enumerate(zip(wrt, eval_points)): for pack in enumerate(zip(wrt, eval_points)):
i = pack[0] i = pack[0]
wrt_elem, eval_point = pack[1] wrt_elem, eval_point = pack[1]
if not isinstance(wrt_elem, gof.Variable):
wrt_elem = as_tensor_variable(wrt_elem)
if not isinstance(eval_point, gof.Variable):
eval_point = as_tensor_variable(eval_point)
wrt_elem = as_tensor_variable(wrt_elem) try:
eval_point = as_tensor_variable(eval_point) wrt_dim = len(wrt_elem.type.broadcastable)
eval_dim = len(eval_point.type.broadcastable)
wrt_dim = len(wrt_elem.type.broadcastable)
eval_dim = len(eval_point.type.broadcastable) if wrt_dim != eval_dim:
raise ValueError('Element ' +
if wrt_dim != eval_dim: str(i) +
raise ValueError('Element ' + ' of wrt/eval_point have mismatched ' +
str(i) + 'dimensionality: ' +
' of wrt/eval_point have mismatched ' + str(wrt_dim) +
'dimensionality: ' + ' versus ' +
str(wrt_dim) + str(eval_dim))
' versus ' + except:
str(eval_dim)) # wrt_elem and eval_point can be non-tensor variable which do
# not have broadcastable flags
pass
seen_nodes = {} seen_nodes = {}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论