提交 1a620aad authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add Reshape track to local_dimshuffle_lift

Sso it can match reshape(dimshuffle(...), ...)
上级 48263896
......@@ -559,7 +559,7 @@ def is_dimshuffle_useless(new_order, input):
return is_useless
@gof.local_optimizer([DimShuffle])
@gof.local_optimizer([DimShuffle, Reshape])
def local_dimshuffle_lift(node):
"""
"Lifts" DimShuffle through Elemwise operations and merges
......@@ -573,9 +573,16 @@ def local_dimshuffle_lift(node):
After this transform, clusters of Elemwise operations are
void of DimShuffle operations.
Also removes useless DimShuffle operation inside Reshape:
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
"""
op = node.op
if (isinstance(op, T.Reshape) and
if (isinstance(op, Reshape) and
node.inputs[0].owner is not None and
isinstance(node.inputs[0].owner.op, DimShuffle)):
new_order = node.inputs[0].owner.op.new_order
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论