提交 526faff7 authored 作者: AdeB's avatar AdeB

Fix the optim to replace the gradient of log softmax which didnt check true_div explicitly

上级 ecfc65ec
...@@ -773,7 +773,7 @@ def local_logsoftmax_grad(node): ...@@ -773,7 +773,7 @@ def local_logsoftmax_grad(node):
if (isinstance(node.op, SoftmaxGrad) and if (isinstance(node.op, SoftmaxGrad) and
len(node.inputs) == 2 and len(node.inputs) == 2 and
node.inputs[0].owner is not None and node.inputs[0].owner is not None and
isinstance(node.inputs[0].owner.op, tensor.Elemwise) and node.inputs[0].owner.op == tensor.true_div and
len(node.inputs[0].owner.inputs) >= 2 and len(node.inputs[0].owner.inputs) >= 2 and
node.inputs[0].owner.inputs[1].owner is not None and node.inputs[0].owner.inputs[1].owner is not None and
node.inputs[0].owner.inputs[1].owner.op == softmax_op and node.inputs[0].owner.inputs[1].owner.op == softmax_op and
......
...@@ -287,6 +287,30 @@ class T_LogSoftmax(utt.InferShapeTester): ...@@ -287,6 +287,30 @@ class T_LogSoftmax(utt.InferShapeTester):
f = theano.function([], myfunc(sa)) f = theano.function([], myfunc(sa))
self.assertTrue(check_stack_trace(f, ops_to_check='all')) self.assertTrue(check_stack_trace(f, ops_to_check='all'))
def test_logsoftmax_grad_true_div_elemwise(self):
# Checks that the gradient of an expression similar to a log(softmax)
# but with a different elemwise operation than true_div is not
# optimized.
x = T.matrix('x')
y = T.log(T.nnet.softmax(x))
g = T.grad(y.sum(), x)
softmax_grad_node = g.owner
assert softmax_grad_node.op == softmax_grad
true_div_node = softmax_grad_node.inputs[0].owner
assert true_div_node.op == tensor.true_div
# We replace the elemwise true_div op by an elemwise add.
new_g = softmax_grad(tensor.add(*true_div_node.inputs),
softmax_grad_node.inputs[1])
fgraph = gof.FunctionGraph([x], [new_g])
theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(fgraph)
assert softmax_grad in [n.op for n in fgraph.toposort()]
class T_SoftmaxGrad(utt.InferShapeTester): class T_SoftmaxGrad(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论