提交 74944e78 authored 作者: Ricardo's avatar Ricardo 提交者: Thomas Wiecki

Remove `warn__argmax_pushdown_bug` flag

上级 19f1566f
......@@ -1453,17 +1453,6 @@ def add_deprecated_configvars():
BoolParam(True),
in_c_key=False,
)
# TODO: most of these bugfix-related warnings can probably be removed
config.add(
"warn__argmax_pushdown_bug",
(
"Warn if in past version of Aesara we generated a bug with the "
"aesara.tensor.nnet.basic.local_argmax_pushdown optimization. "
"Was fixed 27 may 2010"
),
BoolParam(_warn_default("0.3")),
in_c_key=False,
)
config.add(
"warn__gpusum_01_011_0111_bug",
......
......@@ -13,7 +13,6 @@ revisited later when all the intermediate part are on the GPU.
"""
import logging
import warnings
import numpy as np
......@@ -22,7 +21,6 @@ import aesara
from aesara import scalar as aes
from aesara.assert_op import Assert
from aesara.compile import optdb
from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, grad_not_implemented
from aesara.graph.basic import Apply
from aesara.graph.op import COp, Op
......@@ -1674,31 +1672,6 @@ def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node):
@register_specialize("fast_compile_gpu")
@local_optimizer([MaxAndArgmax])
def local_argmax_pushdown(fgraph, node):
if (
isinstance(node.op, MaxAndArgmax)
and node.inputs[0].owner
and len(fgraph.clients[node.outputs[0]]) > 0
and node.inputs[0].owner.op
in (
softmax_op,
softplus,
exp,
log,
tanh,
sigmoid,
softmax_with_bias,
)
):
if config.warn__argmax_pushdown_bug:
logging.getLogger("aesara.tensor.nnet.basic").warn(
"There was a bug in Aesara fixed on May 27th, 2010 in this case."
" I.E. when we take the max of a softplus, softmax, exp, "
"log, tanh, sigmoid, softmax_with_bias op, we were doing "
"the max of the parent of the input. To remove this "
"warning set the Aesara flags 'warn__argmax_pushdown_bug' "
"to False"
)
if (
isinstance(node.op, MaxAndArgmax)
and node.inputs[0].owner
......
......@@ -1083,8 +1083,7 @@ def test_argmax_pushdown():
assert hasattr(fgraph.outputs[0].tag, "trace")
with config.change_flags(warn__argmax_pushdown_bug=False):
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).optimize(fgraph)
# print 'AFTER'
# for node in fgraph.toposort():
......@@ -1119,8 +1118,7 @@ def test_argmax_pushdown_bias():
out = max_and_argmax(softmax_with_bias(x, b), axis=-1)[0]
fgraph = FunctionGraph([x, b], [out])
with config.change_flags(warn__argmax_pushdown_bug=False):
optdb.query(OPT_FAST_RUN).optimize(fgraph)
optdb.query(OPT_FAST_RUN).optimize(fgraph)
assert len(fgraph.toposort()) == 2
assert isinstance(fgraph.toposort()[0].op, SoftmaxWithBias)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论