提交 2242b05a authored 作者: Frederic's avatar Frederic

Fix Usmm test in fast_compile mode.

上级 f4f40bbd
......@@ -675,15 +675,20 @@ class UsmmTests(unittest.TestCase):
assert abs(f_a_out - f_b_out).max() < 1e-4
topo = f_a.maker.env.toposort()
up = theano.scalar.upcast(dtype1, dtype2, dtype3, dtype4)
if y.type.dtype == up and format1 == 'csc' and format2 == 'dense':
fast_compile = theano.config.mode == "FAST_COMPILE"
if (y.type.dtype == up and format1 == 'csc' and format2 == 'dense'
and not fast_compile) and up in ('float32', 'float64'):
# The op UsmmCscDense should be inserted
assert (sum([isinstance(node.op, tensor.Elemwise) and
isinstance(node.op.scalar_op,
theano.scalar.basic.Cast)
for node in topo]) == len(topo) - 5)
new_topo = []
for node in topo:
if not isinstance(node.op, tensor.Elemwise) and \
isinstance(node.op.scalar_op, theano.scalar.basic.Cast):
if not (isinstance(node.op, tensor.Elemwise) and \
isinstance(node.op.scalar_op, theano.scalar.basic.Cast)):
new_topo.append(node)
topo = new_topo
assert len(topo) == 5, topo
......@@ -698,7 +703,8 @@ class UsmmTests(unittest.TestCase):
assert isinstance(topo[4].op, theano.sparse.UsmmCscDense)
if inplace:
assert topo[4].op.inplace
else:
elif not fast_compile:
# The op Usmm should be inserted
assert len(topo)==3, topo
assert isinstance(topo[0].op, theano.tensor.DimShuffle)
assert topo[1].op == theano.tensor.neg
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论