提交 8495c63a authored 作者: James Bergstra's avatar James Bergstra

dimshuffle_lift - fixed a potential optimizer error when replacing two dimshuffles in a row

上级 b69b8747
...@@ -180,11 +180,12 @@ def local_dimshuffle_lift(node): ...@@ -180,11 +180,12 @@ def local_dimshuffle_lift(node):
new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in op.new_order] new_order = [x == 'x' and 'x' or inode.op.new_order[x] for x in op.new_order]
inplace = op.inplace and inode.op.inplace inplace = op.inplace and inode.op.inplace
iinput = inode.inputs[0] iinput = inode.inputs[0]
if new_order == range(len(new_order)): if new_order == range(len(new_order)) and (len(new_order) == iinput.type.ndim):
return [iinput] return [iinput]
else: else:
return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs
register_canonicalize(local_dimshuffle_lift) register_canonicalize(local_dimshuffle_lift)
register_specialize(local_dimshuffle_lift) register_specialize(local_dimshuffle_lift)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论