提交 3d6cd434 authored 作者: Frederic Bastien's avatar Frederic Bastien

allow to create a CudaNdarraySharedVariable from a CudaNdarray.

上级 b6f39866
......@@ -819,6 +819,11 @@ def test_shared_float32():
# Unregister
del theano.shared.constructors[-1]
def test_shared_cudandarray():
'''Test that we can create a CudaNdarraySharedVariable from a CudaNdarray'''
a = cuda.shared_constructor(cuda.CudaNdarray.zeros((2,3)))
assert isinstance(a.type, tcn.CudaNdarrayType)
import theano.tensor.tests.test_sharedvar
test_shared_options = theano.tensor.tests.test_sharedvar.makeSharedTester(
......
......@@ -84,7 +84,7 @@ CudaNdarrayType.SharedVariable = CudaNdarraySharedVariable
def cuda_shared_constructor(value, name=None, strict=False,
allow_downcast=None, borrow=False, broadcastable=None):
"""SharedVariable Constructor for TensorType"""
"""SharedVariable Constructor for CudaNdarrayType"""
# THIS CONSTRUCTOR TRIES TO CAST VALUE TO A FLOAT32, WHICH THEN GOES ONTO THE CARD
# SO INT shared vars, float64 shared vars, etc. all end up on the card.
......@@ -115,22 +115,29 @@ def cuda_shared_constructor(value, name=None, strict=False,
def float32_shared_constructor(value, name=None, strict=False,
allow_downcast=None, borrow=False, broadcastable=None):
"""SharedVariable Constructor for TensorType"""
"""SharedVariable Constructor for CudaNdarrayType from numpy.ndarray or CudaNdarray"""
# if value isn't a float32 ndarray, then raise
if not isinstance(value, numpy.ndarray):
raise TypeError('ndarray required')
if value.dtype.num != CudaNdarrayType.typenum:
# if value isn't a float32 ndarray, or a CudaNdarray then raise
if not isinstance(value, (numpy.ndarray, theano.sandbox.cuda.CudaNdarray)):
raise TypeError('ndarray or CudaNdarray required')
if isinstance(value, numpy.ndarray) and value.dtype.num != CudaNdarrayType.typenum:
raise TypeError('float32 ndarray required')
if broadcastable is None:
broadcastable = (False,) * len(value.shape)
type = CudaNdarrayType(broadcastable=broadcastable)
deviceval = type_support_filter(value, broadcastable, False, None)
if isinstance(value, theano.sandbox.cuda.CudaNdarray):
if borrow:
deviceval = value
else:
deviceval = value.copy()
else:
deviceval = type_support_filter(value, broadcastable, False, None)
try:
rval = CudaNdarraySharedVariable(type=type, value=deviceval, name=name, strict=strict)
except Exception, e:
print "ERROR", e
raise
return rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论