提交 900546ef authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Fix the rewrite tests

上级 ff1a3a9d
from aesara.tensor.rewriting.basic import (
register_specialize,
)
from aesara import scalar as aes from aesara import scalar as aes
from aesara.tensor.math import true_div, exp, Sum from aesara.graph.rewriting.basic import copy_stack_trace, node_rewriter
from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.math import Sum, exp
from aesara.tensor.math import sum as at_sum
from aesara.tensor.math import true_div
from aesara.tensor.rewriting.basic import register_specialize
from aesara.tensor.rewriting.math import local_mul_canonizer from aesara.tensor.rewriting.math import local_mul_canonizer
from aesara.graph.rewriting.basic import node_rewriter, copy_stack_trace from aesara.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.subtensor import AdvancedIncSubtensor from aesara.tensor.subtensor import AdvancedIncSubtensor
from aesara.tensor.elemwise import Elemwise, DimShuffle from aesara.tensor.type import values_eq_approx_remove_inf, values_eq_approx_remove_nan
# This is not registered in stabilize, as it cause some crossentropy # This is not registered in stabilize, as it cause some crossentropy
...@@ -63,7 +64,7 @@ def local_logsoftmax_grad(fgraph, node): ...@@ -63,7 +64,7 @@ def local_logsoftmax_grad(fgraph, node):
node.inputs[0].owner.inputs[0].owner.op, AdvancedIncSubtensor node.inputs[0].owner.inputs[0].owner.op, AdvancedIncSubtensor
) )
# the rewrite only applies to legacy SoftmaxGrad # the rewrite only applies to legacy SoftmaxGrad
and node.op == softmax_grad_legacy and node.op == SoftmaxGrad(axis=-1)
and node.inputs[0].owner.inputs[1].ndim == 2 and node.inputs[0].owner.inputs[1].ndim == 2
) )
): ):
......
...@@ -2,11 +2,9 @@ import numpy as np ...@@ -2,11 +2,9 @@ import numpy as np
import pytest import pytest
import aesara import aesara
import aesara.tensor as at
from aesara import shared from aesara import shared
from aesara.compile import optdb from aesara.compile import optdb
from aesara.compile.function import function from aesara.compile.mode import get_mode
from aesara.compile.mode import OPT_FAST_RUN, Mode, get_mode
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.basic import check_stack_trace from aesara.graph.rewriting.basic import check_stack_trace
...@@ -17,6 +15,10 @@ from aesara.tensor.type import matrix ...@@ -17,6 +15,10 @@ from aesara.tensor.type import matrix
from tests import unittest_tools as utt from tests import unittest_tools as utt
_fast_run_rewrites = RewriteDatabaseQuery(include=["fast_run"])
_fast_run_rewrites = optdb.query(_fast_run_rewrites)
class TestLogSoftmaxRewrites: class TestLogSoftmaxRewrites:
@pytest.mark.parametrize("axis", [None, 0, -1]) @pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_rewrite(self, axis): def test_local_logsoftmax_rewrite(self, axis):
...@@ -29,9 +31,10 @@ class TestLogSoftmaxRewrites: ...@@ -29,9 +31,10 @@ class TestLogSoftmaxRewrites:
x = matrix("x") x = matrix("x")
sm = softmax(x, axis=axis) sm = softmax(x, axis=axis)
logsm = log(sm) logsm = log(sm)
f = function([x], logsm) fgraph = FunctionGraph([x], [logsm])
assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax) _fast_run_rewrites.rewrite(fgraph)
assert check_stack_trace(f, ops_to_check=LogSoftmax) assert isinstance(fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(fgraph, ops_to_check=LogSoftmax)
@pytest.mark.parametrize("axis", [None, 0, -1]) @pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_grad_rewrite(self, axis): def test_local_logsoftmax_grad_rewrite(self, axis):
...@@ -58,7 +61,8 @@ class TestLogSoftmaxRewrites: ...@@ -58,7 +61,8 @@ class TestLogSoftmaxRewrites:
# We set step to 0.1 because for big values we need a big epsilon # We set step to 0.1 because for big values we need a big epsilon
utt.verify_grad(myfunc, [a], eps=0.1, mode=m) utt.verify_grad(myfunc, [a], eps=0.1, mode=m)
sa = shared(a) sa = shared(a)
f = function([], myfunc(sa)) f = FunctionGraph([sa], [myfunc(sa)])
_fast_run_rewrites(f)
assert check_stack_trace(f, ops_to_check="all") assert check_stack_trace(f, ops_to_check="all")
def test_logsoftmax_grad_true_div_elemwise(self): def test_logsoftmax_grad_true_div_elemwise(self):
...@@ -76,38 +80,17 @@ class TestLogSoftmaxRewrites: ...@@ -76,38 +80,17 @@ class TestLogSoftmaxRewrites:
true_div_node = softmax_grad_node.inputs[0].owner true_div_node = softmax_grad_node.inputs[0].owner
assert true_div_node.op == true_div assert true_div_node.op == true_div
# We replace the elemwise true_div op by an elemwise add. # We replace thk elemwise true_div op by an elemwise add.
new_g = SoftmaxGrad(axis=-1)( new_g = SoftmaxGrad(axis=-1)(
add(*true_div_node.inputs), softmax_grad_node.inputs[1] add(*true_div_node.inputs), softmax_grad_node.inputs[1]
) )
fgraph = FunctionGraph([x], [new_g]) fgraph = FunctionGraph([x], [new_g])
optdb.query(OPT_FAST_RUN).rewrite(fgraph) _fast_run_rewrites.rewrite(fgraph)
assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()] assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()]
def test_log1mexp_stabilization():
mode = Mode("py").including("stabilize")
x = vector()
f = function([x], log(1 - exp(x)), mode=mode)
nodes = [node.op for node in f.maker.fgraph.toposort()]
assert nodes == [at.log1mexp]
# Check values that would under or overflow without rewriting
assert f([-(2.0**-55)]) != -np.inf
overflow_value = -500.0 if config.floatX == "float64" else -100.0
assert f([overflow_value]) < 0
# Check values around the switch point np.log(0.5)
assert np.allclose(
f(np.array([-0.8, -0.6], dtype=config.floatX)),
np.log(1 - np.exp([-0.8, -0.6])),
)
def test_log_softmax_stabilization(): def test_log_softmax_stabilization():
mode = aesara.compile.mode.get_default_mode() mode = aesara.compile.mode.get_default_mode()
mode = mode.including("local_log_softmax", "specialize") mode = mode.including("local_log_softmax", "specialize")
...@@ -116,15 +99,17 @@ def test_log_softmax_stabilization(): ...@@ -116,15 +99,17 @@ def test_log_softmax_stabilization():
y = softmax(x) y = softmax(x)
z = log(y) z = log(y)
f = aesara.function([x], z, mode=mode) fgraph = FunctionGraph([x], [z])
assert check_stack_trace(f, ops_to_check="all") _fast_run_rewrites(fgraph)
assert check_stack_trace(fgraph, ops_to_check="all")
# Check that the softmax has been rewritten # Check that the softmax has been rewritten
for node in f.maker.fgraph.toposort(): for node in fgraph.toposort():
assert not isinstance(node.op, y.owner.op.__class__) assert not isinstance(node.op, Softmax)
# Call the function so debug mode can verify the rewritten version matches # Call the function so debug mode can verify the rewritten version matches
# the un-rewritten version # the un-rewritten version
f = aesara.function([x], z, mode=mode)
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
f(np.cast[config.floatX](rng.random((2, 3)))) f(np.cast[config.floatX](rng.random((2, 3))))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论