提交 aedba364 authored 作者: f0k's avatar f0k

Insert GpuSolve only when CULA is available

上级 ee4c4e21
...@@ -48,7 +48,7 @@ from theano.sandbox.cuda.blas import ( ...@@ -48,7 +48,7 @@ from theano.sandbox.cuda.blas import (
GpuCorr3dMM, GpuCorr3dMM_gradInputs, GpuCorr3dMM_gradWeights) GpuCorr3dMM, GpuCorr3dMM_gradInputs, GpuCorr3dMM_gradWeights)
from theano.sandbox.cuda.blas import gpu_gemv_inplace from theano.sandbox.cuda.blas import gpu_gemv_inplace
from theano.sandbox.cuda.cula import gpu_solve from theano.sandbox.cuda.cula import gpu_solve, cula_available
from theano.sandbox.cuda.blas import gpu_gemv_no_inplace from theano.sandbox.cuda.blas import gpu_gemv_no_inplace
from theano.sandbox.cuda.blas import gpu_ger_inplace from theano.sandbox.cuda.blas import gpu_ger_inplace
...@@ -702,6 +702,8 @@ def local_gpu_solve(node): ...@@ -702,6 +702,8 @@ def local_gpu_solve(node):
CpuSolve(host_from_gpu) -> host_from_gpu(GpuSolve) CpuSolve(host_from_gpu) -> host_from_gpu(GpuSolve)
""" """
if not cula_available:
return
if node.outputs[0].dtype != 'float32': if node.outputs[0].dtype != 'float32':
return return
if isinstance(node.op, GpuFromHost): if isinstance(node.op, GpuFromHost):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论