提交 e4abbe90 authored 作者: Frederic's avatar Frederic 提交者: Arnaud Bergeron

Make inplace elemwise opt support multiple output

上级 1a6e03c6
...@@ -296,9 +296,6 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -296,9 +296,6 @@ def inplace_elemwise_optimizer_op(OP):
# gpuarray GpuElemwise inherit from Elemwise # gpuarray GpuElemwise inherit from Elemwise
if not type(op) == OP: if not type(op) == OP:
continue continue
# TODO support this case
if len(node.outputs) > 1:
return
baseline = op.inplace_pattern baseline = op.inplace_pattern
protected_inputs = [ protected_inputs = [
...@@ -335,8 +332,8 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -335,8 +332,8 @@ def inplace_elemwise_optimizer_op(OP):
if hasattr(op.scalar_op, "make_new_inplace"): if hasattr(op.scalar_op, "make_new_inplace"):
new_scal = op.scalar_op.make_new_inplace( new_scal = op.scalar_op.make_new_inplace(
scalar.transfer_type( scalar.transfer_type(
*[inplace_pattern.get(i, None) *[inplace_pattern.get(i, o.dtype)
for i in xrange(len(node.outputs))])) for i, o in enumerate(node.outputs)]))
else: else:
new_scal = op.scalar_op.__class__( new_scal = op.scalar_op.__class__(
scalar.transfer_type( scalar.transfer_type(
...@@ -5871,15 +5868,17 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, ...@@ -5871,15 +5868,17 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
tmp_s_input.append(tmp) tmp_s_input.append(tmp)
tmp_input.append(ii) tmp_input.append(ii)
tmp_scalar.append(tmp_s_input[-1]) tmp_scalar.append(tmp_s_input[-1])
s_op = i.owner.op.scalar_op(*tmp_s_input) s_op = i.owner.op.scalar_op(*tmp_s_input,
return_list=True)
# if the scalar_op don't have a c implementation, # if the scalar_op don't have a c implementation,
# we skip its fusion to allow the fusion of the # we skip its fusion to allow the fusion of the
# other ops. # other ops.
i.owner.op.scalar_op.c_code(s_op.owner, i.owner.op.scalar_op.c_code(s_op[0].owner,
"test_presence_of_c_code", "test_presence_of_c_code",
["x" for x in i.owner.inputs], ["x" for x in i.owner.inputs],
"z", {}) ["z" for z in i.owner.outputs],
{})
except MethodNotDefined: except MethodNotDefined:
catch = True catch = True
except NotImplementedError: except NotImplementedError:
...@@ -5910,7 +5909,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, ...@@ -5910,7 +5909,7 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
new_nb_input = new_nb_input_ new_nb_input = new_nb_input_
inputs.extend(tmp_input) inputs.extend(tmp_input)
s_inputs.extend(tmp_scalar) s_inputs.extend(tmp_scalar)
s_g.append(s_op) s_g.extend(s_op)
else: else:
# We must support the case where the same variable appear many # We must support the case where the same variable appear many
# time in the inputs # time in the inputs
...@@ -5938,25 +5937,26 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32, ...@@ -5938,25 +5937,26 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 32,
fusion optimization. We skip this optimization. You can ignore this message, fusion optimization. We skip this optimization. You can ignore this message,
your code will run correctly, but may be slower.""") your code will run correctly, but may be slower.""")
s_new_out = node.op.scalar_op(*s_g) s_new_out = node.op.scalar_op(*s_g, return_list=True)
try: try:
s_new_out.owner.op.c_code(s_new_out.owner, s_new_out[0].owner.op.c_code(s_new_out[0].owner,
"test_presence_of_c_code", "test_presence_of_c_code",
["x" for x in s_g], ["x" for x in s_g],
"z", {}) ["z" for x in s_new_out], {})
except MethodNotDefined: except MethodNotDefined:
_logger.info(("%s does not implement the c_code function." _logger.info(("%s does not implement the c_code function."
" As well as being potentially slow, this disables " " As well as being potentially slow, this disables "
"loop fusion of this op.") % str(s_new_out.owner.op)) "loop fusion of this op.") % str(
s_new_out[0].owner.op))
return False return False
except NotImplementedError: except NotImplementedError:
_logger.info(("%s does not implement the c_code function. As well" _logger.info(("%s does not implement the c_code function. As well"
" as being potentially slow, this disables loop" " as being potentially slow, this disables loop"
" fusion of this op.") % str(s_new_out.owner.op)) " fusion of this op.") % str(s_new_out[0].owner.op))
return False return False
# create the composite op. # create the composite op.
C = scalar.Composite(s_inputs, [s_new_out]) C = scalar.Composite(s_inputs, s_new_out)
# create the new node. # create the new node.
# Do not call make_node to have test_value # Do not call make_node to have test_value
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论