提交 6f636a8d authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix a test where the gpuarray back-end didn't downcasted python float to float32.

上级 50572e99
...@@ -10,7 +10,7 @@ from theano.misc.pkl_utils import CompatUnpickler ...@@ -10,7 +10,7 @@ from theano.misc.pkl_utils import CompatUnpickler
from .config import test_ctx_name from .config import test_ctx_name
from .test_basic_ops import rand_gpuarray from .test_basic_ops import rand_gpuarray
from ..type import GpuArrayType from ..type import GpuArrayType, gpuarray_shared_constructor
import pygpu import pygpu
...@@ -47,8 +47,13 @@ def test_specify_shape(): ...@@ -47,8 +47,13 @@ def test_specify_shape():
def test_filter_float(): def test_filter_float():
s = theano.shared(numpy.array(0.0, dtype='float32'), target=test_ctx_name) theano.compile.shared_constructor(gpuarray_shared_constructor)
theano.function([], updates=[(s, 0.0)]) try:
s = theano.shared(numpy.array(0.0, dtype='float32'),
target=test_ctx_name)
theano.function([], updates=[(s, 0.0)])
finally:
del theano.compile.sharedvalue.shared.constructors[-1]
def test_unpickle_gpuarray_as_numpy_ndarray_flag0(): def test_unpickle_gpuarray_as_numpy_ndarray_flag0():
......
...@@ -205,7 +205,13 @@ class GpuArrayType(Type): ...@@ -205,7 +205,13 @@ class GpuArrayType(Type):
# (like lists). We anticipate that the type below # (like lists). We anticipate that the type below
# will match and we pass copy=False so it won't make a # will match and we pass copy=False so it won't make a
# second object on the GPU. # second object on the GPU.
data = gpuarray.array(data, copy=False, context=self.context) converted_data = gpuarray.array(data, dtype=self.dtype, context=self.context)
g_data = gpuarray.array(data, copy=False, context=self.context)
if self.values_eq(g_data,
converted_data,
force_same_dtype=False):
data = converted_data
up_dtype = scalar.upcast(self.dtype, data.dtype) up_dtype = scalar.upcast(self.dtype, data.dtype)
if up_dtype == self.dtype: if up_dtype == self.dtype:
...@@ -263,10 +269,10 @@ class GpuArrayType(Type): ...@@ -263,10 +269,10 @@ class GpuArrayType(Type):
return GpuFromHost(self.context_name)(other) return GpuFromHost(self.context_name)(other)
@staticmethod @staticmethod
def values_eq(a, b): def values_eq(a, b, force_same_dtype=True):
if a.shape != b.shape: if a.shape != b.shape:
return False return False
if a.typecode != b.typecode: if force_same_dtype and a.typecode != b.typecode:
return False return False
a_eq_b = numpy.asarray(compare(a, '==', b)) a_eq_b = numpy.asarray(compare(a, '==', b))
if a_eq_b.all(): if a_eq_b.all():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论