提交 8897fcf1 authored 作者: --global's avatar --global

Avoid inserting useless nodes in the graph during optimization

上级 4bfa10de
...@@ -3906,12 +3906,17 @@ def local_sum_prod_mul_by_scalar(node): ...@@ -3906,12 +3906,17 @@ def local_sum_prod_mul_by_scalar(node):
# If node.op is a T.elemwise.Prod, then the scalars need to be # If node.op is a T.elemwise.Prod, then the scalars need to be
# raised to the power of the number of elements in the input # raised to the power of the number of elements in the input
# to the Prod # to the Prod
if isinstance(node.op, T.elemwise.Prod): if (isinstance(node.op, T.elemwise.Prod) and
new_op_input_nb_elements != 1):
scalars = [s ** new_op_input_nb_elements for s in scalars] scalars = [s ** new_op_input_nb_elements for s in scalars]
# Scale the output of the op by the scalars and return as # Scale the output of the op by the scalars and return as
# replacement for the original output # replacement for the original output
mul_inputs = scalars + [new_op_output] mul_inputs = scalars
if new_op_input_nb_elements != 1:
mul_inputs.append(new_op_output)
return [T.mul(*mul_inputs)] return [T.mul(*mul_inputs)]
if isinstance(node.op, T.Sum) and node_inps.owner and node_inps.owner.op == T.neg: if isinstance(node.op, T.Sum) and node_inps.owner and node_inps.owner.op == T.neg:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论