提交 54e3a06e authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add filter_variable (which was not present for unknown reasons).

上级 c13192c0
...@@ -63,6 +63,32 @@ class GpuArrayType(Type): ...@@ -63,6 +63,32 @@ class GpuArrayType(Type):
" dimension.", shp, self.broadcastable) " dimension.", shp, self.broadcastable)
return data return data
def filter_variable(self, other):
if hasattr(other, '_as_GpuArrayVariable'):
other = other._as_GpuArrayVariable()
if not isinstance(other, Variable):
other = self.Constant(type=self, data=other)
if other.type == self:
return other
if not isinstance(other.type, tensor.TensorType):
raise TypeError('Incompatible type', (self, other.type))
if (other.type.dtype != self.dtype):
raise TypeError('Incompatible dtype', (self.dtype,
other.type.dtype))
if other.type.ndim != self.ndim:
raise TypeError('Incompatible number of dimensions.'
' Expected %d, got %d.' % (self.ndim, other.ndim))
if other.type.broadcastable != self.broadcastable:
raise TypeError('Incompatible broadcastable dimensions.'
' Expected %s, got %s.' %
(str(other.type.broadcastable),
str(self.broadcastable)))
return theano.sandbox.gpuarray.basic_ops.gpu_from_host(other)
@staticmethod @staticmethod
def values_eq(a, b): def values_eq(a, b):
if a.shape != b.shape: if a.shape != b.shape:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论