提交 b2f62a1e authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix type conversion in TensorType and expand it to other similar types.

上级 682385f7
......@@ -277,6 +277,11 @@ class PureType(object):
# a Constant of the appropriate Type.
other = self.Constant(type=self, data=other)
if other.type != self:
other2 = self.convert_variable(other)
if other2 is not None:
return other2
if other.type != self:
raise TypeError(
'Cannot convert Type %(othertype)s '
......
......@@ -146,10 +146,14 @@ class CudaNdarrayType(Type):
raise TypeError('Incompatible number of dimensions.'
' Expected %d, got %d.' % (self.ndim, other.ndim))
if other.type.broadcastable != self.broadcastable:
raise TypeError('Incompatible broadcastable dimensions.'
' Expected %s, got %s.' %
(str(other.type.broadcastable),
str(self.broadcastable)))
type2 = other.type.clone(broadcastable=self.broadcastable)
other2 = type2.convert_variable(other)
if other2 is None:
raise TypeError('Incompatible broadcastable dimensions.'
' Expected %s, got %s.' %
(str(other.type.broadcastable),
str(self.broadcastable)))
other = other2
return theano.sandbox.cuda.basic_ops.GpuFromHost()(other)
......
......@@ -108,10 +108,14 @@ class GpuArrayType(Type):
raise TypeError('Incompatible number of dimensions.'
' Expected %d, got %d.' % (self.ndim, other.ndim))
if other.type.broadcastable != self.broadcastable:
raise TypeError('Incompatible broadcastable dimensions.'
' Expected %s, got %s.' %
(str(other.type.broadcastable),
str(self.broadcastable)))
type2 = other.type.clone(broadcastable=self.broadcastable)
other2 = type2.convert_variable(other)
if other2 is None:
raise TypeError('Incompatible broadcastable dimensions.'
' Expected %s, got %s.' %
(str(other.type.broadcastable),
str(self.broadcastable)))
other = other2
return theano.sandbox.gpuarray.basic_ops.gpu_from_host(other)
......
......@@ -209,9 +209,9 @@ class TensorType(Type):
return other
# Attempt safe broadcast conversion.
other = self.convert_variable(other)
if other.type == self:
return other
other2 = self.convert_variable(other)
if other2 is not None and other2.type == self:
return other2
raise TypeError(
'Cannot convert Type %(othertype)s '
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论