提交 7f80efca authored 作者: --global's avatar --global

Update local_sum_prod_mul_by_scalar to work correctly with products

上级 fbef8859
...@@ -3886,24 +3886,34 @@ def local_sum_prod_mul_by_scalar(node): ...@@ -3886,24 +3886,34 @@ def local_sum_prod_mul_by_scalar(node):
scalars = [t.dimshuffle() for t in terms if scalars = [t.dimshuffle() for t in terms if
numpy.all(t.type.broadcastable)] numpy.all(t.type.broadcastable)]
non_scalars = [t for t in terms if not numpy.all(t.broadcastable)] non_scalars = [t for t in terms if not numpy.all(t.broadcastable)]
if scalars:
if len(scalars) > 1: if len(scalars) == 0:
if len(non_scalars) > 1: # Nothing to optimize here
return [T.mul(T.mul(*scalars), return
node.op(T.mul(*non_scalars)))]
elif len(non_scalars) == 1: # Perform the op only on the non-scalar inputs, if applicable
return [T.mul(T.mul(*scalars), if len(non_scalars) == 0:
node.op(non_scalars[0]))] new_op_input_nb_elements = 1
else: new_op_output = 1
return [T.mul(*scalars)] elif len(non_scalars) == 1:
else: new_op_input_nb_elements = T.prod(non_scalars[0].shape)
if len(non_scalars) > 1: new_op_output = node.op(non_scalars[0])
return [T.mul(scalars[0], else:
node.op(T.mul(*non_scalars)))] new_op_input = T.mul(*non_scalars)
elif len(non_scalars) == 1: new_op_input_nb_elements = T.prod(new_op_input.shape)
return [T.mul(scalars[0], node.op(non_scalars[0]))] new_op_output = node.op(new_op_input)
else:
return [scalars[0]] # If node.op is a T.elemwise.Prod, then the scalars need to be
# raised to the power of the number of elements in the input
# to the Prod
if isinstance(node.op, T.elemwise.Prod):
scalars = [s ** new_op_input_nb_elements for s in scalars]
# Scale the output of the op by the scalars and return as
# replacement for the original output
mul_inputs = scalars + [new_op_output]
return [T.mul(*mul_inputs)]
if isinstance(node.op, T.Sum) and node_inps.owner and node_inps.owner.op == T.neg: if isinstance(node.op, T.Sum) and node_inps.owner and node_inps.owner.op == T.neg:
return [T.neg(node.op(node_inps.owner.inputs[0]))] return [T.neg(node.op(node_inps.owner.inputs[0]))]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论