提交 69e5ee16 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make local_dimshuffle_lift work with subclasses for DimShuffle.

上级 dd91c021
...@@ -410,9 +410,9 @@ def local_dimshuffle_lift(node): ...@@ -410,9 +410,9 @@ def local_dimshuffle_lift(node):
inode = input.owner inode = input.owner
if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1): if inode and isinstance(inode.op, Elemwise) and (len(input.clients) == 1):
# Don't use make_node to have tag.test_value set. # Don't use make_node to have tag.test_value set.
ret = inode.op(*[DimShuffle(inp.type.broadcastable, ret = inode.op(*[op.__class__(inp.type.broadcastable,
op.new_order, op.new_order,
op.inplace)(inp) for inp in op.inplace)(inp) for inp in
inode.inputs], **dict(return_list=True)) inode.inputs], **dict(return_list=True))
return ret return ret
if inode and isinstance(inode.op, DimShuffle): if inode and isinstance(inode.op, DimShuffle):
...@@ -424,8 +424,8 @@ def local_dimshuffle_lift(node): ...@@ -424,8 +424,8 @@ def local_dimshuffle_lift(node):
iinput.type.ndim): iinput.type.ndim):
return [iinput] return [iinput]
else: else:
ret = DimShuffle(iinput.type.broadcastable, new_order, ret = op.__class__(iinput.type.broadcastable, new_order,
inplace)(iinput, **dict(return_list=True)) inplace)(iinput, **dict(return_list=True))
return ret return ret
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论