提交 900546ef authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Fix the rewrite tests

上级 ff1a3a9d
from aesara.tensor.rewriting.basic import (
register_specialize,
)
from aesara import scalar as aes
from aesara.tensor.math import true_div, exp, Sum
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from aesara.graph.rewriting.basic import copy_stack_trace, node_rewriter
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.math import Sum, exp
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import true_div
from aesara.tensor.rewriting.basic import register_specialize
from aesara.tensor.rewriting.math import local_mul_canonizer
from aesara.graph.rewriting.basic import node_rewriter, copy_stack_trace
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.subtensor import AdvancedIncSubtensor
from aesara.tensor.elemwise import Elemwise, DimShuffle
from aesara.tensor.type import values_eq_approx_remove_inf, values_eq_approx_remove_nan
# This is not registered in stabilize, as it cause some crossentropy
......@@ -63,7 +64,7 @@ def local_logsoftmax_grad(fgraph, node):
node.inputs[0].owner.inputs[0].owner.op, AdvancedIncSubtensor
)
# the rewrite only applies to legacy SoftmaxGrad
and node.op == softmax_grad_legacy
and node.op == SoftmaxGrad(axis=-1)
and node.inputs[0].owner.inputs[1].ndim == 2
)
):
......
......@@ -2,11 +2,9 @@ import numpy as np
import pytest
import aesara
import aesara.tensor as at
from aesara import shared
from aesara.compile import optdb
from aesara.compile.function import function
from aesara.compile.mode import OPT_FAST_RUN, Mode, get_mode
from aesara.compile.mode import get_mode
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.basic import check_stack_trace
......@@ -17,6 +15,10 @@ from aesara.tensor.type import matrix
from tests import unittest_tools as utt
_fast_run_rewrites = RewriteDatabaseQuery(include=["fast_run"])
_fast_run_rewrites = optdb.query(_fast_run_rewrites)
class TestLogSoftmaxRewrites:
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_rewrite(self, axis):
......@@ -29,9 +31,10 @@ class TestLogSoftmaxRewrites:
x = matrix("x")
sm = softmax(x, axis=axis)
logsm = log(sm)
f = function([x], logsm)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(f, ops_to_check=LogSoftmax)
fgraph = FunctionGraph([x], [logsm])
_fast_run_rewrites.rewrite(fgraph)
assert isinstance(fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(fgraph, ops_to_check=LogSoftmax)
@pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_grad_rewrite(self, axis):
......@@ -58,7 +61,8 @@ class TestLogSoftmaxRewrites:
# We set step to 0.1 because for big values we need a big epsilon
utt.verify_grad(myfunc, [a], eps=0.1, mode=m)
sa = shared(a)
f = function([], myfunc(sa))
f = FunctionGraph([sa], [myfunc(sa)])
_fast_run_rewrites(f)
assert check_stack_trace(f, ops_to_check="all")
def test_logsoftmax_grad_true_div_elemwise(self):
......@@ -76,38 +80,17 @@ class TestLogSoftmaxRewrites:
true_div_node = softmax_grad_node.inputs[0].owner
assert true_div_node.op == true_div
# We replace the elemwise true_div op by an elemwise add.
# We replace thk elemwise true_div op by an elemwise add.
new_g = SoftmaxGrad(axis=-1)(
add(*true_div_node.inputs), softmax_grad_node.inputs[1]
)
fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).rewrite(fgraph)
_fast_run_rewrites.rewrite(fgraph)
assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()]
def test_log1mexp_stabilization():
mode = Mode("py").including("stabilize")
x = vector()
f = function([x], log(1 - exp(x)), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert nodes == [at.log1mexp]
# Check values that would under or overflow without rewriting
assert f([-(2.0**-55)]) != -np.inf
overflow_value = -500.0 if config.floatX == "float64" else -100.0
assert f([overflow_value]) < 0
# Check values around the switch point np.log(0.5)
assert np.allclose(
f(np.array([-0.8, -0.6], dtype=config.floatX)),
np.log(1 - np.exp([-0.8, -0.6])),
)
def test_log_softmax_stabilization():
mode = aesara.compile.mode.get_default_mode()
mode = mode.including("local_log_softmax", "specialize")
......@@ -116,15 +99,17 @@ def test_log_softmax_stabilization():
y = softmax(x)
z = log(y)
f = aesara.function([x], z, mode=mode)
assert check_stack_trace(f, ops_to_check="all")
fgraph = FunctionGraph([x], [z])
_fast_run_rewrites(fgraph)
assert check_stack_trace(fgraph, ops_to_check="all")
# Check that the softmax has been rewritten
for node in f.maker.fgraph.toposort():
assert not isinstance(node.op, y.owner.op.__class__)
for node in fgraph.toposort():
assert not isinstance(node.op, Softmax)
# Call the function so debug mode can verify the rewritten version matches
# the un-rewritten version
f = aesara.function([x], z, mode=mode)
rng = np.random.default_rng(utt.fetch_seed())
f(np.cast[config.floatX](rng.random((2, 3))))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论