提交 0686893f authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

Logical tests added

上级 303cc80d
...@@ -4992,6 +4992,7 @@ class T_local_sum_prod_dimshuffle(unittest.TestCase): ...@@ -4992,6 +4992,7 @@ class T_local_sum_prod_dimshuffle(unittest.TestCase):
a = T.matrix('a') a = T.matrix('a')
b = T.vector('b') b = T.vector('b')
c = T.tensor3('c') c = T.tensor3('c')
e = T.matrix('e')
d = T.scalar('d') d = T.scalar('d')
prod = T.prod prod = T.prod
prods = [ prods = [
...@@ -5044,13 +5045,15 @@ class T_local_sum_prod_dimshuffle(unittest.TestCase): ...@@ -5044,13 +5045,15 @@ class T_local_sum_prod_dimshuffle(unittest.TestCase):
mode_without_opt._optimizer = mode_without_opt._optimizer.excluding( mode_without_opt._optimizer = mode_without_opt._optimizer.excluding(
'local_sum_prod_div_dimshuffle') 'local_sum_prod_div_dimshuffle')
# Numerical tests
for i, s in enumerate(prods): for i, s in enumerate(prods):
f = theano.function([a, b, c, d], s, f = theano.function([a, b, c, d], s,
on_unused_input='ignore', on_unused_input='ignore',
mode=mode_without_opt) mode=Mode(optimizer=None))
g = theano.function([a, b, c, d], s, g = theano.function([a, b, c, d], s,
on_unused_input='ignore', on_unused_input='ignore',
mode=mode_with_opt) mode=mode_with_opt)
# g = f.maker.fgraph.toposort() # g = f.maker.fgraph.toposort()
# assert isinstance(g[-1].op.scalar_op, # assert isinstance(g[-1].op.scalar_op,
# theano.scalar.basic.TrueDiv) # theano.scalar.basic.TrueDiv)
...@@ -5058,6 +5061,33 @@ class T_local_sum_prod_dimshuffle(unittest.TestCase): ...@@ -5058,6 +5061,33 @@ class T_local_sum_prod_dimshuffle(unittest.TestCase):
utt.assert_allclose(f(a_val, b_val, c_val, d_val), utt.assert_allclose(f(a_val, b_val, c_val, d_val),
g(a_val, b_val, c_val, d_val)) g(a_val, b_val, c_val, d_val))
# Logical tests
prods = [
prod(a / e),
prod(a / d),
prod(a / d.dimshuffle('x', 'x')),
prod(c / d.dimshuffle('x', 'x', 'x'), axis=1),
prod(a.dimshuffle(1, 0) / b.dimshuffle(0, 'x'), axis=1),
prod(c / b.dimshuffle(0, 'x', 'x'), axis=(1, 2)),
prod(prod(c, axis=1) / b, axis=0)]
expected_outer_operator = [theano.scalar.basic.Mul,
theano.scalar.basic.Composite,
theano.scalar.basic.Composite,
theano.scalar.basic.TrueDiv,
theano.scalar.basic.Composite,
theano.scalar.basic.Composite,
theano.scalar.basic.Composite]
for i, s in enumerate(prods):
f = theano.function([a, b, c, d, e], s,
on_unused_input='ignore',
mode=Mode(optimizer=None))
g = theano.function([a, b, c, d, e], s,
on_unused_input='ignore',
mode=mode_with_opt)
assert isinstance(g.maker.fgraph.toposort()[-1].op.scalar_op,
expected_outer_operator[i])
# TODO: # TODO:
# test_local_sum_prod_dimshuffle (a * b * c) # test_local_sum_prod_dimshuffle (a * b * c)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论