提交 7fbb1e3b authored 作者: Frederic Bastien's avatar Frederic Bastien

make a white list of dtype supported on the gpu and not a blacklist.

This make that we now don't support complex* on the gpu as they where not blacklisted.
上级 c007e003
...@@ -99,12 +99,16 @@ gpu_cut_copies.register('cut_gpu_constant_transfers', tensor.opt.constant_foldin ...@@ -99,12 +99,16 @@ gpu_cut_copies.register('cut_gpu_constant_transfers', tensor.opt.constant_foldin
#botering with this useless pattern. #botering with this useless pattern.
compile.optdb['canonicalize'].register('local_cut_gpu_host_gpu', local_cut_gpu_host_gpu, 'fast_run') compile.optdb['canonicalize'].register('local_cut_gpu_host_gpu', local_cut_gpu_host_gpu, 'fast_run')
def float64_in_elemwise(op): #'float64', 'complex128' and 'complex64' are not supported in elemwise on the gpu.
elemwise_cuda_dtype_supported=['float32','uint8','int8','uint16','int16',
'uint32','int32''uint64','int64']
def dtype_in_elemwise_supported(op):
""" """
Return True of the Elemwise op have float64 in it. Return True of the Elemwise op is supported on the gpu.
Return False otherwise. Return False otherwise.
:note: This can happen with the Composite Op. :note: We need to check inside the Composite op.
""" """
def get_all_basic_scalar(composite_op): def get_all_basic_scalar(composite_op):
l=[] l=[]
...@@ -118,9 +122,10 @@ def float64_in_elemwise(op): ...@@ -118,9 +122,10 @@ def float64_in_elemwise(op):
if isinstance(op.scalar_op, theano.scalar.Composite): if isinstance(op.scalar_op, theano.scalar.Composite):
scals = get_all_basic_scalar(op.scalar_op) scals = get_all_basic_scalar(op.scalar_op)
for s in scals: for s in scals:
if any([i.type.dtype=='float64' for i in s.inputs+s.outputs]): if any([i.type.dtype not in elemwise_cuda_dtype_supported
return True for i in s.inputs+s.outputs]):
return False return False
return True
...@@ -131,7 +136,7 @@ def local_gpu_elemwise_0(node): ...@@ -131,7 +136,7 @@ def local_gpu_elemwise_0(node):
"""elemwise(..., host_from_gpu, ...) """elemwise(..., host_from_gpu, ...)
-> host_from_gpu(elemwise(gpu_from_host, ..., gpu_from_host) -> host_from_gpu(elemwise(gpu_from_host, ..., gpu_from_host)
""" """
if isinstance(node.op, tensor.Elemwise) and not float64_in_elemwise(node.op): if isinstance(node.op, tensor.Elemwise) and dtype_in_elemwise_supported(node.op):
if numpy.any([i.owner and isinstance(i.owner.op, HostFromGpu) for i in node.inputs]): if numpy.any([i.owner and isinstance(i.owner.op, HostFromGpu) for i in node.inputs]):
if numpy.all([o.type.dtype == 'float32' for o in node.outputs]): if numpy.all([o.type.dtype == 'float32' for o in node.outputs]):
#don't set any inplace pattern. gpu_insert_inplace_optimizer will do it later #don't set any inplace pattern. gpu_insert_inplace_optimizer will do it later
...@@ -174,7 +179,7 @@ def local_gpu_elemwise_1(node): ...@@ -174,7 +179,7 @@ def local_gpu_elemwise_1(node):
if (host_i.owner and if (host_i.owner and
isinstance(host_i.owner.op, tensor.Elemwise) and isinstance(host_i.owner.op, tensor.Elemwise) and
len(host_i.clients)==1 and len(host_i.clients)==1 and
not float64_in_elemwise(node.op)): dtype_in_elemwise_supported(node.op)):
elemwise_node = host_i.owner elemwise_node = host_i.owner
#don't set any inplace pattern. gpu_insert_inplace_optimizer will do it later #don't set any inplace pattern. gpu_insert_inplace_optimizer will do it later
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论