提交 6f98d16a authored 作者: nouiz's avatar nouiz

Merge pull request #963 from lamblin/fix_cxx_noblas

Fix tests when cxx is available, but not blas
......@@ -675,7 +675,8 @@ local_usmm = gof.opt.PatternSub(
(theano.tensor.sub, 'z',
(theano.tensor.mul,
{'pattern': 'alpha',
'constraint': lambda expr: numpy.all(expr.type.broadcastable)},
'constraint': lambda expr: (numpy.all(expr.type.broadcastable) and
theano.config.blas.ldflags)},
(sparse._dot, 'x', 'y'))),
(usmm, (theano.tensor.neg, 'alpha'), 'x', 'y', 'z'))
register_specialize(local_usmm, name="local_usmm")
......@@ -1646,6 +1647,9 @@ sampling_dot_csr = SamplingDotCSR()
# register a specialization to replace SamplingDot -> SamplingDotCsr
@gof.local_optimizer([sparse.sampling_dot])
def local_sampling_dot_csr(node):
if not theano.config.blas.ldflags:
# The C implementation of SamplingDotCsr relies on BLAS routines
return
if node.op == sparse.sampling_dot:
x, y, p = node.inputs
if p.type.format == 'csr':
......@@ -1656,6 +1660,7 @@ def local_sampling_dot_csr(node):
return [sparse.CSR(z_data, z_ind, z_ptr, p_shape)]
return False
sparse.register_specialize(local_sampling_dot_csr,
'cxx_only',
name='local_sampling_dot_csr')
......@@ -1204,7 +1204,16 @@ class UsmmTests(unittest.TestCase):
fast_compile = theano.config.mode == "FAST_COMPILE"
if (y.type.dtype == up and format1 == 'csc' and format2 == 'dense'
if not theano.config.blas.ldflags:
# Usmm should not be inserted, because it relies on BLAS
assert len(topo) == 4, topo
assert isinstance(topo[0].op, theano.sparse.Dot)
assert isinstance(topo[1].op, theano.tensor.DimShuffle)
assert (isinstance(topo[2].op, theano.tensor.Elemwise) and
isinstance(topo[2].op.scalar_op, theano.scalar.Mul))
assert (isinstance(topo[3].op, theano.tensor.Elemwise) and
isinstance(topo[3].op.scalar_op, theano.scalar.Sub))
elif (y.type.dtype == up and format1 == 'csc' and format2 == 'dense'
and not fast_compile and theano.config.cxx and
up in ('float32', 'float64')):
# The op UsmmCscDense should be inserted
......
......@@ -130,5 +130,11 @@ def test_local_sampling_dot_csr():
sparse.sampling_dot(*inputs),
mode=mode)
assert not any(isinstance(node.op, sparse.SamplingDot) for node
if theano.config.blas.ldflags:
assert not any(isinstance(node.op, sparse.SamplingDot) for node
in f.maker.fgraph.toposort())
else:
# SamplingDotCSR's C implementation needs blas, so it should not
# be inserted
assert not any(isinstance(node.op, sparse.opt.SamplingDotCSR) for node
in f.maker.fgraph.toposort())
......@@ -51,7 +51,7 @@ class Conv3D(theano.Op):
return "Conv3D"
def c_code_cache_version(self):
return (2,)
return (3,)
def make_node(self, V, W, b, d):
......@@ -338,7 +338,8 @@ class Conv3D(theano.Op):
#if the data types are not mixed, we can insert special case optimizations based on BLAS
VV, WV, bv, dv = node.inputs
HV = node.outputs[0]
if VV.dtype == WV.dtype and HV.dtype == VV.dtype:
if (theano.config.blas.ldflags and
VV.dtype == WV.dtype and HV.dtype == VV.dtype):
if VV.dtype == 'float64':
gemv = 'dgemv_'
elif VV.dtype == 'float32':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论