提交 06887a98 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

New 'shape_lift' optimization : shape(softmax(x)) -> shape(x)

上级 7bc4bb39
......@@ -437,6 +437,22 @@ def local_softmax_with_bias(node):
return
return [sm_bias]
@gof.local_optimizer([tensor._shape])
def local_shape_lift_softmax(node):
"""shape(softmax(x)) -> shape(x)
shape(softmax(x+b)) -> shape(x)"""
if node.op == tensor._shape:
sm = node.inputs[0]
if sm.owner and sm.owner in (softmax, softmax_with_bias):
sm_w_bias = local_softmax_with_bias.transform(sm.owner)
if sm_w_bias:
assert sm_w_bias[0].owner.op == softmax_with_bias
x_var, b_var = sm_w_bias[0].owner.inputs
else:
x_var = sm.owner.inputs[0]
return tensor.shape(x_var),
opt.register_specialize(local_shape_lift_softmax, 'shape_lift')
class CrossentropySoftmaxArgmax1HotWithBias(gof.Op):
"""A special compound L{Op} for the output of neural-net classifiers.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论