提交 74ee82b5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Improve variable names and comments/docstring in local_elemwise_fusion_op

上级 9df70fda
......@@ -7609,43 +7609,52 @@ for i in range(1,len(p64)): print i, 64[i]-p64[i-1]
"""
# ###############
# # Loop fusion #
# ###############
def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
"""
We parametrize it to make it work for Elemwise and GpuElemwise op.
def local_elemwise_fusion_op(op_class, max_input_fct=lambda node: 32, maker=None):
"""Create a recursive function that fuses `Elemwise` `Op`s.
The basic idea is that we loop through an `Elemwise` node's inputs, find
other `Elemwise` nodes, determine the scalars input types for all of the
`Elemwise` `Op`s, construct a new scalar `Op` using the scalar input types
and each `Elemwise`'s scalar `Op`, and use the composite scalar `Op` in a
new "fused" `Elemwise`.
It's parameterized in order to work for `Elemwise` and `GpuElemwise` `Op`s.
Parameters
----------
OP
GpuElemwise or Elemwise class (the one that we want to fuse)
max_input_fct
A function that returns the maximum number of inputs
that this elemwise can take (useful for GpuElemwise).
GPU kernel currently has a limit of 256 bytes for
the size of all parameters passed to it. As currently
we pass many information only by parameter, we must
limit how many ops we fuse together to avoid busting
that 256 limit.
op_class : type
`GpuElemwise` or `Elemwise` class (the one that we want to fuse)
max_input_fct : callable
A function that returns the maximum number of inputs that this `Elemwise`
can take (useful for `GpuElemwise`). The GPU kernel currently has a
limit of 256 bytes for the size of all parameters passed to it. As
currently we pass a lot of information only by parameter, we must limit how
many `Op`s we fuse together to avoid busting that 256 limit.
On the CPU we limit to 32 input variables
since that is the maximum numpy support.
On the CPU we limit to 32 input variables since that is the maximum
NumPy support.
maker: callable
A function with the signature `(node, *args)` that constructs an
`op_class` instance (e.g. `op_class(*args)`).
"""
if maker is None:
def maker(node, scalar_op):
return OP(scalar_op)
return op_class(scalar_op)
def local_fuse(node):
"""
As part of specialization, we fuse two consecutive elemwise Ops of the
"""Fuse `Elemwise` `Op`s in a node.
As part of specialization, we fuse two consecutive elemwise `Op`s of the
same shape.
For mixed dtype, we let the Composite op do the cast. It lets the C
For mixed dtype, we let the `Composite` `Op` do the cast. It lets the C
compiler do the cast.
The number of dimensions is validated at call time by theano itself.
The number of dimensions is validated at call time by Theano itself.
"""
# META TODO: PUT THESE THINGS IN TRAC, NOT TODO NOTES!!
......@@ -7672,12 +7681,13 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
# worthwhile if the summation axis doesn't line up with a
# contiguous dimension)
if type(node.op) is not OP:
if type(node.op) is not op_class:
return False
if len(node.outputs) > 1:
# We don't support the fusion for node with multiple outputs.
# We don't support fusion for nodes with multiple outputs.
return
inputs = [] # inputs of the new Elemwise op.
s_inputs = [] # inputs of the new scalar op used by the Composite.
# Inputs of the new scalar op that represents the current node.
......@@ -7710,7 +7720,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
# we still want to fusion. So we take the set.
if (
i.owner
and isinstance(i.owner.op, OP)
and isinstance(i.owner.op, op_class)
and len(set([n for n, idx in i.clients])) == 1
and
# Do not merge elemwise that don't have the same
......@@ -7736,9 +7746,11 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
tmp.tag.test_value = tv
except AttributeError:
pass
tmp_s_input.append(tmp)
tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1])
s_op = i.owner.op.scalar_op(*tmp_s_input, return_list=True)
# if the scalar_op don't have a c implementation,
......@@ -7786,8 +7798,8 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, maker=None):
s_inputs.extend(tmp_scalar)
s_g.extend(s_op)
else:
# We must support the case where the same variable appear many
# time in the inputs
# We must support the case where the same variable appears many
# times within the inputs
if inputs.count(i) == node.inputs.count(i):
s = s_inputs[inputs.index(i)]
else:
......@@ -7834,15 +7846,16 @@ your code will run correctly, but may be slower."""
)
# create the composite op.
C = scalar.Composite(s_inputs, s_new_out)
composite_op = scalar.Composite(s_inputs, s_new_out)
# create the new node.
# Do not call make_node to have test_value
n = maker(node, C)(*inputs).owner
assert len(n.outputs) == 1
assert node.outputs[0].dtype == n.outputs[0].dtype
new_node = maker(node, composite_op)(*inputs).owner
assert len(new_node.outputs) == 1
assert node.outputs[0].dtype == new_node.outputs[0].dtype
if len(n.inputs) > max_nb_input:
if len(new_node.inputs) > max_nb_input:
_logger.warning(
"loop fusion failed because Op would exceed" " kernel argument limit."
)
......@@ -7851,16 +7864,15 @@ your code will run correctly, but may be slower."""
# we fuse as many that we can at the same time to make debug mode faster
# debug mode will be faster as it won't test all intermediate step.
while True:
ret = local_fuse(n)
ret = local_fuse(new_node)
if ret is not False and ret is not None:
# print n,ret
assert len(ret) == len(n.outputs)
assert len(ret) == len(new_node.outputs)
assert len(ret) == 1
n = ret[0].owner
new_node = ret[0].owner
else:
break
return n.outputs
return new_node.outputs
return local_fuse
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论