提交 e4813551 authored 作者: Francesco Visin's avatar Francesco Visin

Fix use crossentropy_onehot_grad instead of logsoftmax

Do not optimize LogSoftmax when local_advanced_indexing_crossentropy_onehot_grad can be applied
上级 49cf5b43
...@@ -678,7 +678,15 @@ def local_logsoftmax_grad(node): ...@@ -678,7 +678,15 @@ def local_logsoftmax_grad(node):
len(node.inputs) == 2 and len(node.inputs) == 2 and
isinstance(node.inputs[0].owner.op, tensor.Elemwise) and isinstance(node.inputs[0].owner.op, tensor.Elemwise) and
node.inputs[0].owner.inputs[1].owner.op == Softmax() and node.inputs[0].owner.inputs[1].owner.op == Softmax() and
node.inputs[1] == node.inputs[0].owner.inputs[1]): node.inputs[1] == node.inputs[0].owner.inputs[1]) and not(
# skip if it will be optimized by
# local_advanced_indexing_crossentropy_onehot_grad
node.inputs[1].owner and node.inputs[1].owner.op in
(softmax_op, softmax_with_bias) and
node.inputs[0].owner.op == tensor.true_div and
node.inputs[0].owner.inputs[1] == node.inputs[1] and
isinstance(node.inputs[0].owner.inputs[0].owner.op,
subtensor.AdvancedIncSubtensor)):
# get parameters from unoptimized op # get parameters from unoptimized op
sm = node.inputs[0].owner.inputs[1] sm = node.inputs[0].owner.inputs[1]
# sm_input = node.inputs[1].owner.inputs[0] # sm_input = node.inputs[1].owner.inputs[0]
...@@ -1568,7 +1576,7 @@ def local_advanced_indexing_crossentropy_onehot(node): ...@@ -1568,7 +1576,7 @@ def local_advanced_indexing_crossentropy_onehot(node):
sm = log.owner.inputs[0] sm = log.owner.inputs[0]
# Second case: log(softmax(x)[rows, labels]) # Second case: log(softmax(x)[rows, labels])
if node.op == tensor.log: elif node.op == tensor.log:
pre_log = node.inputs[0].owner pre_log = node.inputs[0].owner
if pre_log and isinstance(pre_log.op, subtensor.AdvancedSubtensor): if pre_log and isinstance(pre_log.op, subtensor.AdvancedSubtensor):
try: try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论