提交 b2f62a1e authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix type conversion in TensorType and expand it to other similar types.

上级 682385f7
...@@ -277,6 +277,11 @@ class PureType(object): ...@@ -277,6 +277,11 @@ class PureType(object):
# a Constant of the appropriate Type. # a Constant of the appropriate Type.
other = self.Constant(type=self, data=other) other = self.Constant(type=self, data=other)
if other.type != self:
other2 = self.convert_variable(other)
if other2 is not None:
return other2
if other.type != self: if other.type != self:
raise TypeError( raise TypeError(
'Cannot convert Type %(othertype)s ' 'Cannot convert Type %(othertype)s '
......
...@@ -146,10 +146,14 @@ class CudaNdarrayType(Type): ...@@ -146,10 +146,14 @@ class CudaNdarrayType(Type):
raise TypeError('Incompatible number of dimensions.' raise TypeError('Incompatible number of dimensions.'
' Expected %d, got %d.' % (self.ndim, other.ndim)) ' Expected %d, got %d.' % (self.ndim, other.ndim))
if other.type.broadcastable != self.broadcastable: if other.type.broadcastable != self.broadcastable:
raise TypeError('Incompatible broadcastable dimensions.' type2 = other.type.clone(broadcastable=self.broadcastable)
' Expected %s, got %s.' % other2 = type2.convert_variable(other)
(str(other.type.broadcastable), if other2 is None:
str(self.broadcastable))) raise TypeError('Incompatible broadcastable dimensions.'
' Expected %s, got %s.' %
(str(other.type.broadcastable),
str(self.broadcastable)))
other = other2
return theano.sandbox.cuda.basic_ops.GpuFromHost()(other) return theano.sandbox.cuda.basic_ops.GpuFromHost()(other)
......
...@@ -108,10 +108,14 @@ class GpuArrayType(Type): ...@@ -108,10 +108,14 @@ class GpuArrayType(Type):
raise TypeError('Incompatible number of dimensions.' raise TypeError('Incompatible number of dimensions.'
' Expected %d, got %d.' % (self.ndim, other.ndim)) ' Expected %d, got %d.' % (self.ndim, other.ndim))
if other.type.broadcastable != self.broadcastable: if other.type.broadcastable != self.broadcastable:
raise TypeError('Incompatible broadcastable dimensions.' type2 = other.type.clone(broadcastable=self.broadcastable)
' Expected %s, got %s.' % other2 = type2.convert_variable(other)
(str(other.type.broadcastable), if other2 is None:
str(self.broadcastable))) raise TypeError('Incompatible broadcastable dimensions.'
' Expected %s, got %s.' %
(str(other.type.broadcastable),
str(self.broadcastable)))
other = other2
return theano.sandbox.gpuarray.basic_ops.gpu_from_host(other) return theano.sandbox.gpuarray.basic_ops.gpu_from_host(other)
......
...@@ -209,9 +209,9 @@ class TensorType(Type): ...@@ -209,9 +209,9 @@ class TensorType(Type):
return other return other
# Attempt safe broadcast conversion. # Attempt safe broadcast conversion.
other = self.convert_variable(other) other2 = self.convert_variable(other)
if other.type == self: if other2 is not None and other2.type == self:
return other return other2
raise TypeError( raise TypeError(
'Cannot convert Type %(othertype)s ' 'Cannot convert Type %(othertype)s '
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论