提交 941a91e4 authored 作者: Frederic's avatar Frederic

small clean up, test more corner case and pep8.

上级 10563359
...@@ -3473,7 +3473,7 @@ class T_local_sum(unittest.TestCase): ...@@ -3473,7 +3473,7 @@ class T_local_sum(unittest.TestCase):
def test_local_sum_all_to_none(self): def test_local_sum_all_to_none(self):
a = T.tensor3() a = T.tensor3()
input = numpy.arange(3 * 3 * 3, dtype=config.floatX).reshape(3, 3, 3) input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
f = theano.function([a], a.sum(), mode=self.mode) f = theano.function([a], a.sum(), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1 assert len(f.maker.fgraph.apply_nodes) == 1
assert numpy.allclose(f(input), input.sum()) assert numpy.allclose(f(input), input.sum())
...@@ -3520,24 +3520,23 @@ class T_local_sum(unittest.TestCase): ...@@ -3520,24 +3520,23 @@ class T_local_sum(unittest.TestCase):
assert len(f.maker.fgraph.apply_nodes) == 1 assert len(f.maker.fgraph.apply_nodes) == 1
for d, dd in dims[:6]: for d, dd in dims[:6]:
f = theano.function([a], a.sum(d).sum(dd). f = theano.function([a], a.sum(d).sum(dd).
sum(0), mode=self.mode) sum(0), mode=self.mode)
assert numpy.allclose(f(input), input.sum(d).sum(dd).sum(0)) assert numpy.allclose(f(input), input.sum(d).sum(dd).sum(0))
assert len(f.maker.fgraph.apply_nodes) == 1 assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]: for d in [0, 1, 2]:
f = theano.function([a], a.sum(d).sum(None), mode=self.mode) f = theano.function([a], a.sum(d).sum(None), mode=self.mode)
assert numpy.allclose(f(input), input.sum(d).sum()) assert numpy.allclose(f(input), input.sum(d).sum())
assert len(f.maker.fgraph.apply_nodes) == 1 assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]: f = theano.function([a], a.sum(None).sum(), mode=self.mode)
f = theano.function([a], a.sum(None).sum(), mode=self.mode) assert numpy.allclose(f(input), input.sum())
assert numpy.allclose(f(input), input.sum()) assert len(f.maker.fgraph.apply_nodes) == 1
assert len(f.maker.fgraph.apply_nodes) == 1
finally: finally:
config.warn.sum_sum_bug = backup config.warn.sum_sum_bug = backup
def test_local_sum_alloc(self): def test_local_sum_alloc(self):
a = T.dtensor3() a = T.dtensor3()
input = numpy.asarray(numpy.arange(2 * 3 * 4).reshape(2, 3, 4), input = numpy.asarray(numpy.arange(2 * 3 * 4).reshape(2, 3, 4),
dtype='float64') dtype='float64')
mode = self.mode.including('specialize').excluding('fusion') mode = self.mode.including('specialize').excluding('fusion')
for t_like,n_like,nb_nodes in [(tensor.zeros_like,numpy.zeros_like,(1,3,3,2)), for t_like,n_like,nb_nodes in [(tensor.zeros_like,numpy.zeros_like,(1,3,3,2)),
...@@ -3571,14 +3570,14 @@ class T_local_sum(unittest.TestCase): ...@@ -3571,14 +3570,14 @@ class T_local_sum(unittest.TestCase):
try: try:
for d, dd in [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)]: for d, dd in [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1)]:
f = theano.function([a], t_like(a). f = theano.function([a], t_like(a).
sum(d).sum(dd), mode=mode) sum(d).sum(dd), mode=mode)
assert numpy.allclose(f(input), assert numpy.allclose(f(input),
n_like(input).sum(d).sum(dd)) n_like(input).sum(d).sum(dd))
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[3] assert len(f.maker.fgraph.apply_nodes) == nb_nodes[3]
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
assert topo[-1].op == T.alloc assert topo[-1].op == T.alloc
assert not any([isinstance(node.op, assert not any([isinstance(node.op,
T.Sum) for node in topo]) T.Sum) for node in topo])
finally: finally:
config.warn.sum_sum_bug = backup config.warn.sum_sum_bug = backup
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论