提交 5558269e authored 作者: Ricardo's avatar Ricardo 提交者: Thomas Wiecki

Remove `warn__sum_sum_bug` flag

上级 9fad1430
......@@ -1454,19 +1454,6 @@ def add_deprecated_configvars():
in_c_key=False,
)
config.add(
"warn__sum_sum_bug",
(
"Warn if we are in a case where Aesara version between version "
"9923a40c7b7a and the 2 august 2010 (fixed date), generated an "
"error in that case. This happens when there are 2 consecutive "
"sums in the graph, bad code was generated. "
"Was fixed 2 August 2010"
),
BoolParam(_warn_default("0.3")),
in_c_key=False,
)
config.add(
"warn__sum_div_dimshuffle_bug",
(
......
......@@ -1578,38 +1578,6 @@ def local_op_of_op(fgraph, node):
list(node_inps.owner.op.axis) + list(node.op.axis)
)
# The old bugged logic. We keep it there to generate a warning
# when we generated bad code.
alldims = list(range(node_inps.owner.inputs[0].type.ndim))
alldims = [
d for i, d in enumerate(alldims) if i in node_inps.owner.op.axis
]
alldims = [d for i, d in enumerate(alldims) if i in node.op.axis]
newaxis_old = [
i
for i in range(node_inps.owner.inputs[0].type.ndim)
if i not in alldims
]
if (
config.warn__sum_sum_bug
and newaxis != newaxis_old
and len(newaxis) == len(newaxis_old)
):
_logger.warning(
"(YOUR CURRENT CODE IS FINE): Aesara "
"versions between version 9923a40c7b7a and August "
"2nd, 2010 generated bugged code in this case. "
"This happens when there are two consecutive sums "
"in the graph and the intermediate sum is not "
"used elsewhere in the code. Some safeguard "
"removed some bad code, but not in all cases. You "
"are in one such case. To disable this warning "
"(that you can safely ignore since this bug has "
"been fixed) set the aesara flag "
"`warn__sum_sum_bug` to False."
)
combined = opt_type(newaxis, dtype=out_dtype)
return [combined(node_inps.owner.inputs[0])]
......
......@@ -3234,10 +3234,9 @@ class TestLocalSumProd:
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.prod())
with config.change_flags(warn__sum_sum_bug=False):
f = function([a], a.sum(0).sum(0).sum(0), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.sum())
f = function([a], a.sum(0).sum(0).sum(0), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
utt.assert_allclose(f(input), input.sum())
def test_local_sum_sum_prod_prod(self):
a = tensor3()
......@@ -3289,23 +3288,22 @@ class TestLocalSumProd:
dd = sorted(dd)
return data.sum(d).prod(dd[1]).prod(dd[0])
with config.change_flags(warn__sum_sum_bug=False):
for d, dd in dims:
expected = my_sum(input, d, dd)
f = function([a], a.sum(d).sum(dd), mode=self.mode)
utt.assert_allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 1
for d, dd in dims[:6]:
f = function([a], a.sum(d).sum(dd).sum(0), mode=self.mode)
utt.assert_allclose(f(input), input.sum(d).sum(dd).sum(0))
assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]:
f = function([a], a.sum(d).sum(None), mode=self.mode)
utt.assert_allclose(f(input), input.sum(d).sum())
assert len(f.maker.fgraph.apply_nodes) == 1
f = function([a], a.sum(None).sum(), mode=self.mode)
utt.assert_allclose(f(input), input.sum())
for d, dd in dims:
expected = my_sum(input, d, dd)
f = function([a], a.sum(d).sum(dd), mode=self.mode)
utt.assert_allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 1
for d, dd in dims[:6]:
f = function([a], a.sum(d).sum(dd).sum(0), mode=self.mode)
utt.assert_allclose(f(input), input.sum(d).sum(dd).sum(0))
assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]:
f = function([a], a.sum(d).sum(None), mode=self.mode)
utt.assert_allclose(f(input), input.sum(d).sum())
assert len(f.maker.fgraph.apply_nodes) == 1
f = function([a], a.sum(None).sum(), mode=self.mode)
utt.assert_allclose(f(input), input.sum())
assert len(f.maker.fgraph.apply_nodes) == 1
# test prod
for d, dd in dims:
......@@ -3401,14 +3399,13 @@ class TestLocalSumProd:
assert topo[-1].op == aet.alloc
assert not any([isinstance(node.op, Prod) for node in topo])
with config.change_flags(warn__sum_sum_bug=False):
for d, dd in [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)]:
f = function([a], t_like(a).sum(d).sum(dd), mode=mode)
utt.assert_allclose(f(input), n_like(input).sum(d).sum(dd))
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[3]
topo = f.maker.fgraph.toposort()
assert topo[-1].op == aet.alloc
assert not any([isinstance(node.op, Sum) for node in topo])
for d, dd in [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)]:
f = function([a], t_like(a).sum(d).sum(dd), mode=mode)
utt.assert_allclose(f(input), n_like(input).sum(d).sum(dd))
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[3]
topo = f.maker.fgraph.toposort()
assert topo[-1].op == aet.alloc
assert not any([isinstance(node.op, Sum) for node in topo])
def test_local_sum_sum_int8(self):
# Test that local_sum_sum works when combining two sums on an int8 array.
......@@ -3653,9 +3650,7 @@ class TestLocalSumProdDimshuffle:
c_val = rng.standard_normal((2, 2, 2)).astype(config.floatX)
d_val = np.asarray(rng.standard_normal(), config.floatX)
with config.change_flags(
warn__sum_sum_bug=False, warn__sum_div_dimshuffle_bug=False
):
with config.change_flags(warn__sum_div_dimshuffle_bug=False):
for i, s in enumerate(sums):
f = function([a, b, c, d], s, mode=self.mode, on_unused_input="ignore")
g = f.maker.fgraph.toposort()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论