提交 ce0b503c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Extend log_softmax rewrite and run it in `stabilize`

上级 39bda72a
from pytensor import scalar as aes
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import Sum, exp from pytensor.tensor.math import Sum, exp, log
from pytensor.tensor.math import sum as at_sum from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import true_div from pytensor.tensor.math import true_div
from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.basic import register_stabilize
from pytensor.tensor.rewriting.math import local_mul_canonizer from pytensor.tensor.rewriting.math import local_mul_canonizer
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import Softmax, SoftmaxGrad, log_softmax
from pytensor.tensor.subtensor import AdvancedIncSubtensor from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedSubtensor,
AdvancedSubtensor1,
Subtensor,
)
from pytensor.tensor.type import ( from pytensor.tensor.type import (
values_eq_approx_remove_inf, values_eq_approx_remove_inf,
values_eq_approx_remove_nan, values_eq_approx_remove_nan,
) )
# This is not registered in stabilize, as it cause some crossentropy subtensor_ops = (
# optimization to not be inserted. Subtensor,
@register_specialize("stabilize", "fast_compile") AdvancedSubtensor,
@node_rewriter([Elemwise]) AdvancedSubtensor1,
)
@register_stabilize
@node_rewriter([log])
def local_logsoftmax(fgraph, node): def local_logsoftmax(fgraph, node):
""" """
Detect Log(Softmax(x)) and replace it with LogSoftmax(x) Detect Log(Softmax(x)) and replace it with LogSoftmax(x)
This also lifts Subtensor or Dimshuffle operations that could be in between log and softmax
Note: only forward pass is affected Note: only forward pass is affected
""" """
if (
isinstance(node.op, Elemwise) def find_softmax_under_lifteable_ops(inp_node, ops_to_lift):
and isinstance(node.op.scalar_op, aes.Log) if inp_node is None:
and len(node.inputs) == 1 return
and node.inputs[0].owner is not None
and isinstance(node.inputs[0].owner.op, Softmax) if isinstance(inp_node.op, Softmax):
): return inp_node
inVars = node.inputs[0].owner.inputs[0]
new_op = LogSoftmax(axis=node.inputs[0].owner.op.axis) if isinstance(inp_node.op, subtensor_ops):
ret = new_op(inVars) ops_to_lift.append((inp_node.op, inp_node.inputs[1:]))
ret.tag.values_eq_approx = values_eq_approx_remove_inf return find_softmax_under_lifteable_ops(
copy_stack_trace([node.inputs[0], node.outputs[0]], ret) inp_node.inputs[0].owner, ops_to_lift
return [ret] )
if isinstance(inp_node.op, DimShuffle):
ops_to_lift.append((inp_node.op, ()))
return find_softmax_under_lifteable_ops(
inp_node.inputs[0].owner, ops_to_lift
)
ops_to_lift = []
softmax_node = find_softmax_under_lifteable_ops(node.inputs[0].owner, ops_to_lift)
if softmax_node is None:
return
ret = log_softmax(softmax_node.inputs[0], axis=softmax_node.op.axis)
ret.tag.values_eq_approx = values_eq_approx_remove_inf
# Lift ops that used to be between log and softmax
for op_to_lift, parameters in reversed(ops_to_lift):
ret = op_to_lift(ret, *parameters)
copy_stack_trace(node.outputs, ret)
return [ret]
# This is not registered in stabilize, as it cause some crossentropy @register_stabilize
# optimization to not be inserted.
@register_specialize("stabilize", "fast_compile")
@node_rewriter([SoftmaxGrad]) @node_rewriter([SoftmaxGrad])
def local_logsoftmax_grad(fgraph, node): def local_logsoftmax_grad(fgraph, node):
""" """
...@@ -50,9 +81,7 @@ def local_logsoftmax_grad(fgraph, node): ...@@ -50,9 +81,7 @@ def local_logsoftmax_grad(fgraph, node):
Note: only grad is affected Note: only grad is affected
""" """
if ( if (
isinstance(node.op, SoftmaxGrad) node.inputs[0].owner is not None
and len(node.inputs) == 2
and node.inputs[0].owner is not None
and node.inputs[0].owner.op == true_div and node.inputs[0].owner.op == true_div
and len(node.inputs[0].owner.inputs) >= 2 and len(node.inputs[0].owner.inputs) >= 2
and node.inputs[0].owner.inputs[1].owner is not None and node.inputs[0].owner.inputs[1].owner is not None
......
import numpy as np import numpy as np
import pytest import pytest
import scipy.special
import pytensor import pytensor
from pytensor import shared from pytensor import shared
...@@ -35,6 +36,37 @@ class TestLogSoftmaxRewrites: ...@@ -35,6 +36,37 @@ class TestLogSoftmaxRewrites:
_fast_run_rewrites.rewrite(fgraph) _fast_run_rewrites.rewrite(fgraph)
assert isinstance(fgraph.outputs[0].owner.op, LogSoftmax) assert isinstance(fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(fgraph, ops_to_check=LogSoftmax) assert check_stack_trace(fgraph, ops_to_check=LogSoftmax)
assert check_stack_trace(fgraph, ops_to_check="all")
@pytest.mark.parametrize("axis", [None, 0, -1])
@pytest.mark.parametrize("idx0", [0, slice(1, None), slice(None)])
@pytest.mark.parametrize("idx1", [None, [0, 1, 1, -1]])
def test_logsoftmax_subtensor_dimshuffle(self, axis, idx0, idx1):
"""Test that stabilization is introduced even when subtensor or dimshuffle operations
are present between log and softmax.
"""
logit_p = matrix("logit_p")
p = softmax(logit_p, axis=axis)
p_indexed = p[(idx0, idx1)]
out = log(p_indexed)
# Don't waste time with C compilation
with config.change_flags(cxx=""):
mode = get_mode(None).including("stabilize")
fn = pytensor.function([logit_p], out, mode=mode)
assert not any(
isinstance(node.op, Softmax) for node in fn.maker.fgraph.apply_nodes
)
# This range would lead to underflow to -inf without the stabilization
test_logit_p = np.array(
[[-10.0, -10.0, 999.0], [999.0, 990.0, -10.0]], dtype=config.floatX
)
np.testing.assert_allclose(
fn(logit_p=test_logit_p),
scipy.special.log_softmax(test_logit_p, axis=axis)[(idx0, idx1)],
)
@pytest.mark.parametrize("axis", [None, 0, -1]) @pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_grad_rewrite(self, axis): def test_local_logsoftmax_grad_rewrite(self, axis):
...@@ -46,7 +78,7 @@ class TestLogSoftmaxRewrites: ...@@ -46,7 +78,7 @@ class TestLogSoftmaxRewrites:
""" """
m = config.mode m = config.mode
m = get_mode(m) m = get_mode(m).including("stabilize")
m.check_isfinite = False m.check_isfinite = False
# some inputs that are large to make the gradient explode in the non # some inputs that are large to make the gradient explode in the non
# rewritten case # rewritten case
...@@ -91,29 +123,6 @@ class TestLogSoftmaxRewrites: ...@@ -91,29 +123,6 @@ class TestLogSoftmaxRewrites:
assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()] assert SoftmaxGrad(axis=-1) in [n.op for n in fgraph.toposort()]
def test_log_softmax_stabilization():
mode = pytensor.compile.mode.get_default_mode()
mode = mode.including("local_log_softmax", "specialize")
x = matrix()
y = softmax(x, axis=-1)
z = log(y)
fgraph = FunctionGraph([x], [z])
_fast_run_rewrites(fgraph)
assert check_stack_trace(fgraph, ops_to_check="all")
# Check that the softmax has been rewritten
for node in fgraph.toposort():
assert not isinstance(node.op, Softmax)
# Call the function so debug mode can verify the rewritten version matches
# the un-rewritten version
f = pytensor.function([x], z, mode=mode)
rng = np.random.default_rng(utt.fetch_seed())
f(np.cast[config.floatX](rng.random((2, 3))))
def test_softmax_graph(): def test_softmax_graph():
"""Make sure that sotfmax expressions are turned into """Make sure that sotfmax expressions are turned into
a softmax Op. a softmax Op.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论