提交 fd54945c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

That test was failing at some point, I do not remember when or why...

上级 c3b3f456
......@@ -516,6 +516,51 @@ class T_CrossentropyCategorical1Hot(unittest.TestCase):
theano.printing.debugprint(g)
raise
def test_xent_thing_int32(self):
verbose = 0
mode = theano.compile.mode.get_default_mode()
if mode == theano.compile.mode.get_mode('FAST_COMPILE'):
mode = 'FAST_RUN'
rng = numpy.random.RandomState(utt.fetch_seed())
x_val = rng.randn(3,5)
b_val = rng.randn(5)
y_val = numpy.asarray([2,4,1], dtype='int64')
x = T.dmatrix('x')
b = T.dvector('b')
y = T.lvector('y')
yi = T.cast(y, 'int32')
expressions = [
T.sum(-T.log(softmax(x)[T.arange(yi.shape[0]), yi])),
-T.sum(T.log(softmax(x)[T.arange(yi.shape[0]), yi])),
-T.sum(T.log(softmax(x))[T.arange(yi.shape[0]), yi]),
T.sum(-T.log(softmax(x))[T.arange(yi.shape[0]), yi])
]
for expr in expressions:
# Verify the optimizer worked on the expressions
f = theano.function([x,y], expr, mode=mode)
if verbose:
theano.printing.debugprint(f)
try:
assert len(f.maker.env.toposort()) == 5
f(x_val, y_val)
except:
theano.printing.debugprint(f)
raise
# Also verify the gradient wrt x
g = theano.function([x,y], T.grad(expr, x), mode=mode)
if verbose:
theano.printing.debugprint(g)
try:
assert len(g.maker.env.toposort()) == 5
g(x_val, y_val)
except:
theano.printing.debugprint(g)
raise
## Test that a biased softmax is optimized correctly
bias_expressions = [
......@@ -526,7 +571,8 @@ class T_CrossentropyCategorical1Hot(unittest.TestCase):
for expr in bias_expressions:
f = theano.function([x,b,y], expr, mode=mode)
if verbose: print_graph(f)
if verbose:
theano.printing.debugprint(f)
try:
assert len(f.maker.env.toposort()) == 2 # [big_op, sum]
f(x_val, b_val, y_val)
......@@ -535,7 +581,8 @@ class T_CrossentropyCategorical1Hot(unittest.TestCase):
raise
g = theano.function([x,b,y], T.grad(expr, x), mode=mode)
if verbose: print_graph(g)
if verbose:
theano.printing.debugprint(g)
try:
assert len(g.maker.env.toposort()) == 4
g(x_val, b_val, y_val)
......@@ -552,7 +599,8 @@ class T_CrossentropyCategorical1Hot(unittest.TestCase):
for expr in mean_expressions:
f = theano.function([x,y], expr, mode=mode)
if verbose: print_graph(f)
if verbose:
theano.printing.debugprint(f)
try:
assert len(f.maker.env.toposort()) == 6
f(x_val, y_val)
......@@ -561,7 +609,8 @@ class T_CrossentropyCategorical1Hot(unittest.TestCase):
raise
g = theano.function([x,y], T.grad(expr, x), mode=mode)
if verbose: print_graph(g)
if verbose:
theano.printing.debugprint(g)
try:
assert len(g.maker.env.toposort()) in (6,7) #there's an extra dimshuffle in there
# but I can't think of a good rule to get rid of it
......@@ -578,7 +627,8 @@ class T_CrossentropyCategorical1Hot(unittest.TestCase):
for expr in mean_bias_expressions:
f = theano.function([x,b,y], expr, mode=mode)
if verbose: print_graph(f)
if verbose:
theano.printing.debugprint(f)
try:
assert len(f.maker.env.toposort()) == 4
except:
......@@ -586,7 +636,8 @@ class T_CrossentropyCategorical1Hot(unittest.TestCase):
raise
g = theano.function([x,b,y], T.grad(expr, x), mode=mode)
if verbose: print_graph(g)
if verbose:
theano.printing.debugprint(g)
try:
assert len(g.maker.env.toposort()) in (6,7)
g(x_val, b_val, y_val)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论