提交 613f89b9 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed a nasty bug in MergeOptimizer

上级 a5a5cdc9
......@@ -209,7 +209,7 @@ class MergeOptimizer(Optimizer):
for node in _list_of_nodes(env):
node_cid = (node.op, tuple([symbol_idx[input] for input in node.inputs]))
print 'NODE', node, node_cid
#print 'NODE', node, node_cid
dup = symbol_idx_inv.get(node_cid, None)
success = False
if dup is not None:
......@@ -247,6 +247,8 @@ class MergeOptimizer(Optimizer):
nodes_seen.add(node)
#print 'NODE', node, merge_candidates, node.inputs[0].clients
for candidate in merge_candidates:
if len(node.inputs) != len(candidate.inputs):
continue
inputs_match = all(node_in is cand_in for node_in, cand_in in zip(node.inputs, candidate.inputs))
if inputs_match and node.op == candidate.op:
assert node is not candidate
......
......@@ -144,7 +144,7 @@ def local_shape_lift_elemwise(node):
return False
register_canonicalize(local_shape_lift_elemwise)
register_canonicalize(local_shape_lift_elemwise, 'shape_lift')
@gof.local_optimizer([T.shape, None])
......@@ -165,7 +165,7 @@ def local_shape_lift_sum(node):
return T.make_lvector.make_node(*(ish[i] for i in xrange(input.type.ndim) if i not in axis)).outputs
# return T.vertical_stack.make_node(ish[:axis], ish[axis+1:]).outputs
register_canonicalize(local_shape_lift_sum)
register_canonicalize(local_shape_lift_sum, 'shape_lift')
@gof.local_optimizer([T.shape, T.dot])
......@@ -178,7 +178,7 @@ def local_shape_lift_dot(node):
a, b = node.inputs[0].owner.inputs
return T.make_lvector.make_node(T.shape(a)[0], T.shape(b)[1]).outputs
register_canonicalize(local_shape_lift_dot)
register_canonicalize(local_shape_lift_dot, 'shape_lift')
# local_shape_lift = opt.LocalOptGroup(local_shape_lift_elemwise,
......@@ -233,7 +233,7 @@ def local_fill_lift(node):
return False
register_canonicalize(local_fill_lift)
register_canonicalize(local_fill_lift, 'fill_lift')
##################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论