提交 7b7b6618 authored 作者: Frederic's avatar Frederic

Don't introduce useless prod op

上级 219d33bf
......@@ -4081,28 +4081,29 @@ def local_sum_prod_mul_by_scalar(node):
"""
# TODO: if the the thing inside the Sum is a division,
# we should get at the numerator....
if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod):
if isinstance(node.op, (T.Sum, T.elemwise.Prod)):
node_inps, = node.inputs
if node_inps.owner and node_inps.owner.op == T.mul:
terms = node_inps.owner.inputs
scalars = [t.dimshuffle() for t in terms if
numpy.all(t.type.broadcastable)]
non_scalars = [t for t in terms if not numpy.all(t.broadcastable)]
if len(scalars) == 0:
# Nothing to optimize here
return
non_scalars = [t for t in terms if not numpy.all(t.broadcastable)]
# Perform the op only on the non-scalar inputs, if applicable
if len(non_scalars) == 0:
new_op_input_nb_elements = 1
new_op_output = 1
elif len(non_scalars) == 1:
new_op_input_nb_elements = T.prod(non_scalars[0].shape)
new_op_input_nb_elements = non_scalars[0].size
new_op_output = node.op(non_scalars[0])
else:
new_op_input = T.mul(*non_scalars)
new_op_input_nb_elements = T.prod(new_op_input.shape)
new_op_input_nb_elements = new_op_input.size
new_op_output = node.op(new_op_input)
# If node.op is a T.elemwise.Prod, then the scalars need to be
......
......@@ -4568,7 +4568,7 @@ class T_local_sum_prod(unittest.TestCase):
# Case 2
test_reduction_opt([vect, scalar1], [v_val, s1_val], T.elemwise.Prod,
(s1_val * v_val).prod(), 2)
(s1_val * v_val).prod(), 1)
# Case 3
test_reduction_opt([vect, mat, scalar1], [v_val, m_val, s1_val],
......@@ -4581,7 +4581,7 @@ class T_local_sum_prod(unittest.TestCase):
# Case 5
test_reduction_opt([vect, scalar1, scalar2], [v_val, s1_val, s2_val],
T.elemwise.Prod, (s1_val * s2_val * v_val).prod(),
2)
1)
# Case 6
test_reduction_opt([vect, mat, scalar1, scalar2],
......
......@@ -280,7 +280,8 @@ class _tensor_py_operators:
shape = property(lambda self: theano.tensor.basic.shape(self))
size = property(lambda self: theano.tensor.basic.prod(self.shape))
size = property(lambda self: self.shape[0] if self.ndim == 1 else
theano.tensor.basic.prod(self.shape))
# We can't implement __len__ to provide a better error message.
def any(self, axis=None, keepdims=False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论