提交 59b8455c authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

more tests + comments

上级 0686893f
......@@ -5038,51 +5038,45 @@ class T_local_sum_prod_dimshuffle(unittest.TestCase):
c_val = rng.randn(2, 2, 2).astype(config.floatX)
d_val = numpy.asarray(rng.randn(), config.floatX)
mode_with_opt = copy.copy(theano.compile.mode.get_default_mode())
mode_without_opt = copy.copy(theano.compile.mode.get_default_mode())
mode_with_opt._optimizer = mode_with_opt._optimizer.including(
'local_sum_prod_div_dimshuffle')
mode_without_opt._optimizer = mode_without_opt._optimizer.excluding(
'local_sum_prod_div_dimshuffle')
# Numerical tests
default_mode = theano.compile.mode.get_default_mode()
mode_with_opt = default_mode.including('local_sum_prod_div_dimshuffle')
mode_without_opt = default_mode.excluding('local_sum_prod_div_dimshuffle')
# Numerical tests: tests whether the numerical values with and without
# optimizer are equal or not.
for i, s in enumerate(prods):
f = theano.function([a, b, c, d], s,
on_unused_input='ignore',
mode=Mode(optimizer=None))
mode=mode_without_opt)
g = theano.function([a, b, c, d], s,
on_unused_input='ignore',
mode=mode_with_opt)
# g = f.maker.fgraph.toposort()
# assert isinstance(g[-1].op.scalar_op,
# theano.scalar.basic.TrueDiv)
utt.assert_allclose(f(a_val, b_val, c_val, d_val),
g(a_val, b_val, c_val, d_val))
# Logical tests
# Logical tests: tests whether the optimizer has been appplied or not
# by checking graph structure.
prods = [
prod(a / e),
prod(a / d),
prod(a / d.dimshuffle('x', 'x')),
prod(c / d.dimshuffle('x', 'x', 'x'), axis=1),
prod(a.dimshuffle(1, 0) / b.dimshuffle(0, 'x'), axis=1),
prod(c / b.dimshuffle(0, 'x', 'x'), axis=(1, 2)),
prod(prod(c, axis=1) / b, axis=0)]
prod(c / b.dimshuffle(0, 'x', 'x'), axis=(1, 0)),
prod(prod(c, axis=1) / b, axis=0),
prod(prod(c, axis=(1, 2)) / b, axis=0)]
expected_outer_operator = [theano.scalar.basic.Mul,
theano.scalar.basic.Composite,
theano.scalar.basic.Composite,
theano.scalar.basic.TrueDiv,
theano.scalar.basic.Composite,
theano.scalar.basic.Mul,
theano.scalar.basic.Composite,
theano.scalar.basic.Composite]
theano.scalar.basic.Mul]
for i, s in enumerate(prods):
f = theano.function([a, b, c, d, e], s,
on_unused_input='ignore',
mode=Mode(optimizer=None))
g = theano.function([a, b, c, d, e], s,
on_unused_input='ignore',
mode=mode_with_opt)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论