提交 8b447002 authored 作者: fvisin's avatar fvisin 提交者: Francesco Visin

Add optimization for grad

上级 f9eb767b
......@@ -652,15 +652,43 @@ def local_logsoftmax(node):
Note: only forward pass is affected
"""
if (isinstance(node.op, tensor.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Log) and
len(node.inputs) == 1 and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, Softmax)):
# what is --> and len(node.inputs[0].owner.out.clients) == 1):
inVars = node.inputs[0].owner.inputs[0]
new_op = LogSoftmax()
return [new_op(inVars)]
try:
if (isinstance(node.op, tensor.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Log) and
len(node.inputs) == 1 and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, Softmax)):
inVars = node.inputs[0].owner.inputs[0]
new_op = LogSoftmax()
return [new_op(inVars)]
except AttributeError:
pass
@opt.register_specialize('stabilize')
@gof.local_optimizer([SoftmaxGrad])
def local_logsoftmax_grad(node):
"""
Detect Log(Softmax(x))'s grad and replace it with LogSoftmax(x)'s grad
Note: only grad is affected
"""
try:
if (isinstance(node.op, SoftmaxGrad) and
len(node.inputs) == 2 and
isinstance(node.inputs[0].owner.op, tensor.Elemwise) and
node.inputs[0].owner.inputs[1].owner.op == Softmax() and
node.inputs[1] == node.inputs[0].owner.inputs[1]):
# get parameters from unoptimized op
sm = node.inputs[0].owner.inputs[1]
# sm_input = node.inputs[1].owner.inputs[0]
grads = node.inputs[0].owner.inputs[0]
if grads.broadcastable[1] and not sm.broadcastable[1]:
grads = tensor.alloc(grads, grads.shape[0], sm.shape[1])
return [grads - tensor.sum(grads, axis=1, keepdims=True) * sm]
except AttributeError:
pass
def softmax_graph(c):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论