提交 342546b2 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

A few clarifications in the comments

上级 d5faa5be
...@@ -3008,8 +3008,8 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -3008,8 +3008,8 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
if not isinstance(node.op, OP): if not isinstance(node.op, OP):
return False return False
inputs=[]#inputs of the new Elemwise op. inputs=[]#inputs of the new Elemwise op.
s_inputs = []#inputs of the new scalar op. s_inputs = []#inputs of the new scalar op used by the Composite.
s_g=[]#graph of scalar, what will by done in the inner loop. s_g=[] # Inputs of the new scalar op that represents the current node.
# There is a hard limit of 256 bytes for the formal argument list to a # There is a hard limit of 256 bytes for the formal argument list to a
# GPU kernel function. # GPU kernel function.
...@@ -3025,7 +3025,10 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -3025,7 +3025,10 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
for i in node.inputs: for i in node.inputs:
do_fusion = False do_fusion = False
catch = False catch = False
tmp_input=[]#used to remove duplicate input. # Will store inputs of the fused node that are not currently inputs
# of the node we want to create (to avoid duplicating inputs).
tmp_input=[]
# Same as tmp_input, but for scalars.
tmp_scalar=[] tmp_scalar=[]
# We should not check the number of inputs here # We should not check the number of inputs here
...@@ -3061,11 +3064,16 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -3061,11 +3064,16 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
_logger.info("%s does not implement the c_code function. As well as being potentially slow, this disables loop fusion of this op." % str(i.owner.op.scalar_op)) _logger.info("%s does not implement the c_code function. As well as being potentially slow, this disables loop fusion of this op." % str(i.owner.op.scalar_op))
do_fusion=False do_fusion=False
# Compute the number of inputs in case we fuse this input # Compute the number of inputs in case we fuse this input.
# The -1 is that replace an existing input with others. # We subtract 1 because we replace the existing input with the new
# inputs from `tmp_input`.
new_nb_input_ = new_nb_input + len(tmp_input) - 1 new_nb_input_ = new_nb_input + len(tmp_input) - 1
# If the new input is already an input of the current node, it was already counted.
# When new_nb_input was initialized to len(node.inputs) # If the new input is already an input of the current node, it was
# already counted when `new_nb_input` was initialized to
# len(node.inputs).
# This can happen when a variable is used both by the Elemwise to
# fuse and the current node.
for x in tmp_input: for x in tmp_input:
if x in node.inputs: if x in node.inputs:
new_nb_input_ -= 1 new_nb_input_ -= 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论