提交 0bce24e0 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Do not insert Usmm when there is no blas

上级 10b4834b
...@@ -675,7 +675,8 @@ local_usmm = gof.opt.PatternSub( ...@@ -675,7 +675,8 @@ local_usmm = gof.opt.PatternSub(
(theano.tensor.sub, 'z', (theano.tensor.sub, 'z',
(theano.tensor.mul, (theano.tensor.mul,
{'pattern': 'alpha', {'pattern': 'alpha',
'constraint': lambda expr: numpy.all(expr.type.broadcastable)}, 'constraint': lambda expr: (numpy.all(expr.type.broadcastable) and
theano.config.blas.ldflags)},
(sparse._dot, 'x', 'y'))), (sparse._dot, 'x', 'y'))),
(usmm, (theano.tensor.neg, 'alpha'), 'x', 'y', 'z')) (usmm, (theano.tensor.neg, 'alpha'), 'x', 'y', 'z'))
register_specialize(local_usmm, name="local_usmm") register_specialize(local_usmm, name="local_usmm")
......
...@@ -1204,7 +1204,16 @@ class UsmmTests(unittest.TestCase): ...@@ -1204,7 +1204,16 @@ class UsmmTests(unittest.TestCase):
fast_compile = theano.config.mode == "FAST_COMPILE" fast_compile = theano.config.mode == "FAST_COMPILE"
if (y.type.dtype == up and format1 == 'csc' and format2 == 'dense' if not theano.config.blas.ldflags:
# Usmm should not be inserted, because it relies on BLAS
assert len(topo) == 4, topo
assert isinstance(topo[0].op, theano.sparse.Dot)
assert isinstance(topo[1].op, theano.tensor.DimShuffle)
assert (isinstance(topo[2].op, theano.tensor.Elemwise) and
isinstance(topo[2].op.scalar_op, theano.scalar.Mul))
assert (isinstance(topo[3].op, theano.tensor.Elemwise) and
isinstance(topo[3].op.scalar_op, theano.scalar.Sub))
elif (y.type.dtype == up and format1 == 'csc' and format2 == 'dense'
and not fast_compile and theano.config.cxx and and not fast_compile and theano.config.cxx and
up in ('float32', 'float64')): up in ('float32', 'float64')):
# The op UsmmCscDense should be inserted # The op UsmmCscDense should be inserted
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论