提交 184ea129 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make the comparison on the CPU.

上级 31354829
...@@ -4,6 +4,7 @@ from six import iteritems ...@@ -4,6 +4,7 @@ from six import iteritems
import warnings import warnings
import theano import theano
from theano.tensor.type import TensorType
from theano.tensor.var import _tensor_py_operators from theano.tensor.var import _tensor_py_operators
from theano import Type, Variable, Constant, tensor, config, scalar from theano import Type, Variable, Constant, tensor, config, scalar
from theano.compile import SharedVariable from theano.compile import SharedVariable
...@@ -201,17 +202,14 @@ class GpuArrayType(Type): ...@@ -201,17 +202,14 @@ class GpuArrayType(Type):
context=self.context) context=self.context)
else: else:
if not hasattr(data, 'dtype'): if not hasattr(data, 'dtype'):
# This is to convert objects that don't have a dtype converted_data = theano._asarray(data, self.dtype)
# (like lists). We anticipate that the type below # We use the `values_eq` static function from TensorType
# will match and we pass copy=False so it won't make a # to handle NaN values.
# second object on the GPU. if TensorType.values_eq(numpy.asarray(data),
converted_data = gpuarray.array(data, dtype=self.dtype, context=self.context) converted_data,
g_data = gpuarray.array(data, copy=False, context=self.context) force_same_dtype=False):
if self.values_eq(g_data,
converted_data,
force_same_dtype=False):
data = converted_data data = converted_data
data = gpuarray.array(data, context=self.context)
up_dtype = scalar.upcast(self.dtype, data.dtype) up_dtype = scalar.upcast(self.dtype, data.dtype)
if up_dtype == self.dtype: if up_dtype == self.dtype:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论