提交 74944e78 authored 作者: Ricardo's avatar Ricardo 提交者: Thomas Wiecki

Remove `warn__argmax_pushdown_bug` flag

上级 19f1566f
...@@ -1453,17 +1453,6 @@ def add_deprecated_configvars(): ...@@ -1453,17 +1453,6 @@ def add_deprecated_configvars():
BoolParam(True), BoolParam(True),
in_c_key=False, in_c_key=False,
) )
# TODO: most of these bugfix-related warnings can probably be removed
config.add(
"warn__argmax_pushdown_bug",
(
"Warn if in past version of Aesara we generated a bug with the "
"aesara.tensor.nnet.basic.local_argmax_pushdown optimization. "
"Was fixed 27 may 2010"
),
BoolParam(_warn_default("0.3")),
in_c_key=False,
)
config.add( config.add(
"warn__gpusum_01_011_0111_bug", "warn__gpusum_01_011_0111_bug",
......
...@@ -13,7 +13,6 @@ revisited later when all the intermediate part are on the GPU. ...@@ -13,7 +13,6 @@ revisited later when all the intermediate part are on the GPU.
""" """
import logging
import warnings import warnings
import numpy as np import numpy as np
...@@ -22,7 +21,6 @@ import aesara ...@@ -22,7 +21,6 @@ import aesara
from aesara import scalar as aes from aesara import scalar as aes
from aesara.assert_op import Assert from aesara.assert_op import Assert
from aesara.compile import optdb from aesara.compile import optdb
from aesara.configdefaults import config
from aesara.gradient import DisconnectedType, grad_not_implemented from aesara.gradient import DisconnectedType, grad_not_implemented
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.op import COp, Op from aesara.graph.op import COp, Op
...@@ -1674,31 +1672,6 @@ def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node): ...@@ -1674,31 +1672,6 @@ def local_softmax_grad_to_crossentropy_with_softmax_grad(fgraph, node):
@register_specialize("fast_compile_gpu") @register_specialize("fast_compile_gpu")
@local_optimizer([MaxAndArgmax]) @local_optimizer([MaxAndArgmax])
def local_argmax_pushdown(fgraph, node): def local_argmax_pushdown(fgraph, node):
if (
isinstance(node.op, MaxAndArgmax)
and node.inputs[0].owner
and len(fgraph.clients[node.outputs[0]]) > 0
and node.inputs[0].owner.op
in (
softmax_op,
softplus,
exp,
log,
tanh,
sigmoid,
softmax_with_bias,
)
):
if config.warn__argmax_pushdown_bug:
logging.getLogger("aesara.tensor.nnet.basic").warn(
"There was a bug in Aesara fixed on May 27th, 2010 in this case."
" I.E. when we take the max of a softplus, softmax, exp, "
"log, tanh, sigmoid, softmax_with_bias op, we were doing "
"the max of the parent of the input. To remove this "
"warning set the Aesara flags 'warn__argmax_pushdown_bug' "
"to False"
)
if ( if (
isinstance(node.op, MaxAndArgmax) isinstance(node.op, MaxAndArgmax)
and node.inputs[0].owner and node.inputs[0].owner
......
...@@ -1083,7 +1083,6 @@ def test_argmax_pushdown(): ...@@ -1083,7 +1083,6 @@ def test_argmax_pushdown():
assert hasattr(fgraph.outputs[0].tag, "trace") assert hasattr(fgraph.outputs[0].tag, "trace")
with config.change_flags(warn__argmax_pushdown_bug=False):
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).optimize(fgraph)
# print 'AFTER' # print 'AFTER'
...@@ -1119,7 +1118,6 @@ def test_argmax_pushdown_bias(): ...@@ -1119,7 +1118,6 @@ def test_argmax_pushdown_bias():
out = max_and_argmax(softmax_with_bias(x, b), axis=-1)[0] out = max_and_argmax(softmax_with_bias(x, b), axis=-1)[0]
fgraph = FunctionGraph([x, b], [out]) fgraph = FunctionGraph([x, b], [out])
with config.change_flags(warn__argmax_pushdown_bug=False):
optdb.query(OPT_FAST_RUN).optimize(fgraph) optdb.query(OPT_FAST_RUN).optimize(fgraph)
assert len(fgraph.toposort()) == 2 assert len(fgraph.toposort()) == 2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论