提交 81ec8a02 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make local_elemwise_fusion_op more flexible.

上级 94012662
...@@ -5258,7 +5258,8 @@ for i in xrange(1,len(p64)): print i, 64[i]-p64[i-1] ...@@ -5258,7 +5258,8 @@ for i in xrange(1,len(p64)): print i, 64[i]-p64[i-1]
# ############### # ###############
# # Loop fusion # # # Loop fusion #
# ############### # ###############
def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024,
maker=None):
""" """
We parametrize it to make it work for Elemwise and GpuElemwise op. We parametrize it to make it work for Elemwise and GpuElemwise op.
...@@ -5277,6 +5278,9 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -5277,6 +5278,9 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
enough that if we hit it, I'm not sure it enough that if we hit it, I'm not sure it
will affect performance. will affect performance.
""" """
if maker is None:
def maker(node, scalar_op):
return OP(scalar_op)
def local_fuse(node): def local_fuse(node):
""" """
As part of specialization, we fuse two consecutive elemwise Ops of the As part of specialization, we fuse two consecutive elemwise Ops of the
...@@ -5458,7 +5462,7 @@ your code will run correctly, but may be slower.""") ...@@ -5458,7 +5462,7 @@ your code will run correctly, but may be slower.""")
# create the new node. # create the new node.
# Do not call make_node to have test_value # Do not call make_node to have test_value
n = OP(C)(*inputs).owner n = maker(node, C)(*inputs).owner
assert len(n.outputs) == 1 assert len(n.outputs) == 1
assert node.outputs[0].dtype == n.outputs[0].dtype assert node.outputs[0].dtype == n.outputs[0].dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论