提交 dba48c9d authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Small coding style improvements

上级 a0545039
...@@ -2958,13 +2958,14 @@ for i in range(1,len(p64)): print i, 64[i]-p64[i-1] ...@@ -2958,13 +2958,14 @@ for i in range(1,len(p64)): print i, 64[i]-p64[i-1]
# ############### # ###############
# # Loop fusion # # # Loop fusion #
# ############### # ###############
def local_elemwise_fusion_op(OP, max_input_fct = lambda node: 1024): def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
""" """
We parametrise it to make it work for Elemwise and GpuElemwise op. We parametrize it to make it work for Elemwise and GpuElemwise op.
:param OP: GpuElemwise or Elemwise class (the one that we want to fuse) :param OP: GpuElemwise or Elemwise class (the one that we want to fuse)
:param max_input_fct: a fct that return the maximum number of input that this elemwise can take(usefull for the GpuElemwise) :param max_input_fct: a function that returns the maximum number of inputs
that this elemwise can take (useful for GpuElemwise)
""" """
def local_fuse(node): def local_fuse(node):
""" """
...@@ -2998,7 +2999,8 @@ def local_elemwise_fusion_op(OP, max_input_fct = lambda node: 1024): ...@@ -2998,7 +2999,8 @@ def local_elemwise_fusion_op(OP, max_input_fct = lambda node: 1024):
s_inputs = []#inputs of the new scalar op. s_inputs = []#inputs of the new scalar op.
s_g=[]#graph of scalar, what will by done in the inner loop. s_g=[]#graph of scalar, what will by done in the inner loop.
# There is a hard limit of 256 bytes for the formal argument list to a GPU kernel function. # There is a hard limit of 256 bytes for the formal argument list to a
# GPU kernel function.
max_nb_input = max_input_fct(node) max_nb_input = max_input_fct(node)
#print len(node.inputs),max_nb_input #print len(node.inputs),max_nb_input
new_nb_input = len(node.inputs) new_nb_input = len(node.inputs)
...@@ -3008,10 +3010,10 @@ def local_elemwise_fusion_op(OP, max_input_fct = lambda node: 1024): ...@@ -3008,10 +3010,10 @@ def local_elemwise_fusion_op(OP, max_input_fct = lambda node: 1024):
catch = False catch = False
tmp_input=[]#used to remove duplicate input. tmp_input=[]#used to remove duplicate input.
tmp_scalar=[] tmp_scalar=[]
if ((new_nb_input+1)<=max_nb_input if ((new_nb_input+1)<=max_nb_input and
and i.owner i.owner and
and isinstance(i.owner.op, OP) isinstance(i.owner.op, OP) and
and len(i.clients)==1): len(i.clients)==1):
#if the scalar_op don't have a c implementation, we skip its fusion to allow the fusion of the other ops. #if the scalar_op don't have a c implementation, we skip its fusion to allow the fusion of the other ops.
do_fusion=True do_fusion=True
try: try:
...@@ -3054,7 +3056,7 @@ def local_elemwise_fusion_op(OP, max_input_fct = lambda node: 1024): ...@@ -3054,7 +3056,7 @@ def local_elemwise_fusion_op(OP, max_input_fct = lambda node: 1024):
#if no inputs have are an elemwise, there is nothing to fuse. #if no inputs have are an elemwise, there is nothing to fuse.
if new_nb_input==len(node.inputs): if new_nb_input==len(node.inputs):
# print "local_elemwise_fusion: no elemwise in inputs. Nothing to fuse." #print "local_elemwise_fusion: no elemwise in inputs. Nothing to fuse."
return False return False
assert len(s_inputs)==len(inputs) assert len(s_inputs)==len(inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论