提交 2a21580d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Refactor sum_prod_mul rewrite test and add failing case

Rewrite from prod of mul was not correct when only some axes were reduced by prod
上级 b5cff61d
...@@ -2512,6 +2512,7 @@ class TestLocalSumProd: ...@@ -2512,6 +2512,7 @@ class TestLocalSumProd:
# 4-the inputs to the mul contain two scalars and no non-scalar # 4-the inputs to the mul contain two scalars and no non-scalar
# 5-the inputs to the mul contain two scalars and one non-scalar # 5-the inputs to the mul contain two scalars and one non-scalar
# 6-the inputs to the mul contain two scalars and two non-scalars # 6-the inputs to the mul contain two scalars and two non-scalars
# 7-the reduction happens across only the first of two axes
vect = dvector() vect = dvector()
mat = dmatrix() mat = dmatrix()
...@@ -2524,10 +2525,15 @@ class TestLocalSumProd: ...@@ -2524,10 +2525,15 @@ class TestLocalSumProd:
s2_val = np.random.random() s2_val = np.random.random()
def test_reduction_rewrite( def test_reduction_rewrite(
inputs, inputs_val, reduction_op, expected_output, nb_expected_sum_nodes inputs,
inputs_val,
reduction_op,
expected_output,
nb_expected_sum_nodes,
axis=None,
): ):
mul_out = mul(*inputs) mul_out = mul(*inputs)
f = function(inputs, reduction_op()(mul_out), mode=self.mode) f = function(inputs, reduction_op(axis=axis)(mul_out), mode=self.mode)
out = f(*inputs_val) out = f(*inputs_val)
utt.assert_allclose(out, expected_output) utt.assert_allclose(out, expected_output)
...@@ -2581,6 +2587,16 @@ class TestLocalSumProd: ...@@ -2581,6 +2587,16 @@ class TestLocalSumProd:
1, 1,
) )
# Case 7
test_reduction_rewrite(
[mat, scalar1, scalar2],
[m_val, s1_val, s2_val],
Sum,
(s1_val * s2_val * m_val).sum(0),
1,
axis=(0,),
)
# Test prod # Test prod
# Case 1 # Case 1
...@@ -2627,6 +2643,16 @@ class TestLocalSumProd: ...@@ -2627,6 +2643,16 @@ class TestLocalSumProd:
2, 2,
) )
# Case 7
test_reduction_rewrite(
[mat, scalar1, scalar2],
[m_val, s1_val, s2_val],
Prod,
(s1_val * s2_val * m_val).prod(0),
1,
axis=(0,),
)
def test_local_sum_prod_all_to_none(self): def test_local_sum_prod_all_to_none(self):
a = tensor3() a = tensor3()
input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5) input = np.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论