提交 613f89b9 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed a nasty bug in MergeOptimizer

上级 a5a5cdc9
...@@ -209,7 +209,7 @@ class MergeOptimizer(Optimizer): ...@@ -209,7 +209,7 @@ class MergeOptimizer(Optimizer):
for node in _list_of_nodes(env): for node in _list_of_nodes(env):
node_cid = (node.op, tuple([symbol_idx[input] for input in node.inputs])) node_cid = (node.op, tuple([symbol_idx[input] for input in node.inputs]))
print 'NODE', node, node_cid #print 'NODE', node, node_cid
dup = symbol_idx_inv.get(node_cid, None) dup = symbol_idx_inv.get(node_cid, None)
success = False success = False
if dup is not None: if dup is not None:
...@@ -247,6 +247,8 @@ class MergeOptimizer(Optimizer): ...@@ -247,6 +247,8 @@ class MergeOptimizer(Optimizer):
nodes_seen.add(node) nodes_seen.add(node)
#print 'NODE', node, merge_candidates, node.inputs[0].clients #print 'NODE', node, merge_candidates, node.inputs[0].clients
for candidate in merge_candidates: for candidate in merge_candidates:
if len(node.inputs) != len(candidate.inputs):
continue
inputs_match = all(node_in is cand_in for node_in, cand_in in zip(node.inputs, candidate.inputs)) inputs_match = all(node_in is cand_in for node_in, cand_in in zip(node.inputs, candidate.inputs))
if inputs_match and node.op == candidate.op: if inputs_match and node.op == candidate.op:
assert node is not candidate assert node is not candidate
......
...@@ -144,7 +144,7 @@ def local_shape_lift_elemwise(node): ...@@ -144,7 +144,7 @@ def local_shape_lift_elemwise(node):
return False return False
register_canonicalize(local_shape_lift_elemwise) register_canonicalize(local_shape_lift_elemwise, 'shape_lift')
@gof.local_optimizer([T.shape, None]) @gof.local_optimizer([T.shape, None])
...@@ -165,7 +165,7 @@ def local_shape_lift_sum(node): ...@@ -165,7 +165,7 @@ def local_shape_lift_sum(node):
return T.make_lvector.make_node(*(ish[i] for i in xrange(input.type.ndim) if i not in axis)).outputs return T.make_lvector.make_node(*(ish[i] for i in xrange(input.type.ndim) if i not in axis)).outputs
# return T.vertical_stack.make_node(ish[:axis], ish[axis+1:]).outputs # return T.vertical_stack.make_node(ish[:axis], ish[axis+1:]).outputs
register_canonicalize(local_shape_lift_sum) register_canonicalize(local_shape_lift_sum, 'shape_lift')
@gof.local_optimizer([T.shape, T.dot]) @gof.local_optimizer([T.shape, T.dot])
...@@ -178,7 +178,7 @@ def local_shape_lift_dot(node): ...@@ -178,7 +178,7 @@ def local_shape_lift_dot(node):
a, b = node.inputs[0].owner.inputs a, b = node.inputs[0].owner.inputs
return T.make_lvector.make_node(T.shape(a)[0], T.shape(b)[1]).outputs return T.make_lvector.make_node(T.shape(a)[0], T.shape(b)[1]).outputs
register_canonicalize(local_shape_lift_dot) register_canonicalize(local_shape_lift_dot, 'shape_lift')
# local_shape_lift = opt.LocalOptGroup(local_shape_lift_elemwise, # local_shape_lift = opt.LocalOptGroup(local_shape_lift_elemwise,
...@@ -233,7 +233,7 @@ def local_fill_lift(node): ...@@ -233,7 +233,7 @@ def local_fill_lift(node):
return False return False
register_canonicalize(local_fill_lift) register_canonicalize(local_fill_lift, 'fill_lift')
################## ##################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论