提交 24f6a4e5 authored 作者: Harm de Vries's avatar Harm de Vries

fix in local_argmax_pushdown

上级 83ea3e8f
...@@ -86,8 +86,8 @@ class SoftmaxWithBias(gof.Op): ...@@ -86,8 +86,8 @@ class SoftmaxWithBias(gof.Op):
x_plus_b = x + b[None, :] x_plus_b = x + b[None, :]
e_x = numpy.exp(x_plus_b - x_plus_b.max(axis=1)[:, None]) e_x = numpy.exp(x_plus_b - x_plus_b.max(axis=1)[:, None])
sm = e_x / e_x.sum(axis=1)[:, None] e_x *= 1.0 / e_x.sum(axis=1)[:, None]
output_storage[0][0] = sm output_storage[0][0] = e_x
def grad(self, inp, grads): def grad(self, inp, grads):
x, b = inp x, b = inp
...@@ -1469,7 +1469,7 @@ def local_softmax_grad_to_crossentropy_with_softmax_grad(node): ...@@ -1469,7 +1469,7 @@ def local_softmax_grad_to_crossentropy_with_softmax_grad(node):
def local_argmax_pushdown(node): def local_argmax_pushdown(node):
if node.op == tensor._max_and_argmax and node.inputs[0].owner and \ if node.op == tensor._max_and_argmax and node.inputs[0].owner and \
len(node.outputs[0].clients) > 0 and node.inputs[0].owner.op in \ len(node.outputs[0].clients) > 0 and node.inputs[0].owner.op in \
(softmax, softplus, tensor.exp, tensor.log, tensor.tanh, sigmoid, (softmax_op, softplus, tensor.exp, tensor.log, tensor.tanh, sigmoid,
softmax_with_bias): softmax_with_bias):
if theano.config.warn.argmax_pushdown_bug: if theano.config.warn.argmax_pushdown_bug:
logging.getLogger('theano.tensor.nnet.nnet').warn("WARNING: there " logging.getLogger('theano.tensor.nnet.nnet').warn("WARNING: there "
...@@ -1485,7 +1485,7 @@ def local_argmax_pushdown(node): ...@@ -1485,7 +1485,7 @@ def local_argmax_pushdown(node):
x_max, x_argmax = node.outputs x_max, x_argmax = node.outputs
x, axis = node.inputs x, axis = node.inputs
# TODO: Make a list/set of monotonic ops... # TODO: Make a list/set of monotonic ops...
if x.owner and x.owner.op in (softmax, softplus, tensor.exp, if x.owner and x.owner.op in (softmax_op, softplus, tensor.exp,
tensor.log, tensor.tanh, sigmoid): tensor.log, tensor.tanh, sigmoid):
pre_x, = x.owner.inputs pre_x, = x.owner.inputs
return tensor._max_and_argmax(pre_x, axis) return tensor._max_and_argmax(pre_x, axis)
......
...@@ -1011,7 +1011,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -1011,7 +1011,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
try: try:
g = theano.function([x, b, y], T.grad(expr, x), mode=mode) g = theano.function([x, b, y], T.grad(expr, x), mode=mode)
finally: finally:
config.warn.sum_div_dimshuffle_qbug = backup config.warn.sum_div_dimshuffle_bug = backup
if verbose: if verbose:
printing.debugprint(g) printing.debugprint(g)
...@@ -1026,7 +1026,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -1026,7 +1026,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
theano.printing.debugprint(g) theano.printing.debugprint(g)
raise raise
def test_scrossentropy_softmax_1hot_with_bias_dxcale_cost(self): def test_crossentropy_softmax_1hot_with_bias_dxcale_cost(self):
# TODO: add the optimization in FAST_COMPILE? # TODO: add the optimization in FAST_COMPILE?
# In the mean time, run it as 'FAST_RUN' instead # In the mean time, run it as 'FAST_RUN' instead
mode = theano.compile.mode.get_default_mode() mode = theano.compile.mode.get_default_mode()
...@@ -1130,7 +1130,7 @@ def test_argmax_pushdown(): ...@@ -1130,7 +1130,7 @@ def test_argmax_pushdown():
# test that the max_and_argmax is pushed down if the max is not used # test that the max_and_argmax is pushed down if the max is not used
out = tensor.max_and_argmax( out = tensor.max_and_argmax(
softmax(tensor.exp(tensor.tanh(sigmoid(x)))), softmax_graph(tensor.exp(tensor.tanh(sigmoid(x)))),
axis=-1)[1] axis=-1)[1]
fgraph = gof.FunctionGraph( fgraph = gof.FunctionGraph(
[x], [x],
...@@ -1147,7 +1147,7 @@ def test_argmax_pushdown(): ...@@ -1147,7 +1147,7 @@ def test_argmax_pushdown():
x = tensor.matrix() x = tensor.matrix()
# test that the max_and_argmax is not pushed down if the max is used # test that the max_and_argmax is not pushed down if the max is used
out = tensor.max_and_argmax( out = tensor.max_and_argmax(
softmax(tensor.exp(tensor.tanh(sigmoid(x)))), softmax_op(tensor.exp(tensor.tanh(sigmoid(x)))),
axis=-1)[0] axis=-1)[0]
fgraph = gof.FunctionGraph( fgraph = gof.FunctionGraph(
[x], [x],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论