提交 5ae8bc0c authored 作者: James Bergstra's avatar James Bergstra

fixes to make the behaviour of the cuda shared constructor more strict on…

fixes to make the behaviour of the cuda shared constructor more strict on accepting only float32 ndarrays
上级 91aa8472
...@@ -94,7 +94,7 @@ if enable_cuda: ...@@ -94,7 +94,7 @@ if enable_cuda:
from theano.sandbox.cuda.var import (CudaNdarrayVariable, from theano.sandbox.cuda.var import (CudaNdarrayVariable,
CudaNdarrayConstant, CudaNdarrayConstant,
CudaNdarraySharedVariable, CudaNdarraySharedVariable,
shared_constructor) float32_shared_constructor)
import basic_ops import basic_ops
from basic_ops import (GpuFromHost, HostFromGpu, GpuElemwise, from basic_ops import (GpuFromHost, HostFromGpu, GpuElemwise,
...@@ -139,7 +139,7 @@ def handle_shared_float32(tf): ...@@ -139,7 +139,7 @@ def handle_shared_float32(tf):
""" """
if tf: if tf:
import theano.compile import theano.compile
theano.compile.shared_constructor(shared_constructor) theano.compile.shared_constructor(float32_shared_constructor)
else: else:
raise NotImplementedError('removing our handler') raise NotImplementedError('removing our handler')
......
...@@ -65,7 +65,7 @@ class CudaNdarraySharedVariable(SharedVariable, _operators): ...@@ -65,7 +65,7 @@ class CudaNdarraySharedVariable(SharedVariable, _operators):
CudaNdarrayType.SharedVariable = CudaNdarraySharedVariable CudaNdarrayType.SharedVariable = CudaNdarraySharedVariable
def shared_constructor(value, name, strict=False, broadcastable=None): def cuda_shared_constructor(value, name, strict=False, broadcastable=None):
"""SharedVariable Constructor for TensorType""" """SharedVariable Constructor for TensorType"""
#TODO: what should strict mean in this context, since we always have to make a copy? #TODO: what should strict mean in this context, since we always have to make a copy?
...@@ -82,17 +82,31 @@ def shared_constructor(value, name, strict=False, broadcastable=None): ...@@ -82,17 +82,31 @@ def shared_constructor(value, name, strict=False, broadcastable=None):
if broadcastable is None: if broadcastable is None:
broadcastable = (False,) * len(value.shape) broadcastable = (False,) * len(value.shape)
type = CudaNdarrayType(broadcastable=broadcastable) type = CudaNdarrayType(broadcastable=broadcastable)
return CudaNdarraySharedVariable(type=type, value=_value, name=name, strict=strict) print "trying to return?"
try:
rval = CudaNdarraySharedVariable(type=type, value=_value, name=name, strict=strict)
except Exception, e:
print "ERROR", e
raise
return rval
def float32_shared_constructor(value, name, strict=False, broadcastable=None):
"""SharedVariable Constructor for TensorType"""
def unset_shared_for_numpy(): # if value isn't a float32 ndarray, then raise
raise NotImplementedError() if not isinstance(value, numpy.ndarray):
raise TypeError('ndarray required')
if value.dtype.num != CudaNdarrayType.typenum:
raise TypeError('float32 ndarray required')
def set_shared_for_numpy(): if broadcastable is None:
""" broadcastable = (False,) * len(value.shape)
Set the gpu_tensor_constructor as the handler for ndarray type = CudaNdarrayType(broadcastable=broadcastable)
""" deviceval = type_support_filter(value, broadcastable, False)
shared_constructor(shared_constructor) try:
rval = CudaNdarraySharedVariable(type=type, value=deviceval, name=name, strict=strict)
except Exception, e:
print "ERROR", e
raise
return rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论