提交 dba48c9d authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Small coding style improvements

上级 a0545039
......@@ -2958,13 +2958,14 @@ for i in range(1,len(p64)): print i, 64[i]-p64[i-1]
# ###############
# # Loop fusion #
# ###############
def local_elemwise_fusion_op(OP, max_input_fct = lambda node: 1024):
def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
"""
We parametrise it to make it work for Elemwise and GpuElemwise op.
We parametrize it to make it work for Elemwise and GpuElemwise op.
:param OP: GpuElemwise or Elemwise class (the one that we want to fuse)
:param max_input_fct: a fct that return the maximum number of input that this elemwise can take(usefull for the GpuElemwise)
:param max_input_fct: a function that returns the maximum number of inputs
that this elemwise can take (useful for GpuElemwise)
"""
def local_fuse(node):
"""
......@@ -2998,7 +2999,8 @@ def local_elemwise_fusion_op(OP, max_input_fct = lambda node: 1024):
s_inputs = []#inputs of the new scalar op.
s_g=[]#graph of scalar, what will by done in the inner loop.
# There is a hard limit of 256 bytes for the formal argument list to a GPU kernel function.
# There is a hard limit of 256 bytes for the formal argument list to a
# GPU kernel function.
max_nb_input = max_input_fct(node)
#print len(node.inputs),max_nb_input
new_nb_input = len(node.inputs)
......@@ -3008,10 +3010,10 @@ def local_elemwise_fusion_op(OP, max_input_fct = lambda node: 1024):
catch = False
tmp_input=[]#used to remove duplicate input.
tmp_scalar=[]
if ((new_nb_input+1)<=max_nb_input
and i.owner
and isinstance(i.owner.op, OP)
and len(i.clients)==1):
if ((new_nb_input+1)<=max_nb_input and
i.owner and
isinstance(i.owner.op, OP) and
len(i.clients)==1):
#if the scalar_op don't have a c implementation, we skip its fusion to allow the fusion of the other ops.
do_fusion=True
try:
......@@ -3054,7 +3056,7 @@ def local_elemwise_fusion_op(OP, max_input_fct = lambda node: 1024):
#if no inputs have are an elemwise, there is nothing to fuse.
if new_nb_input==len(node.inputs):
# print "local_elemwise_fusion: no elemwise in inputs. Nothing to fuse."
#print "local_elemwise_fusion: no elemwise in inputs. Nothing to fuse."
return False
assert len(s_inputs)==len(inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论