提交 d7f44005 authored 作者: Kelvin Xu's avatar Kelvin Xu

prod tests

上级 a6aff9ab
...@@ -4003,7 +4003,7 @@ def local_sum_prod_all_to_none(node): ...@@ -4003,7 +4003,7 @@ def local_sum_prod_all_to_none(node):
"""Sum{0,1,...N} -> Sum{} or """Sum{0,1,...N} -> Sum{} or
Prod{0,1,...N} -> Prod{} Prod{0,1,...N} -> Prod{}
""" """
if isinstance(node.op, T.Sum) or isinstance(node.opt, T.elemwise.Prod): if isinstance(node.op, T.Sum) or isinstance(node.op, T.elemwise.Prod):
opt_type = T.Sum if isinstance(node.op, T.Sum) else T.elemwise.Prod opt_type = T.Sum if isinstance(node.op, T.Sum) else T.elemwise.Prod
# if all the axes are named, then use None as a shorthand # if all the axes are named, then use None as a shorthand
# this permits more merging # this permits more merging
...@@ -4255,7 +4255,7 @@ def local_opt_alloc(node): ...@@ -4255,7 +4255,7 @@ def local_opt_alloc(node):
to_prod = [shapes[i] for i in xrange(len(shapes)) to_prod = [shapes[i] for i in xrange(len(shapes))
if i in node.op.axis] if i in node.op.axis]
if to_prod: if to_prod:
if isintance(node.op, T.Sum): if isinstance(node.op, T.Sum):
val *= T.mul(*to_prod) val *= T.mul(*to_prod)
else: else:
val = val ** T.mul(*to_prod) val = val ** T.mul(*to_prod)
......
...@@ -4456,21 +4456,33 @@ class test_local_remove_switch_const_cond(unittest.TestCase): ...@@ -4456,21 +4456,33 @@ class test_local_remove_switch_const_cond(unittest.TestCase):
assert numpy.all(f(vx, vy) == vy) assert numpy.all(f(vx, vy) == vy)
class T_local_sum(unittest.TestCase): class T_local_sum_prod(unittest.TestCase):
"""
Test sum/prod opts in opt.py
"""
def setUp(self): def setUp(self):
self.mode = theano.compile.get_default_mode().including('canonicalize', self.mode = theano.compile.get_default_mode().including('canonicalize',
'specialize') 'specialize')
def test_local_sum_all_to_none(self): def test_local_sum_prod_all_to_none(self):
a = T.tensor3() a = T.tensor3()
input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5) input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
# test sum
f = theano.function([a], a.sum(), mode=self.mode) f = theano.function([a], a.sum(), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1 assert len(f.maker.fgraph.apply_nodes) == 1
assert numpy.allclose(f(input), input.sum()) assert numpy.allclose(f(input), input.sum())
# test prod
f = theano.function([a], a.prod(), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
assert numpy.allclose(f(input), input.prod())
# test sum
f = theano.function([a], a.sum([0, 1, 2]), mode=self.mode) f = theano.function([a], a.sum([0, 1, 2]), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1 assert len(f.maker.fgraph.apply_nodes) == 1
assert numpy.allclose(f(input), input.sum()) assert numpy.allclose(f(input), input.sum())
# test prod
f = theano.function([a], a.prod([0, 1, 2]), mode=self.mode)
assert len(f.maker.fgraph.apply_nodes) == 1
assert numpy.allclose(f(input), input.prod())
backup = config.warn.sum_sum_bug backup = config.warn.sum_sum_bug
config.warn.sum_sum_bug = False config.warn.sum_sum_bug = False
...@@ -4481,7 +4493,7 @@ class T_local_sum(unittest.TestCase): ...@@ -4481,7 +4493,7 @@ class T_local_sum(unittest.TestCase):
finally: finally:
config.warn.sum_sum_bug = backup config.warn.sum_sum_bug = backup
def test_local_sum_sum(self): def test_local_sum_sum_prod_prod(self):
a = T.tensor3() a = T.tensor3()
input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5) input = numpy.arange(3 * 4 * 5, dtype=config.floatX).reshape(3, 4, 5)
dims = [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1), dims = [(0, 0), (1, 0), (2, 0), (0, 1), (1, 1), (2, 1),
...@@ -4491,6 +4503,17 @@ class T_local_sum(unittest.TestCase): ...@@ -4491,6 +4503,17 @@ class T_local_sum(unittest.TestCase):
backup = config.warn.sum_sum_bug backup = config.warn.sum_sum_bug
config.warn.sum_sum_bug = False config.warn.sum_sum_bug = False
def my_prod(data, d, dd):
# This prod when d or dd is a tuple of 2 dimensions.
if not isinstance(d, tuple) and not isinstance(dd, tuple):
return data.prod(d).prod(dd)
if isinstance(d, tuple):
d = sorted(d)
return data.prod(d[1]).prod(d[0]).prod(dd)
else:
dd = sorted(dd)
return data.prod(d).prod(dd[1]).prod(dd[0])
def my_sum(data, d, dd): def my_sum(data, d, dd):
# This sum when d or dd is a tuple of 2 dimensions. # This sum when d or dd is a tuple of 2 dimensions.
if not isinstance(d, tuple) and not isinstance(dd, tuple): if not isinstance(d, tuple) and not isinstance(dd, tuple):
...@@ -4523,7 +4546,27 @@ class T_local_sum(unittest.TestCase): ...@@ -4523,7 +4546,27 @@ class T_local_sum(unittest.TestCase):
finally: finally:
config.warn.sum_sum_bug = backup config.warn.sum_sum_bug = backup
def test_local_sum_alloc(self): # test prod
for d, dd in dims:
expected = my_prod(input, d, dd)
f = theano.function([a], a.prod(d).prod(dd), mode=self.mode)
assert numpy.allclose(f(input), expected)
assert len(f.maker.fgraph.apply_nodes) == 1
for d, dd in dims[:6]:
f = theano.function([a], a.prod(d).prod(dd).
prod(0), mode=self.mode)
assert numpy.allclose(f(input), input.prod(d).prod(dd).prod(0))
assert len(f.maker.fgraph.apply_nodes) == 1
for d in [0, 1, 2]:
f = theano.function([a], a.prod(d).prod(None), mode=self.mode)
assert numpy.allclose(f(input), input.prod(d).prod())
assert len(f.maker.fgraph.apply_nodes) == 1
f = theano.function([a], a.prod(None).prod(), mode=self.mode)
assert numpy.allclose(f(input), input.prod())
assert len(f.maker.fgraph.apply_nodes) == 1
def test_local_sum_prod_alloc(self):
a = T.dtensor3() a = T.dtensor3()
input = numpy.asarray(numpy.arange(2 * 3 * 4).reshape(2, 3, 4), input = numpy.asarray(numpy.arange(2 * 3 * 4).reshape(2, 3, 4),
dtype='float64') dtype='float64')
...@@ -4532,6 +4575,7 @@ class T_local_sum(unittest.TestCase): ...@@ -4532,6 +4575,7 @@ class T_local_sum(unittest.TestCase):
for t_like, n_like, nb_nodes in [(tensor.zeros_like, numpy.zeros_like, (1, 3, 3, 2)), for t_like, n_like, nb_nodes in [(tensor.zeros_like, numpy.zeros_like, (1, 3, 3, 2)),
(tensor.ones_like, numpy.ones_like, (5, 5, 5, 6))]: (tensor.ones_like, numpy.ones_like, (5, 5, 5, 6))]:
# test sum
f = theano.function([a], t_like(a).sum(None), mode=mode) f = theano.function([a], t_like(a).sum(None), mode=mode)
assert numpy.allclose(f(input), n_like(input).sum()) assert numpy.allclose(f(input), n_like(input).sum())
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[0] assert len(f.maker.fgraph.apply_nodes) == nb_nodes[0]
...@@ -4555,6 +4599,30 @@ class T_local_sum(unittest.TestCase): ...@@ -4555,6 +4599,30 @@ class T_local_sum(unittest.TestCase):
assert topo[-1].op == T.alloc assert topo[-1].op == T.alloc
assert not any([isinstance(node.op, T.Sum) for node in topo]) assert not any([isinstance(node.op, T.Sum) for node in topo])
# test prod
f = theano.function([a], t_like(a).prod(None), mode=mode)
assert numpy.allclose(f(input), n_like(input).prod())
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[0]
f = theano.function([a], t_like(a).prod([0, 1, 2]), mode=mode)
assert numpy.allclose(f(input), n_like(input).prod())
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[0]
for d in range(3):
f = theano.function([a], t_like(a).prod(d), mode=mode)
assert numpy.allclose(f(input), n_like(input).prod(d))
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[1]
topo = f.maker.fgraph.toposort()
assert topo[-1].op == T.alloc
assert not any([isinstance(node.op, T.elemwise.Prod) for node in topo])
for i in range(3):
f = theano.function([a], t_like(a).prod(i), mode=mode)
assert numpy.allclose(f(input), n_like(input).prod(i))
assert len(f.maker.fgraph.apply_nodes) == nb_nodes[2]
topo = f.maker.fgraph.toposort()
assert topo[-1].op == T.alloc
assert not any([isinstance(node.op, T.elemwise.Prod) for node in topo])
backup = config.warn.sum_sum_bug backup = config.warn.sum_sum_bug
config.warn.sum_sum_bug = False config.warn.sum_sum_bug = False
try: try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论