提交 7b7b6618 authored 作者: Frederic's avatar Frederic

Don't introduce useless prod op

上级 219d33bf
...@@ -4081,28 +4081,29 @@ def local_sum_prod_mul_by_scalar(node): ...@@ -4081,28 +4081,29 @@ def local_sum_prod_mul_by_scalar(node):
""" """
# TODO: if the the thing inside the Sum is a division, # TODO: if the the thing inside the Sum is a division,
# we should get at the numerator.... # we should get at the numerator....
if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod): if isinstance(node.op, (T.Sum, T.elemwise.Prod)):
node_inps, = node.inputs node_inps, = node.inputs
if node_inps.owner and node_inps.owner.op == T.mul: if node_inps.owner and node_inps.owner.op == T.mul:
terms = node_inps.owner.inputs terms = node_inps.owner.inputs
scalars = [t.dimshuffle() for t in terms if scalars = [t.dimshuffle() for t in terms if
numpy.all(t.type.broadcastable)] numpy.all(t.type.broadcastable)]
non_scalars = [t for t in terms if not numpy.all(t.broadcastable)]
if len(scalars) == 0: if len(scalars) == 0:
# Nothing to optimize here # Nothing to optimize here
return return
non_scalars = [t for t in terms if not numpy.all(t.broadcastable)]
# Perform the op only on the non-scalar inputs, if applicable # Perform the op only on the non-scalar inputs, if applicable
if len(non_scalars) == 0: if len(non_scalars) == 0:
new_op_input_nb_elements = 1 new_op_input_nb_elements = 1
new_op_output = 1 new_op_output = 1
elif len(non_scalars) == 1: elif len(non_scalars) == 1:
new_op_input_nb_elements = T.prod(non_scalars[0].shape) new_op_input_nb_elements = non_scalars[0].size
new_op_output = node.op(non_scalars[0]) new_op_output = node.op(non_scalars[0])
else: else:
new_op_input = T.mul(*non_scalars) new_op_input = T.mul(*non_scalars)
new_op_input_nb_elements = T.prod(new_op_input.shape) new_op_input_nb_elements = new_op_input.size
new_op_output = node.op(new_op_input) new_op_output = node.op(new_op_input)
# If node.op is a T.elemwise.Prod, then the scalars need to be # If node.op is a T.elemwise.Prod, then the scalars need to be
......
...@@ -4568,7 +4568,7 @@ class T_local_sum_prod(unittest.TestCase): ...@@ -4568,7 +4568,7 @@ class T_local_sum_prod(unittest.TestCase):
# Case 2 # Case 2
test_reduction_opt([vect, scalar1], [v_val, s1_val], T.elemwise.Prod, test_reduction_opt([vect, scalar1], [v_val, s1_val], T.elemwise.Prod,
(s1_val * v_val).prod(), 2) (s1_val * v_val).prod(), 1)
# Case 3 # Case 3
test_reduction_opt([vect, mat, scalar1], [v_val, m_val, s1_val], test_reduction_opt([vect, mat, scalar1], [v_val, m_val, s1_val],
...@@ -4581,7 +4581,7 @@ class T_local_sum_prod(unittest.TestCase): ...@@ -4581,7 +4581,7 @@ class T_local_sum_prod(unittest.TestCase):
# Case 5 # Case 5
test_reduction_opt([vect, scalar1, scalar2], [v_val, s1_val, s2_val], test_reduction_opt([vect, scalar1, scalar2], [v_val, s1_val, s2_val],
T.elemwise.Prod, (s1_val * s2_val * v_val).prod(), T.elemwise.Prod, (s1_val * s2_val * v_val).prod(),
2) 1)
# Case 6 # Case 6
test_reduction_opt([vect, mat, scalar1, scalar2], test_reduction_opt([vect, mat, scalar1, scalar2],
......
...@@ -280,7 +280,8 @@ class _tensor_py_operators: ...@@ -280,7 +280,8 @@ class _tensor_py_operators:
shape = property(lambda self: theano.tensor.basic.shape(self)) shape = property(lambda self: theano.tensor.basic.shape(self))
size = property(lambda self: theano.tensor.basic.prod(self.shape)) size = property(lambda self: self.shape[0] if self.ndim == 1 else
theano.tensor.basic.prod(self.shape))
# We can't implement __len__ to provide a better error message. # We can't implement __len__ to provide a better error message.
def any(self, axis=None, keepdims=False): def any(self, axis=None, keepdims=False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论