提交 9465d81a authored 作者: nouiz's avatar nouiz

Merge pull request #1039 from lamblin/test_softmaxgrad_flatten

Add test for the "flatten" vector case
...@@ -562,8 +562,6 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -562,8 +562,6 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
def test_get_rid_of_advanced_indexing_version_of_xent(self): def test_get_rid_of_advanced_indexing_version_of_xent(self):
verbose = 0 verbose = 0
if verbose:
from theano.printing import pprint
# TODO: add the optimization in FAST_COMPILE? # TODO: add the optimization in FAST_COMPILE?
# In the mean time, run it as 'FAST_RUN' instead # In the mean time, run it as 'FAST_RUN' instead
mode = theano.compile.mode.get_default_mode() mode = theano.compile.mode.get_default_mode()
...@@ -591,7 +589,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -591,7 +589,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(f) theano.printing.debugprint(f)
try: try:
assert len(f.maker.fgraph.toposort()) == 4 ops = [node.op for node in f.maker.fgraph.toposort()]
assert len(ops) == 4
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops
if isinstance(o, T.AdvancedSubtensor)]
f(x_val, y_val) f(x_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(f) theano.printing.debugprint(f)
...@@ -602,7 +604,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -602,7 +604,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(g) theano.printing.debugprint(g)
try: try:
assert len(g.maker.fgraph.toposort()) == 4 ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(ops) == 4
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax in ops
assert softmax_grad not in ops
g(x_val, y_val) g(x_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(g) theano.printing.debugprint(g)
...@@ -620,7 +626,9 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -620,7 +626,9 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(f) theano.printing.debugprint(f)
try: try:
assert len(f.maker.fgraph.toposort()) == 2 # [big_op, sum] ops = [node.op for node in f.maker.fgraph.toposort()]
assert len(ops) == 2 # [big_op, sum]
assert crossentropy_softmax_argmax_1hot_with_bias in ops
f(x_val, b_val, y_val) f(x_val, b_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(f) theano.printing.debugprint(f)
...@@ -629,7 +637,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -629,7 +637,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(g) theano.printing.debugprint(g)
try: try:
assert len(g.maker.fgraph.toposort()) == 4 ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(ops) == 4
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_with_bias in ops
assert softmax_grad not in ops
g(x_val, b_val, y_val) g(x_val, b_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(g) theano.printing.debugprint(g)
...@@ -647,7 +659,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -647,7 +659,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(f) theano.printing.debugprint(f)
try: try:
assert len(f.maker.fgraph.toposort()) == 6 ops = [node.op for node in f.maker.fgraph.toposort()]
assert len(ops) == 6
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops
if isinstance(o, T.AdvancedSubtensor)]
f(x_val, y_val) f(x_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(f) theano.printing.debugprint(f)
...@@ -657,9 +673,13 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -657,9 +673,13 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(g) theano.printing.debugprint(g)
try: try:
assert len(g.maker.fgraph.toposort()) in (6, 7) ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(ops) in (6, 7)
#there's an extra dimshuffle in there #there's an extra dimshuffle in there
# but I can't think of a good rule to get rid of it # but I can't think of a good rule to get rid of it
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax in ops
assert softmax_grad not in ops
g(x_val, y_val) g(x_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(g) theano.printing.debugprint(g)
...@@ -676,7 +696,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -676,7 +696,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(f) theano.printing.debugprint(f)
try: try:
assert len(f.maker.fgraph.toposort()) == 4 ops = [node.op for node in f.maker.fgraph.toposort()]
assert len(ops) == 4
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops
if isinstance(o, T.AdvancedSubtensor)]
except Exception: except Exception:
theano.printing.debugprint(f) theano.printing.debugprint(f)
raise raise
...@@ -684,7 +708,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -684,7 +708,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(g) theano.printing.debugprint(g)
try: try:
assert len(g.maker.fgraph.toposort()) in (6, 7) ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(ops) in (6, 7)
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_with_bias in ops
assert softmax_grad not in ops
g(x_val, b_val, y_val) g(x_val, b_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(g) theano.printing.debugprint(g)
...@@ -697,10 +725,8 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -697,10 +725,8 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
mode = 'FAST_RUN' mode = 'FAST_RUN'
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
x_val = rng.randn(3, 5).astype(config.floatX) x_val = rng.randn(3, 5).astype(config.floatX)
b_val = rng.randn(5).astype(config.floatX)
y_val = numpy.asarray([2, 4, 1], dtype='int64') y_val = numpy.asarray([2, 4, 1], dtype='int64')
x = T.matrix('x') x = T.matrix('x')
b = T.vector('b')
y = T.lvector('y') y = T.lvector('y')
yi = T.cast(y, 'int32') yi = T.cast(y, 'int32')
expressions = [ expressions = [
...@@ -716,7 +742,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -716,7 +742,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(f) theano.printing.debugprint(f)
try: try:
assert len(f.maker.fgraph.toposort()) == 5 ops = [node.op for node in f.maker.fgraph.toposort()]
assert len(ops) == 5
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops
if isinstance(o, T.AdvancedSubtensor)]
f(x_val, y_val) f(x_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(f) theano.printing.debugprint(f)
...@@ -727,7 +757,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -727,7 +757,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(g) theano.printing.debugprint(g)
try: try:
assert len(g.maker.fgraph.toposort()) == 5 ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(ops) == 5
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax in ops
assert softmax_grad not in ops
g(x_val, y_val) g(x_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(g) theano.printing.debugprint(g)
...@@ -762,8 +796,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -762,8 +796,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
print_graph(f) print_graph(f)
try: try:
prev, last = f.maker.fgraph.toposort()[-2:] ops = [node.op for node in f.maker.fgraph.toposort()]
assert len(f.maker.fgraph.toposort()) == 5 assert len(ops) == 5
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops
if isinstance(o, T.AdvancedSubtensor)]
f(x_val, y_val) f(x_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(f) theano.printing.debugprint(f)
...@@ -815,9 +852,80 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -815,9 +852,80 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
print_graph(f) print_graph(f)
try: try:
prev, last = f.maker.fgraph.toposort()[-2:] ops = [node.op for node in f.maker.fgraph.toposort()]
assert len(f.maker.fgraph.toposort()) == 3
# [big_op, sum, dim_shuffle] # [big_op, sum, dim_shuffle]
assert len(ops) == 3
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops
if isinstance(o, T.AdvancedSubtensor)]
f(x_val, b_val, y_val)
except Exception:
theano.printing.debugprint(f)
raise
backup = config.warn.sum_div_dimshuffle_bug
config.warn.sum_div_dimshuffle_bug = False
try:
g = theano.function([x, b, y], T.grad(expr, x), mode=mode)
finally:
config.warn.sum_div_dimshuffle_bug = backup
if verbose:
print_graph(g)
try:
ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(ops) <= 6
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_with_bias in ops
assert softmax_grad not in ops
g(x_val, b_val, y_val)
except Exception:
theano.printing.debugprint(g)
raise
def test_optimize_xent_vector3(self):
# Same as test_optimize_xent_vector2, but y is the result of
# a "flatten", and it somehow makes the constant-folding
# of arange(y.shape[0]) happen before the xent optimization
verbose = 0
mode = theano.compile.mode.get_default_mode()
if mode == theano.compile.mode.get_mode('FAST_COMPILE'):
mode = 'FAST_RUN'
rng = numpy.random.RandomState(utt.fetch_seed())
x_val = rng.randn(5).astype(config.floatX)
b_val = rng.randn(5).astype(config.floatX)
y_val = numpy.asarray([2])
x = T.vector('x')
b = T.vector('b')
y_ = T.lvector('y_')
y = y_.flatten()
def print_graph(func):
for i, node in enumerate(func.maker.fgraph.toposort()):
print i, node
# Last node should be the output
print i, printing.pprint(node.outputs[0])
print
## Test that a biased softmax is optimized correctly
bias_expressions = [
T.sum(-T.log(softmax(x + b)[T.arange(y.shape[0]), y])),
-T.sum(T.log(softmax(b + x)[T.arange(y.shape[0]), y])),
-T.sum(T.log(softmax(x + b))[T.arange(y.shape[0]), y]),
T.sum(-T.log(softmax(b + x))[T.arange(y.shape[0]), y])]
for expr in bias_expressions:
f = theano.function([x, b, y_], expr, mode=mode)
if verbose:
print_graph(f)
try:
ops = [node.op for node in f.maker.fgraph.toposort()]
# [big_op, sum, dim_shuffle, flatten]
assert len(ops) <= 4
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops
if isinstance(o, T.AdvancedSubtensor)]
f(x_val, b_val, y_val) f(x_val, b_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(f) theano.printing.debugprint(f)
...@@ -851,10 +959,8 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -851,10 +959,8 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
mode = 'FAST_RUN' mode = 'FAST_RUN'
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
x_val = rng.randn(3, 5).astype(config.floatX) x_val = rng.randn(3, 5).astype(config.floatX)
b_val = rng.randn(5).astype(config.floatX)
y_val = numpy.asarray([2, 4, 1]) y_val = numpy.asarray([2, 4, 1])
x = T.matrix('x') x = T.matrix('x')
b = T.vector('b')
y = T.lvector('y') y = T.lvector('y')
a = T.scalar('a') a = T.scalar('a')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论