提交 bafd9638 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Make local_softmax_grad rewrite work with arbitrary axis

上级 57c388a7
...@@ -1098,8 +1098,7 @@ def local_logsoftmax_grad(fgraph, node): ...@@ -1098,8 +1098,7 @@ def local_logsoftmax_grad(fgraph, node):
and node.inputs[0].owner.op == true_div and node.inputs[0].owner.op == true_div
and len(node.inputs[0].owner.inputs) >= 2 and len(node.inputs[0].owner.inputs) >= 2
and node.inputs[0].owner.inputs[1].owner is not None and node.inputs[0].owner.inputs[1].owner is not None
and node.inputs[0].owner.inputs[1].owner.op == softmax_legacy and isinstance(node.inputs[0].owner.inputs[1].owner.op, Softmax)
and node.inputs[0].owner.inputs[1].ndim == 2
and node.inputs[1] == node.inputs[0].owner.inputs[1] and node.inputs[1] == node.inputs[0].owner.inputs[1]
and not ( and not (
# skip if it will be optimized by # skip if it will be optimized by
...@@ -1109,15 +1108,14 @@ def local_logsoftmax_grad(fgraph, node): ...@@ -1109,15 +1108,14 @@ def local_logsoftmax_grad(fgraph, node):
and isinstance( and isinstance(
node.inputs[0].owner.inputs[0].owner.op, AdvancedIncSubtensor node.inputs[0].owner.inputs[0].owner.op, AdvancedIncSubtensor
) )
# the rewrite only applies to legacy SoftmaxGrad
and node.op == softmax_grad_legacy
and node.inputs[0].owner.inputs[1].ndim == 2
) )
): ):
# get parameters from unoptimized op # get parameters from unoptimized op
sm = node.inputs[0].owner.inputs[1] grads, sm = node.inputs[0].owner.inputs
# sm_input = node.inputs[1].owner.inputs[0] ret = grads - aet_sum(grads, axis=sm.owner.op.axis, keepdims=True) * sm
grads = node.inputs[0].owner.inputs[0]
if grads.broadcastable[1] and not sm.broadcastable[1]:
grads = aet.alloc(grads, grads.shape[0], sm.shape[1])
ret = grads - aet_sum(grads, axis=1, keepdims=True) * sm
ret.tag.values_eq_approx = values_eq_approx_remove_nan ret.tag.values_eq_approx = values_eq_approx_remove_nan
copy_stack_trace(node.outputs[0], ret) copy_stack_trace(node.outputs[0], ret)
return [ret] return [ret]
......
...@@ -278,7 +278,8 @@ class TestLogSoftmax(utt.InferShapeTester): ...@@ -278,7 +278,8 @@ class TestLogSoftmax(utt.InferShapeTester):
assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax) assert isinstance(f.maker.fgraph.outputs[0].owner.op, LogSoftmax)
assert check_stack_trace(f, ops_to_check=LogSoftmax) assert check_stack_trace(f, ops_to_check=LogSoftmax)
def test_local_softmax_grad_optimization_and_big_input(self): @pytest.mark.parametrize("axis", [None, 0, -1])
def test_local_logsoftmax_grad_opt(self, axis):
# Test the Logsoftmax's grad substitution. # Test the Logsoftmax's grad substitution.
# #
# Check that Log(Softmax(x))'s grad is substituted with Logsoftmax(x)'s # Check that Log(Softmax(x))'s grad is substituted with Logsoftmax(x)'s
...@@ -294,7 +295,7 @@ class TestLogSoftmax(utt.InferShapeTester): ...@@ -294,7 +295,7 @@ class TestLogSoftmax(utt.InferShapeTester):
a = np.exp(10 * rng.random((5, 10)).astype(config.floatX)) a = np.exp(10 * rng.random((5, 10)).astype(config.floatX))
def myfunc(x): def myfunc(x):
sm = softmax(x) sm = softmax(x, axis=axis)
logsm = log(sm) logsm = log(sm)
return logsm return logsm
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论