提交 81ec8a02 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make local_elemwise_fusion_op more flexible.

上级 94012662
......@@ -5258,7 +5258,8 @@ for i in xrange(1,len(p64)): print i, 64[i]-p64[i-1]
# ###############
# # Loop fusion #
# ###############
def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024,
maker=None):
"""
We parametrize it to make it work for Elemwise and GpuElemwise op.
......@@ -5277,6 +5278,9 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
enough that if we hit it, I'm not sure it
will affect performance.
"""
if maker is None:
def maker(node, scalar_op):
return OP(scalar_op)
def local_fuse(node):
"""
As part of specialization, we fuse two consecutive elemwise Ops of the
......@@ -5458,7 +5462,7 @@ your code will run correctly, but may be slower.""")
# create the new node.
# Do not call make_node to have test_value
n = OP(C)(*inputs).owner
n = maker(node, C)(*inputs).owner
assert len(n.outputs) == 1
assert node.outputs[0].dtype == n.outputs[0].dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论