提交 d09587bc authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Do not wrap in as_tensor_variable unless needed

上级 e9264ec6
...@@ -295,12 +295,20 @@ def Rop(f, wrt, eval_points): ...@@ -295,12 +295,20 @@ def Rop(f, wrt, eval_points):
_traverse(inp.owner) _traverse(inp.owner)
local_eval_points.append( local_eval_points.append(
seen_nodes[inp.owner][inp.owner.outputs.index(inp)]) seen_nodes[inp.owner][inp.owner.outputs.index(inp)])
same_type_eval_points = []
for x, y in zip(inputs, local_eval_points): for x, y in zip(inputs, local_eval_points):
if y is not None: if y is not None:
assert (as_tensor_variable(x).type == if not isinstance(x, gof.Variable):
as_tensor_variable(y).type) x = as_tensor_variable(x)
if not isinstance(y, gof.Variable):
y = as_tensor_variable(y)
y = x.type.filter_variable(y)
assert x.type == y.type
same_type_eval_points.append(y)
else:
same_type_eval_points.append(y)
seen_nodes[node] = op.R_op(node.inputs, local_eval_points) seen_nodes[node] = op.R_op(node.inputs, same_type_eval_points)
return None return None
# Populate the dictionary # Populate the dictionary
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论