提交 595ed184 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Make softmax_simplifier work with arbitrary axis

上级 58cb5c30
......@@ -1127,34 +1127,61 @@ def local_softmax_with_bias(fgraph, node):
def softmax_simplifier(numerators, denominators):
for numerator in list(numerators):
# TODO: a single softmax'd vector??
if not numerator.type.dtype.startswith("float"):
continue
if numerator.ndim != 2:
continue
if numerator.owner and numerator.owner.op == exp:
x = numerator.owner.inputs[0]
else:
if not (numerator.owner and numerator.owner.op == exp):
continue
matching_denom = None
for denominator in denominators:
# Division with dimshuffle
if denominator.owner and isinstance(denominator.owner.op, DimShuffle):
if denominator.owner.op.new_order == (0, "x"):
z = denominator.owner.inputs[0]
# thing getting dimshuffled
if z.owner and isinstance(z.owner.op, Sum):
# print 'ASDF', denominator.owner.op.new_order
# print z.owner.op.axis
if z.owner.op.axis == (1,):
# print "almost there.. softmax", x, z.owner.inputs[0]
if z.owner.inputs[0] is numerator:
matching_denom = denominator
break
ds_order = denominator.owner.op.new_order
# Check that at most only one dimension is being reintroduced by
# a dimshuffle. The cases where all dimensions are reintroduced
# after a complete sum reduction end up in the else branch
if ds_order.count("x") != 1:
continue
# Check that dimshuffle does not change order of original dims
ds_order_without_x = tuple(dim for dim in ds_order if dim != "x")
if tuple(sorted(ds_order_without_x)) != ds_order_without_x:
continue
new_dim = ds_order.index("x")
z = denominator.owner.inputs[0]
if z.owner and isinstance(z.owner.op, Sum):
sum_axis = z.owner.op.axis
# Check that reintroduced dim was the one reduced
if (
(sum_axis is not None)
and (len(sum_axis) == 1)
and (sum_axis[0] == new_dim)
):
if z.owner.inputs[0] is numerator:
(sum_axis,) = sum_axis
matching_denom = denominator
break
# Division without dimshuffle
else:
z = denominator
if z.owner and isinstance(z.owner.op, Sum):
sum_axis = z.owner.op.axis
# Filter out partial summations over more than one axis
# The cases where all axis of summation are explicitly given
# as in `sum(matrix, axis=(0, 1))` are eventually rewritten
# to `sum(matrix)` and this branch is not a blocker
if sum_axis is not None and len(sum_axis) != 1:
continue
if z.owner.inputs[0] is numerator:
if sum_axis is not None:
(sum_axis,) = sum_axis
matching_denom = denominator
break
if matching_denom:
softmax = softmax_legacy(x)
softmax = Softmax(axis=sum_axis)(numerator.owner.inputs[0])
copy_stack_trace(numerator, softmax)
numerators.remove(numerator)
denominators.remove(matching_denom)
......
......@@ -1036,37 +1036,48 @@ class TestSoftmaxOpt:
self.mode = aesara.compile.mode.get_default_mode()
self.mode = self.mode.including("canonicalize")
def test_basic(self):
@pytest.mark.parametrize("axis", [None, 0, 1, -1, (0, 1)])
def test_basic(self, axis):
c = matrix()
p_y = exp(c) / exp(c).sum(axis=1).dimshuffle(0, "x")
if axis is None:
p_y = exp(c) / exp(c).sum(axis=axis).dimshuffle("x", "x")
elif axis == 0:
p_y = exp(c) / exp(c).sum(axis=axis).dimshuffle("x", 0)
elif axis == (0, 1):
p_y = exp(c) / exp(c).sum(axis=axis).dimshuffle("x", "x")
else:
p_y = exp(c) / exp(c).sum(axis=axis).dimshuffle(0, "x")
# test that function contains softmax and no div.
f = aesara.function([c], p_y, mode=self.mode)
assert check_stack_trace(f, ops_to_check=softmax_legacy)
assert check_stack_trace(f, ops_to_check=Softmax)
f_ops = [n.op for n in f.maker.fgraph.toposort()]
assert len(f_ops) == 1
assert softmax_legacy in f_ops
assert isinstance(f_ops[0], Softmax)
f(self.rng.random((3, 4)).astype(config.floatX))
c_val = self.rng.random((3, 4)).astype(config.floatX)
assert np.allclose(f(c_val), sp.softmax(c_val, axis=axis))
def test_basic_keepdims(self):
c = matrix()
p_y = exp(c) / exp(c).sum(axis=1, keepdims=True)
@pytest.mark.parametrize("axis", [None, 0, 1, 2, -1, -2, -3, (0, 1, 2)])
def test_basic_keepdims(self, axis):
c = tensor3()
p_y = exp(c) / exp(c).sum(axis=axis, keepdims=True)
# test that function contains softmax and no div.
f = aesara.function([c], p_y, mode=self.mode)
assert check_stack_trace(f, ops_to_check=softmax_legacy)
assert check_stack_trace(f, ops_to_check=Softmax)
f_ops = [n.op for n in f.maker.fgraph.toposort()]
assert len(f_ops) == 1
assert softmax_legacy in f_ops
assert isinstance(f_ops[0], Softmax)
f(self.rng.random((3, 4)).astype(config.floatX))
c_val = self.rng.random((3, 4, 5)).astype(config.floatX)
assert np.allclose(f(c_val), sp.softmax(c_val, axis=axis))
@pytest.mark.skip(reason="Optimization not enabled for the moment")
def test_grad(self):
......@@ -1076,39 +1087,83 @@ class TestSoftmaxOpt:
# test that function contains softmax and softmaxgrad
w = matrix()
g = aesara.function([c, w], grad((p_y * w).sum(), c))
g = aesara.function([c, w], grad((p_y * w).sum(), c), mode=self.mode)
g_ops = [n.op for n in g.maker.fgraph.toposort()]
assert len(g_ops) == 2
assert softmax_legacy in g_ops
assert softmax_grad_legacy in g_ops
assert len(g_ops) == 2, g_ops
assert isinstance(g_ops[0], Softmax)
assert isinstance(g_ops[1], SoftmaxGrad)
g(self.rng.random((3, 4)), self.rng.uniform(0.5, 1, (3, 4)))
@pytest.mark.skip(reason="Optimization not enabled for the moment")
def test_transpose_basic(self):
# this should be a transposed softmax
c = matrix()
p_y = exp(c) / exp(c).sum(axis=0)
# test that function contains softmax and no div.
aesara.function([c], p_y)
f = aesara.function([c], p_y, mode=self.mode)
f_ops = [n.op for n in f.maker.fgraph.toposort()]
assert len(f_ops) == 1
assert isinstance(f_ops[0], Softmax)
@pytest.mark.skip(reason="Optimization not enabled for the moment")
def test_transpose_grad(self):
# this should be a transposed softmax
c = matrix()
p_y = exp(c) / exp(c).sum(axis=0)
# test that function contains softmax and no div.
aesara.function([c], grad(p_y.sum(), c))
g = aesara.function([c], grad(p_y.sum(), c), mode=self.mode)
g_ops = [n.op for n in g.maker.fgraph.toposort()]
assert len(g_ops) == 2
assert isinstance(g_ops[0], Softmax)
assert isinstance(g_ops[1], SoftmaxGrad)
@pytest.mark.skip(reason="Optimization not enabled for the moment")
def test_1d_basic(self):
# this should be a softmax, but of a one-row matrix
c = vector()
p_y = exp(c) / exp(c).sum()
# test that function contains softmax and no div.
aesara.function([c], p_y)
f = aesara.function([c], p_y, mode=self.mode)
f_ops = [n.op for n in f.maker.fgraph.toposort()]
assert len(f_ops) == 1
assert isinstance(f_ops[0], Softmax)
@pytest.mark.skip(reason="Optimization not enabled for the moment")
def test_1D_grad(self):
c = vector()
p_y = exp(c) / exp(c).sum()
# test that function contains softmax and no div.
aesara.function([c], grad(p_y.sum(), c))
g = aesara.function([c], grad(p_y.sum(), c), mode=self.mode)
g_ops = [n.op for n in g.maker.fgraph.toposort()]
assert len(g_ops) == 2
assert isinstance(g_ops[0], Softmax)
assert isinstance(g_ops[1], SoftmaxGrad)
@pytest.mark.parametrize(
"f",
[
lambda c: exp(c) / exp(c).sum(axis=0).dimshuffle(0, 1, "x"),
lambda c: exp(c) / exp(c).sum(axis=0).dimshuffle("x", 0, 1, "x"),
lambda c: exp(c) / exp(c).sum(axis=0).dimshuffle("x", 1, 0),
lambda c: exp(c) / exp(c).sum(axis=(0, 1), keepdims=True),
],
)
def test_invalid_softmax_expressions(self, f):
# Test that graphs are not rewritten into a softmax when a dimshuffle
# swaps or adds extra dimensions, or when more than one but not all axis
# are summed over (which is not allowed by the Softmax Op but otherwise
# valid)
c = tensor3("c")
out = f(c)
f = aesara.function([c], out, mode=self.mode)
f_ops = [n.op for n in f.maker.fgraph.toposort()]
assert len(f_ops) > 1
assert not any(isinstance(op, Softmax) for op in f_ops)
def test_softmax_graph():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论