提交 9f000926 authored 作者: Harm de Vries's avatar Harm de Vries

Replaced softmax with either softmax_op or softmax_graph

上级 c6ccaeeb
...@@ -413,7 +413,7 @@ class Softmax(gof.Op): ...@@ -413,7 +413,7 @@ class Softmax(gof.Op):
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
g_sm, = grads g_sm, = grads
sm = softmax(x) sm = softmax_op(x)
return [softmax_grad(g_sm, sm)] return [softmax_grad(g_sm, sm)]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
...@@ -578,7 +578,7 @@ def softmax_graph(c): ...@@ -578,7 +578,7 @@ def softmax_graph(c):
def local_softmax_with_bias(node): def local_softmax_with_bias(node):
"""Try to turn softmax(sum_of_stuff) -> softmax_w_bias(matrix, bias) """Try to turn softmax(sum_of_stuff) -> softmax_w_bias(matrix, bias)
""" """
if node.op == softmax: if node.op == softmax_op:
x, = node.inputs x, = node.inputs
if x.owner and x.owner.op == tensor.add: if x.owner and x.owner.op == tensor.add:
vectors = [] vectors = []
...@@ -1406,7 +1406,7 @@ def crossentropy_to_crossentropy_with_softmax(fgraph): ...@@ -1406,7 +1406,7 @@ def crossentropy_to_crossentropy_with_softmax(fgraph):
if node.op == crossentropy_categorical_1hot: if node.op == crossentropy_categorical_1hot:
nll, = node.outputs nll, = node.outputs
sm, one_of_n = node.inputs sm, one_of_n = node.inputs
if sm.owner and sm.owner.op == softmax: if sm.owner and sm.owner.op == softmax_op:
x, = sm.owner.inputs x, = sm.owner.inputs
new_nll, new_sm, new_am = crossentropy_softmax_argmax_1hot_with_bias(x, new_nll, new_sm, new_am = crossentropy_softmax_argmax_1hot_with_bias(x,
tensor.zeros_like(x[0]), one_of_n) tensor.zeros_like(x[0]), one_of_n)
...@@ -1556,7 +1556,7 @@ def local_advanced_indexing_crossentropy_onehot(node): ...@@ -1556,7 +1556,7 @@ def local_advanced_indexing_crossentropy_onehot(node):
except Exception: except Exception:
pass pass
if sm is not None and sm.owner and sm.owner.op in (softmax, if sm is not None and sm.owner and sm.owner.op in (softmax_op,
softmax_with_bias): softmax_with_bias):
sm_w_bias = local_softmax_with_bias.transform(sm.owner) sm_w_bias = local_softmax_with_bias.transform(sm.owner)
if sm_w_bias: if sm_w_bias:
...@@ -1586,7 +1586,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node): ...@@ -1586,7 +1586,7 @@ def local_advanced_indexing_crossentropy_onehot_grad(node):
except Exception: except Exception:
return return
if (sm is not None) and sm.owner and (sm.owner.op in (softmax, if (sm is not None) and sm.owner and (sm.owner.op in (softmax_op,
softmax_with_bias)): softmax_with_bias)):
sm_w_bias = local_softmax_with_bias.transform(sm.owner) sm_w_bias = local_softmax_with_bias.transform(sm.owner)
if sm_w_bias: if sm_w_bias:
...@@ -2056,7 +2056,7 @@ def make_out_pattern(X): ...@@ -2056,7 +2056,7 @@ def make_out_pattern(X):
return out_var return out_var
local_log_softmax = gof.PatternSub(in_pattern=(tensor.log, (softmax, 'x')), local_log_softmax = gof.PatternSub(in_pattern=(tensor.log, (softmax_op, 'x')),
out_pattern=(make_out_pattern, 'x'), out_pattern=(make_out_pattern, 'x'),
allow_multiple_clients=True) allow_multiple_clients=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论