提交 a51c01ea authored 作者: Frederic Bastien's avatar Frederic Bastien

Don't put scalar in shared var on the GPU by default.

上级 29575d6e
......@@ -29,7 +29,7 @@ from theano.tensor.nnet.abstract_conv import (AbstractConv2d,
from theano.tests.breakpoint import PdbBreakpoint
from .type import (GpuArrayType, GpuArrayConstant, get_context,
ContextNotDefined)
ContextNotDefined, move_to_gpu)
from .basic_ops import (as_gpuarray_variable, infer_context_name,
host_from_gpu, GpuToGpu,
HostFromGpu, GpuFromHost,
......@@ -242,9 +242,8 @@ class InputToGpuOptimizer(Optimizer):
target = getattr(input.tag, 'target', None)
if target == 'cpu':
continue
# Do not move *int* scalar to the GPU.
if (isinstance(input.type, tensor.TensorType) and
input.ndim == 0 and 'int' in input.dtype):
not move_to_gpu(input)):
continue
try:
......@@ -301,9 +300,7 @@ class GraphToGPU(Optimizer):
target = infer_context_name(*fgraph.inputs)
for i in fgraph.inputs:
# Do not move *int* scalar to the GPU.
if (isinstance(i.type, tensor.TensorType) and
(i.ndim > 0 or 'int' not in i.dtype) and
"complex" not in i.dtype):
if isinstance(i.type, tensor.TensorType) and move_to_gpu(i):
mapping[i] = i.transfer(getattr(i.tag, 'target', target))
else:
mapping[i] = i
......
......@@ -22,6 +22,26 @@ except ImportError:
_context_reg = {}
def move_to_gpu(data):
"""
Do we want to move this computation to the GPU?
Currently, we don't move complex and scalar int.
Parameters
----------
data : numpy.ndarray or TensorVariable
(it must have dtype and ndim parameter)
"""
# We don't support complex on the GPU
if str(data.dtype) in tensor.basic.complex_dtypes:
return False
# We don't want scalar int on the GPU.
if data.ndim == 0 and str(data.dtype) in tensor.basic.discrete_dtypes:
return False
return True
class ContextNotDefined(ValueError):
pass
......@@ -572,6 +592,8 @@ def gpuarray_shared_constructor(value, name=None, strict=False,
See :func:`theano.shared`.
"""
if target is None and not move_to_gpu(value):
raise TypeError('We do not move that data by deault to the GPU')
if target == 'gpu' or target == 'cpu':
raise TypeError('not for me')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论