提交 20530fa2 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make the new node on a graph not depend on the old node.

上级 b3f54e3a
......@@ -3753,8 +3753,21 @@ def local_useless_switch(node):
if out.type.broadcastable != node.outputs[0].type.broadcastable:
# We need to copy data to the new dimensions during execution
out = T.alloc(out, *[node.outputs[0].shape[i] for i
in xrange(out.ndim)])
# We should not depend on node.outputs as this would
# make the new node depend on the old one that will
# get optimized again. So this create a cycle.
shps = []
for idx, (b1, b2), in enumerate(zip(out.type.broadcastable,
node.outputs[0].type.broadcastable)):
if b1 == b2:
shps.append(out.shape[idx])
elif not node.inputs[1].type.broadcastable[idx]:
shps.append(node.inputs[1].shape[idx])
else:
shps.append(node.inputs[2].shape[idx])
out = T.alloc(out, *shps)
return False
else:
out = out
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论