提交 10b4834b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Do not insert SamplingDotCsr when there is no blas

上级 2880cb32
......@@ -1646,6 +1646,9 @@ sampling_dot_csr = SamplingDotCSR()
# register a specialization to replace SamplingDot -> SamplingDotCsr
@gof.local_optimizer([sparse.sampling_dot])
def local_sampling_dot_csr(node):
if not theano.config.blas.ldflags:
# The C implementation of SamplingDotCsr relies on BLAS routines
return
if node.op == sparse.sampling_dot:
x, y, p = node.inputs
if p.type.format == 'csr':
......@@ -1656,6 +1659,7 @@ def local_sampling_dot_csr(node):
return [sparse.CSR(z_data, z_ind, z_ptr, p_shape)]
return False
sparse.register_specialize(local_sampling_dot_csr,
'cxx_only',
name='local_sampling_dot_csr')
......@@ -130,5 +130,11 @@ def test_local_sampling_dot_csr():
sparse.sampling_dot(*inputs),
mode=mode)
assert not any(isinstance(node.op, sparse.SamplingDot) for node
if theano.config.blas.ldflags:
assert not any(isinstance(node.op, sparse.SamplingDot) for node
in f.maker.fgraph.toposort())
else:
# SamplingDotCSR's C implementation needs blas, so it should not
# be inserted
assert not any(isinstance(node.op, sparse.opt.SamplingDotCSR) for node
in f.maker.fgraph.toposort())
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论