提交 74ee82b5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Improve variable names and comments/docstring in local_elemwise_fusion_op

上级 9df70fda
...@@ -7609,43 +7609,52 @@ for i in range(1,len(p64)): print i, 64[i]-p64[i-1] ...@@ -7609,43 +7609,52 @@ for i in range(1,len(p64)): print i, 64[i]-p64[i-1]
""" """
# ############### def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None):
# # Loop fusion # """Create a recursive function that fuses `Elemwise` `Op`s.
# ###############
def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None): The basic idea is that we loop through an `Elemwise` node's inputs, find
""" other `Elemwise` nodes, determine the scalars input types for all of the
We parametrize it to make it work for Elemwise and GpuElemwise op. `Elemwise` `Op`s, construct a new scalar `Op` using the scalar input types
and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a
new "fused" `Elemwise`.
It's parameterized in order to work for `Elemwise` and `GpuElemwise` `Op`s.
Parameters Parameters
---------- ----------
OP op_class : type
GpuElemwise or Elemwise class (the one that we want to fuse) `GpuElemwise` or `Elemwise` class (the one that we want to fuse)
max_input_fct max_input_fct : callable
A function that returns the maximum number of inputs A function that returns the maximum number of inputs that this `Elemwise`
that this elemwise can take (useful for GpuElemwise). can take (useful for `GpuElemwise`). The GPU kernel currently has a
GPU kernel currently has a limit of 256 bytes for limit of 256 bytes for the size of all parameters passed to it. As
the size of all parameters passed to it. As currently currently we pass a lot of information only by parameter, we must limit how
we pass many information only by parameter, we must many `Op`s we fuse together to avoid busting that 256 limit.
limit how many ops we fuse together to avoid busting
that 256 limit.
On the CPU we limit to 32 input variables On the CPU we limit to 32 input variables since that is the maximum
since that is the maximum numpy support. NumPy support.
maker: callable
A function with the signature `(node, *args)` that constructs an
`op_class` instance (e.g. `op_class(*args)`).
""" """
if maker is None: if maker is None:
def maker(node, scalar_op): def maker(node, scalar_op):
return OP(scalar_op) return op_class(scalar_op)
def local_fuse(node): def local_fuse(node):
""" """Fuse `Elemwise` `Op`s in a node.
As part of specialization, we fuse two consecutive elemwise Ops of the
As part of specialization, we fuse two consecutive elemwise `Op`s of the
same shape. same shape.
For mixed dtype, we let the Composite op do the cast. It lets the C For mixed dtype, we let the `Composite` `Op` do the cast. It lets the C
compiler do the cast. compiler do the cast.
The number of dimensions is validated at call time by theano itself.
The number of dimensions is validated at call time by Theano itself.
""" """
# META TODO: PUT THESE THINGS IN TRAC, NOT TODO NOTES!! # META TODO: PUT THESE THINGS IN TRAC, NOT TODO NOTES!!
...@@ -7672,12 +7681,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None): ...@@ -7672,12 +7681,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
# worthwhile if the summation axis doesn't line up with a # worthwhile if the summation axis doesn't line up with a
# contiguous dimension) # contiguous dimension)
if type(node.op) is not OP: if type(node.op) is not op_class:
return False return False
if len(node.outputs) > 1: if len(node.outputs) > 1:
# We don't support the fusion for node with multiple outputs. # We don't support fusion for nodes with multiple outputs.
return return
inputs = [] # inputs of the new Elemwise op. inputs = [] # inputs of the new Elemwise op.
s_inputs = [] # inputs of the new scalar op used by the Composite. s_inputs = [] # inputs of the new scalar op used by the Composite.
# Inputs of the new scalar op that represents the current node. # Inputs of the new scalar op that represents the current node.
...@@ -7710,7 +7720,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None): ...@@ -7710,7 +7720,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
# we still want to fusion. So we take the set. # we still want to fusion. So we take the set.
if ( if (
i.owner i.owner
and isinstance(i.owner.op, OP) and isinstance(i.owner.op, op_class)
and len(set([n for n, idx in i.clients])) == 1 and len(set([n for n, idx in i.clients])) == 1
and and
# Do not merge elemwise that don't have the same # Do not merge elemwise that don't have the same
...@@ -7736,9 +7746,11 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None): ...@@ -7736,9 +7746,11 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
tmp.tag.test_value = tv tmp.tag.test_value = tv
except AttributeError: except AttributeError:
pass pass
tmp_s_input.append(tmp) tmp_s_input.append(tmp)
tmp_input.append(ii) tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1]) tmp_scalar.append(tmp_s_input[-1])
s_op = i.owner.op.scalar_op(*tmp_s_input, return_list=True) s_op = i.owner.op.scalar_op(*tmp_s_input, return_list=True)
# if the scalar_op don't have a c implementation, # if the scalar_op don't have a c implementation,
...@@ -7786,8 +7798,8 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None): ...@@ -7786,8 +7798,8 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
s_inputs.extend(tmp_scalar) s_inputs.extend(tmp_scalar)
s_g.extend(s_op) s_g.extend(s_op)
else: else:
# We must support the case where the same variable appear many # We must support the case where the same variable appears many
# time in the inputs # times within the inputs
if inputs.count(i) == node.inputs.count(i): if inputs.count(i) == node.inputs.count(i):
s = s_inputs[inputs.index(i)] s = s_inputs[inputs.index(i)]
else: else:
...@@ -7834,15 +7846,16 @@ your code will run correctly, but may be slower.""" ...@@ -7834,15 +7846,16 @@ your code will run correctly, but may be slower."""
) )
# create the composite op. # create the composite op.
C = scalar.Composite(s_inputs, s_new_out) composite_op = scalar.Composite(s_inputs, s_new_out)
# create the new node. # create the new node.
# Do not call make_node to have test_value # Do not call make_node to have test_value
n = maker(node, C)(*inputs).owner new_node = maker(node, composite_op)(*inputs).owner
assert len(n.outputs) == 1
assert node.outputs[0].dtype == n.outputs[0].dtype assert len(new_node.outputs) == 1
assert node.outputs[0].dtype == new_node.outputs[0].dtype
if len(n.inputs) > max_nb_input: if len(new_node.inputs) > max_nb_input:
_logger.warning( _logger.warning(
"loop fusion failed because Op would exceed" " kernel argument limit." "loop fusion failed because Op would exceed" " kernel argument limit."
) )
...@@ -7851,16 +7864,15 @@ your code will run correctly, but may be slower.""" ...@@ -7851,16 +7864,15 @@ your code will run correctly, but may be slower."""
# we fuse as many that we can at the same time to make debug mode faster # we fuse as many that we can at the same time to make debug mode faster
# debug mode will be faster as it won't test all intermediate step. # debug mode will be faster as it won't test all intermediate step.
while True: while True:
ret = local_fuse(n) ret = local_fuse(new_node)
if ret is not False and ret is not None: if ret is not False and ret is not None:
# print n,ret assert len(ret) == len(new_node.outputs)
assert len(ret) == len(n.outputs)
assert len(ret) == 1 assert len(ret) == 1
n = ret[0].owner new_node = ret[0].owner
else: else:
break break
return n.outputs return new_node.outputs
return local_fuse return local_fuse
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论