提交 e4abbe90 authored 作者: Frederic's avatar Frederic 提交者: Arnaud Bergeron

Make inplace elemwise opt support multiple output

上级 1a6e03c6
......@@ -296,9 +296,6 @@ def inplace_elemwise_optimizer_op(OP):
# gpuarray GpuElemwise inherit from Elemwise
if not type(op) == OP:
continue
# TODO support this case
if len(node.outputs) > 1:
return
baseline = op.inplace_pattern
protected_inputs = [
......@@ -335,8 +332,8 @@ def inplace_elemwise_optimizer_op(OP):
if hasattr(op.scalar_op, "make_new_inplace"):
new_scal = op.scalar_op.make_new_inplace(
scalar.transfer_type(
*[inplace_pattern.get(i, None)
for i in xrange(len(node.outputs))]))
*[inplace_pattern.get(i, o.dtype)
for i, o in enumerate(node.outputs)]))
else:
new_scal = op.scalar_op.__class__(
scalar.transfer_type(
......@@ -5871,15 +5868,17 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
tmp_s_input.append(tmp)
tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1])
s_op = i.owner.op.scalar_op(*tmp_s_input)
s_op = i.owner.op.scalar_op(*tmp_s_input,
return_list=True)
# if the scalar_op don't have a c implementation,
# we skip its fusion to allow the fusion of the
# other ops.
i.owner.op.scalar_op.c_code(s_op.owner,
i.owner.op.scalar_op.c_code(s_op[0].owner,
"test_presence_of_c_code",
["x" for x in i.owner.inputs],
"z", {})
["z" for z in i.owner.outputs],
{})
except MethodNotDefined:
catch = True
except NotImplementedError:
......@@ -5910,7 +5909,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
new_nb_input = new_nb_input_
inputs.extend(tmp_input)
s_inputs.extend(tmp_scalar)
s_g.append(s_op)
s_g.extend(s_op)
else:
# We must support the case where the same variable appear many
# time in the inputs
......@@ -5938,25 +5937,26 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
fusion optimization. We skip this optimization. You can ignore this message,
your code will run correctly, but may be slower.""")
s_new_out = node.op.scalar_op(*s_g)
s_new_out = node.op.scalar_op(*s_g, return_list=True)
try:
s_new_out.owner.op.c_code(s_new_out.owner,
s_new_out[0].owner.op.c_code(s_new_out[0].owner,
"test_presence_of_c_code",
["x" for x in s_g],
"z", {})
["z" for x in s_new_out], {})
except MethodNotDefined:
_logger.info(("%s does not implement the c_code function."
" As well as being potentially slow, this disables "
"loop fusion of this op.") % str(s_new_out.owner.op))
"loop fusion of this op.") % str(
s_new_out[0].owner.op))
return False
except NotImplementedError:
_logger.info(("%s does not implement the c_code function. As well"
" as being potentially slow, this disables loop"
" fusion of this op.") % str(s_new_out.owner.op))
" fusion of this op.") % str(s_new_out[0].owner.op))
return False
# create the composite op.
C = scalar.Composite(s_inputs, [s_new_out])
C = scalar.Composite(s_inputs, s_new_out)
# create the new node.
# Do not call make_node to have test_value
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论