提交 24f6a4e5 authored 作者: Harm de Vries's avatar Harm de Vries

fix in local_argmax_pushdown

上级 83ea3e8f
......@@ -86,8 +86,8 @@ class SoftmaxWithBias(gof.Op):
x_plus_b = x + b[None, :]
e_x = numpy.exp(x_plus_b - x_plus_b.max(axis=1)[:, None])
sm = e_x / e_x.sum(axis=1)[:, None]
output_storage[0][0] = sm
e_x *= 1.0 / e_x.sum(axis=1)[:, None]
output_storage[0][0] = e_x
def grad(self, inp, grads):
x, b = inp
......@@ -1469,7 +1469,7 @@ def local_softmax_grad_to_crossentropy_with_softmax_grad(node):
def local_argmax_pushdown(node):
if node.op == tensor._max_and_argmax and node.inputs[0].owner and \
len(node.outputs[0].clients) > 0 and node.inputs[0].owner.op in \
(softmax, softplus, tensor.exp, tensor.log, tensor.tanh, sigmoid,
(softmax_op, softplus, tensor.exp, tensor.log, tensor.tanh, sigmoid,
softmax_with_bias):
if theano.config.warn.argmax_pushdown_bug:
logging.getLogger('theano.tensor.nnet.nnet').warn("WARNING: there "
......@@ -1485,7 +1485,7 @@ def local_argmax_pushdown(node):
x_max, x_argmax = node.outputs
x, axis = node.inputs
# TODO: Make a list/set of monotonic ops...
if x.owner and x.owner.op in (softmax, softplus, tensor.exp,
if x.owner and x.owner.op in (softmax_op, softplus, tensor.exp,
tensor.log, tensor.tanh, sigmoid):
pre_x, = x.owner.inputs
return tensor._max_and_argmax(pre_x, axis)
......
......@@ -1011,7 +1011,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
try:
g = theano.function([x, b, y], T.grad(expr, x), mode=mode)
finally:
config.warn.sum_div_dimshuffle_qbug = backup
config.warn.sum_div_dimshuffle_bug = backup
if verbose:
printing.debugprint(g)
......@@ -1026,7 +1026,7 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
theano.printing.debugprint(g)
raise
def test_scrossentropy_softmax_1hot_with_bias_dxcale_cost(self):
def test_crossentropy_softmax_1hot_with_bias_dxcale_cost(self):
# TODO: add the optimization in FAST_COMPILE?
# In the mean time, run it as 'FAST_RUN' instead
mode = theano.compile.mode.get_default_mode()
......@@ -1130,7 +1130,7 @@ def test_argmax_pushdown():
# test that the max_and_argmax is pushed down if the max is not used
out = tensor.max_and_argmax(
softmax(tensor.exp(tensor.tanh(sigmoid(x)))),
softmax_graph(tensor.exp(tensor.tanh(sigmoid(x)))),
axis=-1)[1]
fgraph = gof.FunctionGraph(
[x],
......@@ -1147,7 +1147,7 @@ def test_argmax_pushdown():
x = tensor.matrix()
# test that the max_and_argmax is not pushed down if the max is used
out = tensor.max_and_argmax(
softmax(tensor.exp(tensor.tanh(sigmoid(x)))),
softmax_op(tensor.exp(tensor.tanh(sigmoid(x)))),
axis=-1)[0]
fgraph = gof.FunctionGraph(
[x],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论