提交 4bfa10de authored 作者: --global's avatar --global

Refactor code

上级 d1244668
......@@ -4519,69 +4519,73 @@ class T_local_sum_prod(unittest.TestCase):
s1_val = numpy.random.rand()
s2_val = numpy.random.rand()
# Test sum
def test_sum_opt(inputs, inputs_val, expected_output):
def test_reduction_opt(inputs, inputs_val, reduction_op,
expected_output, nb_expected_sum_nodes):
mul_out = T.mul(*inputs)
f = theano.function(inputs, T.sum(mul_out), mode=self.mode)
f = theano.function(inputs, reduction_op()(mul_out),
mode=self.mode)
out = f(*inputs_val)
utt.assert_allclose(out, expected_output)
# Ensure that the optimization has been applied properly by
# ensuring that the optimized graph contains the expected number
# of apply nodes for the sum op
prod_nodes = [n for n in f.maker.fgraph.toposort()
if isinstance(n.op, reduction_op)]
assert len(prod_nodes) == nb_expected_sum_nodes
# Test sum
# Case 1
test_sum_opt([scalar1], [s1_val], s1_val)
test_reduction_opt([scalar1], [s1_val], T.Sum, s1_val, 0)
# Case 2
test_sum_opt([vect, scalar1], [v_val, s1_val],
s1_val * v_val.sum())
test_reduction_opt([vect, scalar1], [v_val, s1_val], T.Sum,
s1_val * v_val.sum(), 1)
# Case 3
test_sum_opt([vect, mat, scalar1], [v_val, m_val, s1_val],
s1_val * (v_val * m_val).sum())
test_reduction_opt([vect, mat, scalar1], [v_val, m_val, s1_val], T.Sum,
s1_val * (v_val * m_val).sum(), 1)
# Case 4
test_sum_opt([scalar1, scalar2], [s1_val, s2_val],
s1_val * s2_val)
test_reduction_opt([scalar1, scalar2], [s1_val, s2_val], T.Sum,
s1_val * s2_val, 0)
# Case 5
test_sum_opt([vect, scalar1, scalar2],
[v_val, s1_val, s2_val],
s1_val * s2_val * v_val.sum())
test_reduction_opt([vect, scalar1, scalar2], [v_val, s1_val, s2_val],
T.Sum, s1_val * s2_val * v_val.sum(), 1)
# Case 6
test_sum_opt([vect, mat, scalar1, scalar2],
[v_val, m_val, s1_val, s2_val],
s1_val * s2_val * (v_val * m_val).sum())
test_reduction_opt([vect, mat, scalar1, scalar2],
[v_val, m_val, s1_val, s2_val], T.Sum,
s1_val * s2_val * (v_val * m_val).sum(), 1)
# Test prod
def test_prod_opt(inputs, inputs_val, expected_output):
mul_out = T.mul(*inputs)
f = theano.function(inputs, T.prod(mul_out), mode=self.mode)
out = f(*inputs_val)
utt.assert_allclose(out, expected_output)
# Case 1
test_prod_opt([scalar1], [s1_val], s1_val)
test_reduction_opt([scalar1], [s1_val], T.elemwise.Prod, s1_val, 0)
# Case 2
test_prod_opt([vect, scalar1], [v_val, s1_val],
(s1_val * v_val).prod())
test_reduction_opt([vect, scalar1], [v_val, s1_val], T.elemwise.Prod,
(s1_val * v_val).prod(), 2)
# Case 3
test_prod_opt([vect, mat, scalar1], [v_val, m_val, s1_val],
(s1_val * v_val * m_val).prod())
test_reduction_opt([vect, mat, scalar1], [v_val, m_val, s1_val],
T.elemwise.Prod, (s1_val * v_val * m_val).prod(), 2)
# Case 4
test_prod_opt([scalar1, scalar2], [s1_val, s2_val],
s1_val * s2_val)
test_reduction_opt([scalar1, scalar2], [s1_val, s2_val],
T.elemwise.Prod, s1_val * s2_val, 0)
# Case 5
test_prod_opt([vect, scalar1, scalar2],
[v_val, s1_val, s2_val],
(s1_val * s2_val * v_val).prod())
test_reduction_opt([vect, scalar1, scalar2], [v_val, s1_val, s2_val],
T.elemwise.Prod, (s1_val * s2_val * v_val).prod(),
2)
# Case 6
test_prod_opt([vect, mat, scalar1, scalar2],
[v_val, m_val, s1_val, s2_val],
(s1_val * s2_val * v_val * m_val).prod())
test_reduction_opt([vect, mat, scalar1, scalar2],
[v_val, m_val, s1_val, s2_val], T.elemwise.Prod,
(s1_val * s2_val * v_val * m_val).prod(), 2)
def test_local_sum_prod_all_to_none(self):
a = T.tensor3()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论