提交 4a2b55eb authored 作者: Harm de Vries's avatar Harm de Vries

Replace op with graph, added test for testing 2nd derivative

上级 9f000926
...@@ -738,7 +738,7 @@ class T_Scan(unittest.TestCase): ...@@ -738,7 +738,7 @@ class T_Scan(unittest.TestCase):
def forward_scanner(x_t): def forward_scanner(x_t):
a2_t = tensor.dot(x_t, W) a2_t = tensor.dot(x_t, W)
y_t = tensor.nnet.softmax(a2_t) y_t = tensor.nnet.softmax_graph(a2_t)
return y_t return y_t
y, _ = theano.scan(fn=forward_scanner, sequences=x, y, _ = theano.scan(fn=forward_scanner, sequences=x,
......
...@@ -570,7 +570,7 @@ class Softmax(gof.Op): ...@@ -570,7 +570,7 @@ class Softmax(gof.Op):
softmax_op = Softmax() softmax_op = Softmax()
def softmax_graph(c): def softmax_graph(c):
return tensor.exp(c) / tensor.exp(c).sum(axis=1, keepdims=True) return tensor.exp(c) / tensor.exp(c).sum(axis=-1, keepdims=True)
@opt.register_specialize('fast_compile_gpu') @opt.register_specialize('fast_compile_gpu')
...@@ -666,7 +666,7 @@ def softmax_simplifier(numerators, denominators): ...@@ -666,7 +666,7 @@ def softmax_simplifier(numerators, denominators):
if matching_denom: if matching_denom:
numerators.remove(numerator) numerators.remove(numerator)
denominators.remove(matching_denom) denominators.remove(matching_denom)
numerators.append(softmax(x)) numerators.append(softmax_op(x))
return numerators, denominators return numerators, denominators
opt.local_mul_canonizer.add_simplifier(softmax_simplifier, opt.local_mul_canonizer.add_simplifier(softmax_simplifier,
'softmax_simplifier') 'softmax_simplifier')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论