提交 3e84f926 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix a bug where we pushdown the max to the input of softmax, softplus,…

fix a bug where we pushdown the max to the input of softmax, softplus, tensor.exp, tensor.log, tensor.tanh, sigmoid, softmax_with_bias. Add test for this bug and print a warning by default if we encounter the case where the bug was causing problem. The warning can be disabled by the Theano flag warn.argmax_pushdown_bug.
上级 f0d4b93d
...@@ -68,3 +68,12 @@ AddConfigVar('gpuelemwise.sync', ...@@ -68,3 +68,12 @@ AddConfigVar('gpuelemwise.sync',
AddConfigVar('traceback.limit', AddConfigVar('traceback.limit',
"The number of stack to trace. -1 mean all.", "The number of stack to trace. -1 mean all.",
IntParam(5)) IntParam(5))
###
### To disable some warning about old bug that are fixed now.
###
AddConfigVar('warn.argmax_pushdown_bug',
"Warn if in past version of Theano we generated a bug with the optimisation theano.tensor.nnet.nnet.local_argmax_pushdown optimization. Was fixed 27 may 2010",
BoolParam(True))
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
:note: TODO: factor this out into a neural-network toolbox. :note: TODO: factor this out into a neural-network toolbox.
""" """
import logging
import numpy import numpy
import theano
from theano import gof from theano import gof
from theano import printing from theano import printing
from theano.tensor import basic as tensor from theano.tensor import basic as tensor
...@@ -872,7 +874,15 @@ opt.register_specialize(local_crossentropy_to_crossentropy_with_softmax_grad) ...@@ -872,7 +874,15 @@ opt.register_specialize(local_crossentropy_to_crossentropy_with_softmax_grad)
@opt.register_specialize @opt.register_specialize
@gof.local_optimizer([tensor._max_and_argmax]) @gof.local_optimizer([tensor._max_and_argmax])
def local_argmax_pushdown(node): def local_argmax_pushdown(node):
if node.op == tensor._max_and_argmax: if node.op == tensor._max_and_argmax and node.inputs[0].owner and \
len(node.outputs[0].clients)>0 and node.inputs[0].owner.op in \
(softmax, softplus, tensor.exp, tensor.log, tensor.tanh, sigmoid,
softmax_with_bias):
if theano.config.warn.argmax_pushdown_bug:
logging.getLogger('theano.tensor.nnet.nnet').warn("WARNING: their was a bug in Theano fixed the 27 may 2010 in this case. I.E. when we take the max of a softplus, softmax, exp, log, tanh, sigmoid, softmax_with_bias op, we where doing the max of the parent of the input. To remove this warning set the Theano flags 'warn.argmax_pushdown_bug' to False")
if node.op == tensor._max_and_argmax and node.inputs[0].owner and len(node.outputs[0].clients)==0:
x_max, x_argmax = node.outputs x_max, x_argmax = node.outputs
x, axis = node.inputs x, axis = node.inputs
#TODO: Make a list/set of monotonic ops... #TODO: Make a list/set of monotonic ops...
......
...@@ -773,9 +773,11 @@ class T_CrossentropyCategorical1Hot(unittest.TestCase): ...@@ -773,9 +773,11 @@ class T_CrossentropyCategorical1Hot(unittest.TestCase):
def test_argmax_pushdown(): def test_argmax_pushdown():
x = tensor.dmatrix() x = tensor.dmatrix()
#test that the max_and_argmax is pushed down if the max is not used
out = tensor.max_and_argmax(softmax(tensor.exp(tensor.tanh(sigmoid(x)))))[1]
env = gof.Env( env = gof.Env(
[x], [x],
[tensor.argmax(softmax(tensor.exp(tensor.tanh(sigmoid(x)))))]) [out])
theano.compile.mode.optdb.query( theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(env) theano.compile.mode.OPT_FAST_RUN).optimize(env)
...@@ -785,27 +787,67 @@ def test_argmax_pushdown(): ...@@ -785,27 +787,67 @@ def test_argmax_pushdown():
#print node.op #print node.op
assert len(env.toposort()) == 2 # an output_guard is second assert len(env.toposort()) == 2 # an output_guard is second
assert env.toposort()[0].op == tensor._max_and_argmax assert env.toposort()[0].op == tensor._max_and_argmax
assert str(env.toposort()[1].op) == 'OutputGuard'
x = tensor.dmatrix()
#test that the max_and_argmax is not pushed down if the max is used
out = tensor.max_and_argmax(softmax(tensor.exp(tensor.tanh(sigmoid(x)))))[0]
env = gof.Env(
[x],
[out])
theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(env)
#print 'AFTER'
#for node in env.toposort():
#print node.op
assert len(env.toposort()) == 4 # an output_guard is second
assert isinstance(env.toposort()[0].op, tensor.Elemwise)
assert isinstance(env.toposort()[1].op, Softmax)
assert isinstance(env.toposort()[2].op, tensor.MaxAndArgmax)
assert str(env.toposort()[3].op) == 'OutputGuard'
def test_argmax_pushdown_bias(): def test_argmax_pushdown_bias():
x = tensor.dmatrix() x = tensor.dmatrix()
b = tensor.dvector() b = tensor.dvector()
out = tensor.argmax(softmax_with_bias(x, b))
env = gof.Env( env = gof.Env(
[x,b], [x,b],
[tensor.argmax(softmax_with_bias(x, b))]) [out])
theano.compile.mode.optdb.query( theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(env) theano.compile.mode.OPT_FAST_RUN).optimize(env)
print 'AFTER' #print 'AFTER'
for node in env.toposort(): #for node in env.toposort():
print node.op # print node.op
assert len(env.toposort()) == 4 assert len(env.toposort()) == 4
assert isinstance(env.toposort()[0].op, tensor.DimShuffle) assert isinstance(env.toposort()[0].op, tensor.DimShuffle)
assert isinstance(env.toposort()[1].op, tensor.Elemwise) assert isinstance(env.toposort()[1].op, tensor.Elemwise)
assert isinstance(env.toposort()[2].op, tensor.MaxAndArgmax) assert isinstance(env.toposort()[2].op, tensor.MaxAndArgmax)
assert str(env.toposort()[3].op) == 'OutputGuard' assert str(env.toposort()[3].op) == 'OutputGuard'
x = tensor.dmatrix()
b = tensor.dvector()
out = tensor.max_and_argmax(softmax_with_bias(x, b))[0]
env = gof.Env(
[x,b],
[out])
theano.compile.mode.optdb.query(
theano.compile.mode.OPT_FAST_RUN).optimize(env)
#print 'AFTER'
#for node in env.toposort():
# print node.op
assert len(env.toposort()) == 3
assert isinstance(env.toposort()[0].op, SoftmaxWithBias)
assert isinstance(env.toposort()[1].op, tensor.MaxAndArgmax)
assert str(env.toposort()[2].op) == 'OutputGuard'
def test_asymptotic_32(): def test_asymptotic_32():
""" """
This test makes sure that our functions behave sensibly when huge values are present This test makes sure that our functions behave sensibly when huge values are present
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论