提交 9e43ad87 authored 作者: Frederic's avatar Frederic

Raise a better error if scikits.cuda isn't installed

上级 44470332
......@@ -48,6 +48,12 @@ class ScikitsCudaOp(GpuOp):
return theano.Apply(self, [inp], [self.output_type(inp)()])
def make_thunk(self, node, storage_map, _, _2):
if not scikits_cuda_available:
raise RuntimeError(
"scikits.cuda is needed for all GPU fft implementation,"
" including fftconv.")
class CuFFTOp(ScikitsCudaOp):
def output_type(self, inp):
......@@ -56,6 +62,8 @@ class CuFFTOp(ScikitsCudaOp):
broadcastable=[False] * (inp.type.ndim + 1))
def make_thunk(self, node, storage_map, _, _2):
super(CuFFTOp, self).make_thunk(node, storage_map, _, _2)
from theano.misc.pycuda_utils import to_gpuarray
inputs = [storage_map[v] for v in node.inputs]
outputs = [storage_map[v] for v in node.outputs]
......@@ -111,6 +119,8 @@ class CuIFFTOp(ScikitsCudaOp):
broadcastable=[False] * (inp.type.ndim - 1))
def make_thunk(self, node, storage_map, _, _2):
super(CuIFFTOp, self).make_thunk(node, storage_map, _, _2)
from theano.misc.pycuda_utils import to_gpuarray
inputs = [storage_map[v] for v in node.inputs]
outputs = [storage_map[v] for v in node.outputs]
......@@ -300,6 +310,8 @@ class BatchedComplexDotOp(ScikitsCudaOp):
return CudaNdarrayType(broadcastable=[False] * inp.type.ndim)
def make_thunk(self, node, storage_map, _, _2):
super(BatchedComplexDotOp, self).make_thunk(node, storage_map, _, _2)
inputs = [storage_map[v] for v in node.inputs]
outputs = [storage_map[v] for v in node.outputs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论