提交 47586163 authored 作者: Ricardo's avatar Ricardo 提交者: Thomas Wiecki

Remove `warn__sum_div_dimshuffle_bug` flag

上级 5558269e
......@@ -1454,18 +1454,6 @@ def add_deprecated_configvars():
in_c_key=False,
)
config.add(
"warn__sum_div_dimshuffle_bug",
(
"Warn if previous versions of Aesara (between rev. "
"3bd9b789f5e8, 2010-06-16, and cfc6322e5ad4, 2010-08-03) "
"would have given incorrect result. This bug was triggered by "
"sum of division of dimshuffled tensors."
),
BoolParam(_warn_default("0.3")),
in_c_key=False,
)
config.add(
"warn__subtensor_merge_bug",
"Warn if previous versions of Aesara (before 0.5rc2) could have given "
......
......@@ -1415,33 +1415,6 @@ def local_sum_prod_div_dimshuffle(fgraph, node):
if node_input.owner and node_input.owner.op == true_div:
numerator, denominator = node_input.owner.inputs
# Old, bugged logic, reproduced here only to warn users
if (
config.warn__sum_div_dimshuffle_bug
and isinstance(node.op, Sum)
and numerator.owner
and isinstance(numerator.owner.op, DimShuffle)
):
# Check compatibility
new_order = numerator.owner.op.new_order
compatible_dims = True
for ax in axis:
if len(new_order) <= ax or new_order[ax] != "x":
compatible_dims = False
break
if compatible_dims:
_logger.warning(
"Your current code is fine, but"
" Aesara versions between "
"rev. 3bd9b789f5e8 (2010-06-16) and"
" cfc6322e5ad4 (2010-08-03) would "
"have given an incorrect result. "
"To disable this warning, set the Aesara"
" flag warn__sum_div_dimshuffle_bug to"
" False."
)
if denominator.owner and isinstance(denominator.owner.op, DimShuffle):
dimshuffle_input = denominator.owner.inputs[0]
dimshuffle_order = denominator.owner.op.new_order
......@@ -1483,21 +1456,6 @@ def local_sum_prod_div_dimshuffle(fgraph, node):
optimized_dimshuffle_order,
)(dimshuffle_input)
if config.warn__sum_div_dimshuffle_bug and isinstance(
node.op, Sum
):
_logger.warning(
"Your current code is fine,"
" but Aesara versions between "
"rev. 3bd9b789f5e8 (2010-06-16) and"
" cfc6322e5ad4 (2010-08-03) would "
"have given an incorrect result. "
"To disable this warning, set the"
" Aesara flag "
"warn__sum_div_dimshuffle_bug"
" to False."
)
if isinstance(node.op, Sum):
op_on_compatible_dims = aet_sum(numerator, axis=compatible_dims)
rval = true_div(op_on_compatible_dims, optimized_dimshuffle)
......
......@@ -900,9 +900,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
with config.change_flags(warn__sum_div_dimshuffle_bug=False):
fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) <= 6
......@@ -937,9 +936,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
with config.change_flags(warn__sum_div_dimshuffle_bug=False):
fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) <= 6
......@@ -975,9 +973,8 @@ class TestCrossEntropyCategorical1Hot(utt.InferShapeTester):
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops if isinstance(o, AdvancedSubtensor)]
with config.change_flags(warn__sum_div_dimshuffle_bug=False):
fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
fgraph = FunctionGraph([x, b, y], [grad(expr, x)])
optdb.query(OPT_FAST_RUN).optimize(fgraph)
ops = [node.op for node in fgraph.toposort()]
assert len(ops) <= 6
......@@ -1223,8 +1220,7 @@ class TestSoftmaxOpt:
# test that function contains softmax and softmaxgrad
w = matrix()
with config.change_flags(warn__sum_div_dimshuffle_bug=False):
g = aesara.function([c, w], grad((p_y * w).sum(), c))
g = aesara.function([c, w], grad((p_y * w).sum(), c))
g_ops = [n.op for n in g.maker.fgraph.toposort()]
......@@ -1244,8 +1240,7 @@ class TestSoftmaxOpt:
aesara.function([c], p_y)
# test that function contains softmax and no div.
with config.change_flags(warn__sum_div_dimshuffle_bug=False):
aesara.function([c], grad(p_y.sum(), c))
aesara.function([c], grad(p_y.sum(), c))
@pytest.mark.skip(reason="Optimization not enabled for the moment")
def test_1d_basic(self):
......@@ -1257,8 +1252,7 @@ class TestSoftmaxOpt:
aesara.function([c], p_y)
# test that function contains softmax and no div.
with config.change_flags(warn__sum_div_dimshuffle_bug=False):
aesara.function([c], grad(p_y.sum(), c))
aesara.function([c], grad(p_y.sum(), c))
def test_softmax_graph():
......
......@@ -3650,12 +3650,11 @@ class TestLocalSumProdDimshuffle:
c_val = rng.standard_normal((2, 2, 2)).astype(config.floatX)
d_val = np.asarray(rng.standard_normal(), config.floatX)
with config.change_flags(warn__sum_div_dimshuffle_bug=False):
for i, s in enumerate(sums):
f = function([a, b, c, d], s, mode=self.mode, on_unused_input="ignore")
g = f.maker.fgraph.toposort()
assert isinstance(g[-1].op.scalar_op, aes.basic.TrueDiv)
f(a_val, b_val, c_val, d_val)
for i, s in enumerate(sums):
f = function([a, b, c, d], s, mode=self.mode, on_unused_input="ignore")
g = f.maker.fgraph.toposort()
assert isinstance(g[-1].op.scalar_op, aes.basic.TrueDiv)
f(a_val, b_val, c_val, d_val)
def test_local_prod_div_dimshuffle(self):
a = matrix("a")
......
......@@ -609,30 +609,24 @@ def makeTester(
@pytest.mark.skipif(skip, reason="Skipped")
def test_grad(self):
# Disable old warning that may be triggered by this test.
backup = config.warn__sum_div_dimshuffle_bug
config.warn__sum_div_dimshuffle_bug = False
try:
for testname, inputs in self.grad.items():
inputs = [copy(input) for input in inputs]
try:
utt.verify_grad(
self.op,
inputs,
mode=self.mode,
rel_tol=_grad_rtol,
eps=_grad_eps,
)
except Exception as exc:
err_msg = (
"Test %s::%s: Error occurred while"
" computing the gradient on the following"
" inputs: %s"
) % (self.op, testname, inputs)
exc.args += (err_msg,)
raise
finally:
config.warn__sum_div_dimshuffle_bug = backup
for testname, inputs in self.grad.items():
inputs = [copy(input) for input in inputs]
try:
utt.verify_grad(
self.op,
inputs,
mode=self.mode,
rel_tol=_grad_rtol,
eps=_grad_eps,
)
except Exception as exc:
err_msg = (
"Test %s::%s: Error occurred while"
" computing the gradient on the following"
" inputs: %s"
) % (self.op, testname, inputs)
exc.args += (err_msg,)
raise
@pytest.mark.skipif(skip, reason="Skipped")
def test_grad_none(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论