提交 7f80efca authored 作者: --global's avatar --global

Update local_sum_prod_mul_by_scalar to work correctly with products

上级 fbef8859
......@@ -3886,24 +3886,34 @@ def local_sum_prod_mul_by_scalar(node):
scalars = [t.dimshuffle() for t in terms if
numpy.all(t.type.broadcastable)]
non_scalars = [t for t in terms if not numpy.all(t.broadcastable)]
if scalars:
if len(scalars) > 1:
if len(non_scalars) > 1:
return [T.mul(T.mul(*scalars),
node.op(T.mul(*non_scalars)))]
elif len(non_scalars) == 1:
return [T.mul(T.mul(*scalars),
node.op(non_scalars[0]))]
else:
return [T.mul(*scalars)]
else:
if len(non_scalars) > 1:
return [T.mul(scalars[0],
node.op(T.mul(*non_scalars)))]
if len(scalars) == 0:
# Nothing to optimize here
return
# Perform the op only on the non-scalar inputs, if applicable
if len(non_scalars) == 0:
new_op_input_nb_elements = 1
new_op_output = 1
elif len(non_scalars) == 1:
return [T.mul(scalars[0], node.op(non_scalars[0]))]
new_op_input_nb_elements = T.prod(non_scalars[0].shape)
new_op_output = node.op(non_scalars[0])
else:
return [scalars[0]]
new_op_input = T.mul(*non_scalars)
new_op_input_nb_elements = T.prod(new_op_input.shape)
new_op_output = node.op(new_op_input)
# If node.op is a T.elemwise.Prod, then the scalars need to be
# raised to the power of the number of elements in the input
# to the Prod
if isinstance(node.op, T.elemwise.Prod):
scalars = [s ** new_op_input_nb_elements for s in scalars]
# Scale the output of the op by the scalars and return as
# replacement for the original output
mul_inputs = scalars + [new_op_output]
return [T.mul(*mul_inputs)]
if isinstance(node.op, T.Sum) and node_inps.owner and node_inps.owner.op == T.neg:
return [T.neg(node.op(node_inps.owner.inputs[0]))]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论