提交 7fbb1e3b authored 作者: Frederic Bastien's avatar Frederic Bastien

make a white list of dtype supported on the gpu and not a blacklist.

This make that we now don't support complex* on the gpu as they where not blacklisted.
上级 c007e003
......@@ -99,12 +99,16 @@ gpu_cut_copies.register('cut_gpu_constant_transfers', tensor.opt.constant_foldin
#botering with this useless pattern.
compile.optdb['canonicalize'].register('local_cut_gpu_host_gpu', local_cut_gpu_host_gpu, 'fast_run')
def float64_in_elemwise(op):
#'float64', 'complex128' and 'complex64' are not supported in elemwise on the gpu.
elemwise_cuda_dtype_supported=['float32','uint8','int8','uint16','int16',
'uint32','int32''uint64','int64']
def dtype_in_elemwise_supported(op):
"""
Return True of the Elemwise op have float64 in it.
Return True of the Elemwise op is supported on the gpu.
Return False otherwise.
:note: This can happen with the Composite Op.
:note: We need to check inside the Composite op.
"""
def get_all_basic_scalar(composite_op):
l=[]
......@@ -118,9 +122,10 @@ def float64_in_elemwise(op):
if isinstance(op.scalar_op, theano.scalar.Composite):
scals = get_all_basic_scalar(op.scalar_op)
for s in scals:
if any([i.type.dtype=='float64' for i in s.inputs+s.outputs]):
return True
return False
if any([i.type.dtype not in elemwise_cuda_dtype_supported
for i in s.inputs+s.outputs]):
return False
return True
......@@ -131,7 +136,7 @@ def local_gpu_elemwise_0(node):
"""elemwise(..., host_from_gpu, ...)
-> host_from_gpu(elemwise(gpu_from_host, ..., gpu_from_host)
"""
if isinstance(node.op, tensor.Elemwise) and not float64_in_elemwise(node.op):
if isinstance(node.op, tensor.Elemwise) and dtype_in_elemwise_supported(node.op):
if numpy.any([i.owner and isinstance(i.owner.op, HostFromGpu) for i in node.inputs]):
if numpy.all([o.type.dtype == 'float32' for o in node.outputs]):
#don't set any inplace pattern. gpu_insert_inplace_optimizer will do it later
......@@ -174,7 +179,7 @@ def local_gpu_elemwise_1(node):
if (host_i.owner and
isinstance(host_i.owner.op, tensor.Elemwise) and
len(host_i.clients)==1 and
not float64_in_elemwise(node.op)):
dtype_in_elemwise_supported(node.op)):
elemwise_node = host_i.owner
#don't set any inplace pattern. gpu_insert_inplace_optimizer will do it later
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论