提交 ea627155 authored 作者: James Bergstra's avatar James Bergstra

nnet opt - added local optimizer to merge softmax with crossentropy_softmax

上级 ec9ece68
......@@ -1251,6 +1251,18 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
else:
return
@opt.register_specialize
@gof.local_optimizer([softmax_with_bias])
def graph_merge_softmax_with_crossentropy_softmax(node):
if node.op == softmax_with_bias:
x, b = node.inputs
for x_client in x.clients:
if x_client[0].op == crossentropy_softmax_argmax_1hot_with_bias:
big_client = x_client[0]
if big_client in [b_client[0] for b_client in b.clients]:
xx, bb, ll = big_client.inputs
mergeable_client = big_client.op(x, b, ll)
return [mergeable_client[1]]
def binary_crossentropy(output, target):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论