提交 47586163 authored 作者: Ricardo's avatar Ricardo 提交者: Thomas Wiecki

Remove `warn__sum_div_dimshuffle_bug` flag

上级 5558269e
...@@ -1454,18 +1454,6 @@ def add_deprecated_configvars(): ...@@ -1454,18 +1454,6 @@ def add_deprecated_configvars():
in_c_key=False, in_c_key=False,
) )
config.add(
"warn__sum_div_dimshuffle_bug",
(
"Warn if previous versions of Aesara (between rev. "
"3bd9b789f5e8, 2010-06-16, and cfc6322e5ad4, 2010-08-03) "
"would have given incorrect result. This bug was triggered by "
"sum of division of dimshuffled tensors."
),
BoolParam(_warn_default("0.3")),
in_c_key=False,
)
config.add( config.add(
"warn__subtensor_merge_bug", "warn__subtensor_merge_bug",
"Warn if previous versions of Aesara (before 0.5rc2) could have given " "Warn if previous versions of Aesara (before 0.5rc2) could have given "
......
...@@ -1415,33 +1415,6 @@ def local_sum_prod_div_dimshuffle(fgraph, node): ...@@ -1415,33 +1415,6 @@ def local_sum_prod_div_dimshuffle(fgraph, node):
if node_input.owner and node_input.owner.op == true_div: if node_input.owner and node_input.owner.op == true_div:
numerator, denominator = node_input.owner.inputs numerator, denominator = node_input.owner.inputs
# Old, bugged logic, reproduced here only to warn users
if (
config.warn__sum_div_dimshuffle_bug
and isinstance(node.op, Sum)
and numerator.owner
and isinstance(numerator.owner.op, DimShuffle)
):
# Check compatibility
new_order = numerator.owner.op.new_order
compatible_dims = True
for ax in axis:
if len(new_order) <= ax or new_order[ax] != "x":
compatible_dims = False
break
if compatible_dims:
_logger.warning(
"Your current code is fine, but"
" Aesara versions between "
"rev. 3bd9b789f5e8 (2010-06-16) and"
" cfc6322e5ad4 (2010-08-03) would "
"have given an incorrect result. "
"To disable this warning, set the Aesara"
" flag warn__sum_div_dimshuffle_bug to"
" False."
)
if denominator.owner and isinstance(denominator.owner.op, DimShuffle): if denominator.owner and isinstance(denominator.owner.op, DimShuffle):
dimshuffle_input = denominator.owner.inputs[0] dimshuffle_input = denominator.owner.inputs[0]
dimshuffle_order = denominator.owner.op.new_order dimshuffle_order = denominator.owner.op.new_order
...@@ -1483,21 +1456,6 @@ def local_sum_prod_div_dimshuffle(fgraph, node): ...@@ -1483,21 +1456,6 @@ def local_sum_prod_div_dimshuffle(fgraph, node):
optimized_dimshuffle_order, optimized_dimshuffle_order,
)(dimshuffle_input) )(dimshuffle_input)
if config.warn__sum_div_dimshuffle_bug and isinstance(
node.op, Sum
):
_logger.warning(
"Your current code is fine,"
" but Aesara versions between "
"rev. 3bd9b789f5e8 (2010-06-16) and"
" cfc6322e5ad4 (2010-08-03) would "
"have given an incorrect result. "
"To disable this warning, set the"
" Aesara flag "
"warn__sum_div_dimshuffle_bug"
" to False."
)
if isinstance(node.op, Sum): if isinstance(node.op, Sum):
op_on_compatible_dims = aet_sum(numerator, axis=compatible_dims) op_on_compatible_dims = aet_sum(numerator, axis=compatible_dims)
rval = true_div(op_on_compatible_dims, optimized_dimshuffle) rval = true_div(op_on_compatible_dims, optimized_dimshuffle)
......
...@@ -900,9 +900,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -900,9 +900,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert crossentropy_softmax_argmax_1hot_with_bias in ops assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)] assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
with config.change_flags(warn__sum_div_dimshuffle_bug=False): fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
fgraph = FunctionGraph([x, b, y], [grad(expr, x)]) optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) <= 6 assert len(ops) <= 6
...@@ -937,9 +936,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -937,9 +936,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert crossentropy_softmax_argmax_1hot_with_bias in ops assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)] assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
with config.change_flags(warn__sum_div_dimshuffle_bug=False): fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
fgraph = FunctionGraph([x, b, y], [grad(expr, x)]) optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) <= 6 assert len(ops) <= 6
...@@ -975,9 +973,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester): ...@@ -975,9 +973,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert crossentropy_softmax_argmax_1hot_with_bias in ops assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)] assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
with config.change_flags(warn__sum_div_dimshuffle_bug=False): fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
fgraph = FunctionGraph([x, b, y], [grad(expr, x)]) optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()] ops = [node.op for node in fgraph.toposort()]
assert len(ops) <= 6 assert len(ops) <= 6
...@@ -1223,8 +1220,7 @@ class TestSoftmaxOpt: ...@@ -1223,8 +1220,7 @@ class TestSoftmaxOpt:
# test that function contains softmax and softmaxgrad # test that function contains softmax and softmaxgrad
w = matrix() w = matrix()
with config.change_flags(warn__sum_div_dimshuffle_bug=False): g = aesara.function([c, w], grad((p_y * w).sum(), c))
g = aesara.function([c, w], grad((p_y * w).sum(), c))
g_ops = [n.op for n in g.maker.fgraph.toposort()] g_ops = [n.op for n in g.maker.fgraph.toposort()]
...@@ -1244,8 +1240,7 @@ class TestSoftmaxOpt: ...@@ -1244,8 +1240,7 @@ class TestSoftmaxOpt:
aesara.function([c], p_y) aesara.function([c], p_y)
# test that function contains softmax and no div. # test that function contains softmax and no div.
with config.change_flags(warn__sum_div_dimshuffle_bug=False): aesara.function([c], grad(p_y.sum(), c))
aesara.function([c], grad(p_y.sum(), c))
@pytest.mark.skip(reason="Optimization not enabled for the moment") @pytest.mark.skip(reason="Optimization not enabled for the moment")
def test_1d_basic(self): def test_1d_basic(self):
...@@ -1257,8 +1252,7 @@ class TestSoftmaxOpt: ...@@ -1257,8 +1252,7 @@ class TestSoftmaxOpt:
aesara.function([c], p_y) aesara.function([c], p_y)
# test that function contains softmax and no div. # test that function contains softmax and no div.
with config.change_flags(warn__sum_div_dimshuffle_bug=False): aesara.function([c], grad(p_y.sum(), c))
aesara.function([c], grad(p_y.sum(), c))
def test_softmax_graph(): def test_softmax_graph():
......
...@@ -3650,12 +3650,11 @@ class TestLocalSumProdDimshuffle: ...@@ -3650,12 +3650,11 @@ class TestLocalSumProdDimshuffle:
c_val = rng.standard_normal((2, 2, 2)).astype(config.floatX) c_val = rng.standard_normal((2, 2, 2)).astype(config.floatX)
d_val = np.asarray(rng.standard_normal(), config.floatX) d_val = np.asarray(rng.standard_normal(), config.floatX)
with config.change_flags(warn__sum_div_dimshuffle_bug=False): for i, s in enumerate(sums):
for i, s in enumerate(sums): f = function([a, b, c, d], s, mode=self.mode, on_unused_input="ignore")
f = function([a, b, c, d], s, mode=self.mode, on_unused_input="ignore") g = f.maker.fgraph.toposort()
g = f.maker.fgraph.toposort() assert isinstance(g[-1].op.scalar_op, aes.basic.TrueDiv)
assert isinstance(g[-1].op.scalar_op, aes.basic.TrueDiv) f(a_val, b_val, c_val, d_val)
f(a_val, b_val, c_val, d_val)
def test_local_prod_div_dimshuffle(self): def test_local_prod_div_dimshuffle(self):
a = matrix("a") a = matrix("a")
......
...@@ -609,30 +609,24 @@ def makeTester( ...@@ -609,30 +609,24 @@ def makeTester(
@pytest.mark.skipif(skip, reason="Skipped") @pytest.mark.skipif(skip, reason="Skipped")
def test_grad(self): def test_grad(self):
# Disable old warning that may be triggered by this test. for testname, inputs in self.grad.items():
backup = config.warn__sum_div_dimshuffle_bug inputs = [copy(input) for input in inputs]
config.warn__sum_div_dimshuffle_bug = False try:
try: utt.verify_grad(
for testname, inputs in self.grad.items(): self.op,
inputs = [copy(input) for input in inputs] inputs,
try: mode=self.mode,
utt.verify_grad( rel_tol=_grad_rtol,
self.op, eps=_grad_eps,
inputs, )
mode=self.mode, except Exception as exc:
rel_tol=_grad_rtol, err_msg = (
eps=_grad_eps, "Test %s::%s: Error occurred while"
) " computing the gradient on the following"
except Exception as exc: " inputs: %s"
err_msg = ( ) % (self.op, testname, inputs)
"Test %s::%s: Error occurred while" exc.args += (err_msg,)
" computing the gradient on the following" raise
" inputs: %s"
) % (self.op, testname, inputs)
exc.args += (err_msg,)
raise
finally:
config.warn__sum_div_dimshuffle_bug = backup
@pytest.mark.skipif(skip, reason="Skipped") @pytest.mark.skipif(skip, reason="Skipped")
def test_grad_none(self): def test_grad_none(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论