提交 10b4834b authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Do not insert SamplingDotCsr when there is no blas

上级 2880cb32
...@@ -1646,6 +1646,9 @@ sampling_dot_csr = SamplingDotCSR() ...@@ -1646,6 +1646,9 @@ sampling_dot_csr = SamplingDotCSR()
# register a specialization to replace SamplingDot -> SamplingDotCsr # register a specialization to replace SamplingDot -> SamplingDotCsr
@gof.local_optimizer([sparse.sampling_dot]) @gof.local_optimizer([sparse.sampling_dot])
def local_sampling_dot_csr(node): def local_sampling_dot_csr(node):
if not theano.config.blas.ldflags:
# The C implementation of SamplingDotCsr relies on BLAS routines
return
if node.op == sparse.sampling_dot: if node.op == sparse.sampling_dot:
x, y, p = node.inputs x, y, p = node.inputs
if p.type.format == 'csr': if p.type.format == 'csr':
...@@ -1656,6 +1659,7 @@ def local_sampling_dot_csr(node): ...@@ -1656,6 +1659,7 @@ def local_sampling_dot_csr(node):
return [sparse.CSR(z_data, z_ind, z_ptr, p_shape)] return [sparse.CSR(z_data, z_ind, z_ptr, p_shape)]
return False return False
sparse.register_specialize(local_sampling_dot_csr, sparse.register_specialize(local_sampling_dot_csr,
'cxx_only', 'cxx_only',
name='local_sampling_dot_csr') name='local_sampling_dot_csr')
...@@ -130,5 +130,11 @@ def test_local_sampling_dot_csr(): ...@@ -130,5 +130,11 @@ def test_local_sampling_dot_csr():
sparse.sampling_dot(*inputs), sparse.sampling_dot(*inputs),
mode=mode) mode=mode)
assert not any(isinstance(node.op, sparse.SamplingDot) for node if theano.config.blas.ldflags:
assert not any(isinstance(node.op, sparse.SamplingDot) for node
in f.maker.fgraph.toposort())
else:
# SamplingDotCSR's C implementation needs blas, so it should not
# be inserted
assert not any(isinstance(node.op, sparse.opt.SamplingDotCSR) for node
in f.maker.fgraph.toposort()) in f.maker.fgraph.toposort())
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论