提交 a885bfb0 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Make tests more accurate

上级 b5d38d68
...@@ -591,7 +591,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -591,7 +591,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(f) theano.printing.debugprint(f)
try: try:
assert len(f.maker.fgraph.toposort()) == 4 ops = [node.op for node in f.maker.fgraph.toposort()]
assert len(ops) == 4
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops
if isinstance(o, T.AdvancedSubtensor)]
f(x_val, y_val) f(x_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(f) theano.printing.debugprint(f)
...@@ -602,7 +606,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -602,7 +606,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(g) theano.printing.debugprint(g)
try: try:
assert len(g.maker.fgraph.toposort()) == 4 ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(ops) == 4
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax in ops
assert softmax_grad not in ops
g(x_val, y_val) g(x_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(g) theano.printing.debugprint(g)
...@@ -620,7 +628,9 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -620,7 +628,9 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(f) theano.printing.debugprint(f)
try: try:
assert len(f.maker.fgraph.toposort()) == 2 # [big_op, sum] ops = [node.op for node in f.maker.fgraph.toposort()]
assert len(ops) == 2 # [big_op, sum]
assert crossentropy_softmax_argmax_1hot_with_bias in ops
f(x_val, b_val, y_val) f(x_val, b_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(f) theano.printing.debugprint(f)
...@@ -629,7 +639,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -629,7 +639,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(g) theano.printing.debugprint(g)
try: try:
assert len(g.maker.fgraph.toposort()) == 4 ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(ops) == 4
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_with_bias in ops
assert softmax_grad not in ops
g(x_val, b_val, y_val) g(x_val, b_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(g) theano.printing.debugprint(g)
...@@ -647,7 +661,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -647,7 +661,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(f) theano.printing.debugprint(f)
try: try:
assert len(f.maker.fgraph.toposort()) == 6 ops = [node.op for node in f.maker.fgraph.toposort()]
assert len(ops) == 6
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops
if isinstance(o, T.AdvancedSubtensor)]
f(x_val, y_val) f(x_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(f) theano.printing.debugprint(f)
...@@ -657,9 +675,13 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -657,9 +675,13 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(g) theano.printing.debugprint(g)
try: try:
assert len(g.maker.fgraph.toposort()) in (6, 7) ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(ops) in (6, 7)
#there's an extra dimshuffle in there #there's an extra dimshuffle in there
# but I can't think of a good rule to get rid of it # but I can't think of a good rule to get rid of it
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax in ops
assert softmax_grad not in ops
g(x_val, y_val) g(x_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(g) theano.printing.debugprint(g)
...@@ -676,7 +698,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -676,7 +698,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(f) theano.printing.debugprint(f)
try: try:
assert len(f.maker.fgraph.toposort()) == 4 ops = [node.op for node in f.maker.fgraph.toposort()]
assert len(ops) == 4
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops
if isinstance(o, T.AdvancedSubtensor)]
except Exception: except Exception:
theano.printing.debugprint(f) theano.printing.debugprint(f)
raise raise
...@@ -684,7 +710,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -684,7 +710,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(g) theano.printing.debugprint(g)
try: try:
assert len(g.maker.fgraph.toposort()) in (6, 7) ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(ops) in (6, 7)
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax_with_bias in ops
assert softmax_grad not in ops
g(x_val, b_val, y_val) g(x_val, b_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(g) theano.printing.debugprint(g)
...@@ -716,7 +746,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -716,7 +746,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(f) theano.printing.debugprint(f)
try: try:
assert len(f.maker.fgraph.toposort()) == 5 ops = [node.op for node in f.maker.fgraph.toposort()]
assert len(ops) == 5
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops
if isinstance(o, T.AdvancedSubtensor)]
f(x_val, y_val) f(x_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(f) theano.printing.debugprint(f)
...@@ -727,7 +761,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -727,7 +761,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
theano.printing.debugprint(g) theano.printing.debugprint(g)
try: try:
assert len(g.maker.fgraph.toposort()) == 5 ops = [node.op for node in g.maker.fgraph.toposort()]
assert len(ops) == 5
assert crossentropy_softmax_1hot_with_bias_dx in ops
assert softmax in ops
assert softmax_grad not in ops
g(x_val, y_val) g(x_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(g) theano.printing.debugprint(g)
...@@ -762,8 +800,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -762,8 +800,11 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
print_graph(f) print_graph(f)
try: try:
prev, last = f.maker.fgraph.toposort()[-2:] ops = [node.op for node in f.maker.fgraph.toposort()]
assert len(f.maker.fgraph.toposort()) == 5 assert len(ops) == 5
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops
if isinstance(o, T.AdvancedSubtensor)]
f(x_val, y_val) f(x_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(f) theano.printing.debugprint(f)
...@@ -815,9 +856,12 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester): ...@@ -815,9 +856,12 @@ class T_CrossentropyCategorical1Hot(utt.InferShapeTester):
if verbose: if verbose:
print_graph(f) print_graph(f)
try: try:
prev, last = f.maker.fgraph.toposort()[-2:] ops = [node.op for node in f.maker.fgraph.toposort()]
assert len(f.maker.fgraph.toposort()) == 3
# [big_op, sum, dim_shuffle] # [big_op, sum, dim_shuffle]
assert len(ops) == 3
assert crossentropy_softmax_argmax_1hot_with_bias in ops
assert not [1 for o in ops
if isinstance(o, T.AdvancedSubtensor)]
f(x_val, b_val, y_val) f(x_val, b_val, y_val)
except Exception: except Exception:
theano.printing.debugprint(f) theano.printing.debugprint(f)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论