提交 6f98d16a authored 作者: nouiz's avatar nouiz

Merge pull request #963 from lamblin/fix_cxx_noblas

Fix tests when cxx is available, but not blas
...@@ -675,7 +675,8 @@ local_usmm = gof.opt.PatternSub( ...@@ -675,7 +675,8 @@ local_usmm = gof.opt.PatternSub(
(theano.tensor.sub, 'z', (theano.tensor.sub, 'z',
(theano.tensor.mul, (theano.tensor.mul,
{'pattern': 'alpha', {'pattern': 'alpha',
'constraint': lambda expr: numpy.all(expr.type.broadcastable)}, 'constraint': lambda expr: (numpy.all(expr.type.broadcastable) and
theano.config.blas.ldflags)},
(sparse._dot, 'x', 'y'))), (sparse._dot, 'x', 'y'))),
(usmm, (theano.tensor.neg, 'alpha'), 'x', 'y', 'z')) (usmm, (theano.tensor.neg, 'alpha'), 'x', 'y', 'z'))
register_specialize(local_usmm, name="local_usmm") register_specialize(local_usmm, name="local_usmm")
...@@ -1646,6 +1647,9 @@ sampling_dot_csr = SamplingDotCSR() ...@@ -1646,6 +1647,9 @@ sampling_dot_csr = SamplingDotCSR()
# register a specialization to replace SamplingDot -> SamplingDotCsr # register a specialization to replace SamplingDot -> SamplingDotCsr
@gof.local_optimizer([sparse.sampling_dot]) @gof.local_optimizer([sparse.sampling_dot])
def local_sampling_dot_csr(node): def local_sampling_dot_csr(node):
if not theano.config.blas.ldflags:
# The C implementation of SamplingDotCsr relies on BLAS routines
return
if node.op == sparse.sampling_dot: if node.op == sparse.sampling_dot:
x, y, p = node.inputs x, y, p = node.inputs
if p.type.format == 'csr': if p.type.format == 'csr':
...@@ -1656,6 +1660,7 @@ def local_sampling_dot_csr(node): ...@@ -1656,6 +1660,7 @@ def local_sampling_dot_csr(node):
return [sparse.CSR(z_data, z_ind, z_ptr, p_shape)] return [sparse.CSR(z_data, z_ind, z_ptr, p_shape)]
return False return False
sparse.register_specialize(local_sampling_dot_csr, sparse.register_specialize(local_sampling_dot_csr,
'cxx_only', 'cxx_only',
name='local_sampling_dot_csr') name='local_sampling_dot_csr')
...@@ -1204,7 +1204,16 @@ class UsmmTests(unittest.TestCase): ...@@ -1204,7 +1204,16 @@ class UsmmTests(unittest.TestCase):
fast_compile = theano.config.mode == "FAST_COMPILE" fast_compile = theano.config.mode == "FAST_COMPILE"
if (y.type.dtype == up and format1 == 'csc' and format2 == 'dense' if not theano.config.blas.ldflags:
# Usmm should not be inserted, because it relies on BLAS
assert len(topo) == 4, topo
assert isinstance(topo[0].op, theano.sparse.Dot)
assert isinstance(topo[1].op, theano.tensor.DimShuffle)
assert (isinstance(topo[2].op, theano.tensor.Elemwise) and
isinstance(topo[2].op.scalar_op, theano.scalar.Mul))
assert (isinstance(topo[3].op, theano.tensor.Elemwise) and
isinstance(topo[3].op.scalar_op, theano.scalar.Sub))
elif (y.type.dtype == up and format1 == 'csc' and format2 == 'dense'
and not fast_compile and theano.config.cxx and and not fast_compile and theano.config.cxx and
up in ('float32', 'float64')): up in ('float32', 'float64')):
# The op UsmmCscDense should be inserted # The op UsmmCscDense should be inserted
......
...@@ -130,5 +130,11 @@ def test_local_sampling_dot_csr(): ...@@ -130,5 +130,11 @@ def test_local_sampling_dot_csr():
sparse.sampling_dot(*inputs), sparse.sampling_dot(*inputs),
mode=mode) mode=mode)
if theano.config.blas.ldflags:
assert not any(isinstance(node.op, sparse.SamplingDot) for node assert not any(isinstance(node.op, sparse.SamplingDot) for node
in f.maker.fgraph.toposort()) in f.maker.fgraph.toposort())
else:
# SamplingDotCSR's C implementation needs blas, so it should not
# be inserted
assert not any(isinstance(node.op, sparse.opt.SamplingDotCSR) for node
in f.maker.fgraph.toposort())
...@@ -51,7 +51,7 @@ class Conv3D(theano.Op): ...@@ -51,7 +51,7 @@ class Conv3D(theano.Op):
return "Conv3D" return "Conv3D"
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
def make_node(self, V, W, b, d): def make_node(self, V, W, b, d):
...@@ -338,7 +338,8 @@ class Conv3D(theano.Op): ...@@ -338,7 +338,8 @@ class Conv3D(theano.Op):
#if the data types are not mixed, we can insert special case optimizations based on BLAS #if the data types are not mixed, we can insert special case optimizations based on BLAS
VV, WV, bv, dv = node.inputs VV, WV, bv, dv = node.inputs
HV = node.outputs[0] HV = node.outputs[0]
if VV.dtype == WV.dtype and HV.dtype == VV.dtype: if (theano.config.blas.ldflags and
VV.dtype == WV.dtype and HV.dtype == VV.dtype):
if VV.dtype == 'float64': if VV.dtype == 'float64':
gemv = 'dgemv_' gemv = 'dgemv_'
elif VV.dtype == 'float32': elif VV.dtype == 'float32':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论