提交 f13ab8b7 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add a "filter_variable" mechanism in Type.

It should work similarly to filter(), but on symbolic variables, returning equivalent variables of the current Type if they are compatible.
上级 7dd0276b
...@@ -228,9 +228,35 @@ class PureType(object): ...@@ -228,9 +228,35 @@ class PureType(object):
# filter() This is to allow reusing the old allocated memory. As # filter() This is to allow reusing the old allocated memory. As
# of this writing this is used only when we transfer new data to a # of this writing this is used only when we transfer new data to a
# shared variable on the gpu. # shared variable on the gpu.
#def filter_inplace(value, storage, strict=False, allow_downcast=None) #def filter_inplace(value, storage, strict=False, allow_downcast=None)
def filter_variable(self, other):
"""Convert a symbolic variable into this Type, if compatible.
For the moment, the only Types compatible with one another are
TensorType and CudaNdarrayType, provided they have the same
number of dimensions, same broadcasting pattern, and same dtype.
If Types are not compatible, a TypeError should be raised.
"""
if not isinstance(other, graph.Variable):
# The value is not a Variable: we cast it into
# a Constant of the appropriate Type.
other = self.Constant(type=self, data=other)
if other.type != self:
raise TypeError(
'Cannot convert Type %(othertype)s '
'(of Variable %(other)s) into Type %(self)s. '
'You can try to manually convert %(other)s into a %(self)s.'
% dict(
othertype=other.type,
other=other,
self=self)
)
return other
def is_valid_value(self, a): def is_valid_value(self, a):
"""Required: Return True for any python object `a` that would be a legal value for a Variable of this Type""" """Required: Return True for any python object `a` that would be a legal value for a Variable of this Type"""
try: try:
......
...@@ -97,6 +97,32 @@ class CudaNdarrayType(Type): ...@@ -97,6 +97,32 @@ class CudaNdarrayType(Type):
% (self, self.dtype, data, converted_data, self.dtype), % (self, self.dtype, data, converted_data, self.dtype),
data) data)
def filter_variable(self, other):
"""Convert a Variable into a CudaNdarrayType, if compatible.
This Variable should either already be a CudaNdarrayType, or be
a TensorType. It has to have the right number of dimensions,
broadcastable pattern, and dtype.
"""
if hasattr(other, '_as_CudaNdarrayVariable'):
other = other._as_CudaNdarrayVariable()
if not isinstance(other, Variable):
# The value is not a Variable: we cast it into
# a Constant of the appropriate Type.
other = self.Constant(type=self, data=other)
if other.type == self:
return other
if not isinstance(other.type, tensor.TensorType):
raise TypeError('Incompatible type', (self, other.type))
if (other.type.dtype != self.dtype):
raise TypeError('Incompatible dtype', (self.dtype, other.type.dtype))
if (other.type.broadcastable != self.broadcastable):
raise TypeError('Incompatible broadcastable', (self.broadcastable,
other.type.broadcastable))
return theano.sandbox.cuda.basic_ops.GpuFromHost()(other)
@staticmethod @staticmethod
def bound(a): def bound(a):
......
...@@ -633,6 +633,35 @@ class TensorType(Type): ...@@ -633,6 +633,35 @@ class TensorType(Type):
raise ValueError("non-finite elements not allowed") raise ValueError("non-finite elements not allowed")
return data return data
def filter_variable(self, other):
"""Convert a symbolic Variable into a TensorType, if compatible.
For the moment, only a TensorType or CudaNdarrayType will be
converted, provided they have the same number of dimensions,
broadcastable pattern, and dtype.
"""
if hasattr(other, '_as_TensorVariable'):
other = other._as_TensorVariable()
if not isinstance(other, Variable):
# The value is not a Variable: we cast it into
# a Constant of the appropriate Type.
other = self.Constant(type=self, data=other)
if other.type == self:
return other
raise TypeError(
'Cannot convert Type %(othertype)s '
'(of Variable %(other)s) into Type %(self)s. '
'You can try to manually convert %(other)s into a %(self)s.'
% dict(
othertype=other.type,
other=other,
self=self)
)
def value_validity_msg(self, a): def value_validity_msg(self, a):
try: try:
self.filter(a, strict=True) self.filter(a, strict=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论