提交 2583dd15 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a way for almost-equal types replacements to be patched over automatically.

上级 ab1d2d48
...@@ -462,12 +462,22 @@ class FunctionGraph(utils.object2): ...@@ -462,12 +462,22 @@ class FunctionGraph(utils.object2):
if verbose: if verbose:
print reason, r, new_r print reason, r, new_r
if r.fgraph is not self: if r.fgraph is not self:
raise Exception("Cannot replace %s because it does not belong to this FunctionGraph" % r, str(reason)) raise Exception("Cannot replace %s because it does not belong "
if not r.type == new_r.type: "to this FunctionGraph" % r, str(reason))
raise TypeError("The type of the replacement must be the same as the type of the original Variable.", r, new_r, r.type, new_r.type, str(reason)) if r.type != new_r.type:
new_r2 = r.type.convert_variable(new_r)
# We still make sure that the type converts correctly
if new_r2 is None or new_r2.type != r.type:
raise TypeError("The type of the replacement must be "
"compatible with the type of the original "
"Variable.", r, new_r, r.type, new_r.type,
str(reason))
new_r = new_r2
if r not in self.variables: if r not in self.variables:
# this variable isn't in the graph... don't raise an exception here, just return silently # this variable isn't in the graph... don't raise an
# because it makes it easier to implement some optimizations for multiple-output ops # exception here, just return silently because it makes it
# easier to implement some optimizations for
# multiple-output ops
return return
if theano.config.compute_test_value != 'off': if theano.config.compute_test_value != 'off':
......
...@@ -394,6 +394,23 @@ class Type(object2, PureType, CLinkerType): ...@@ -394,6 +394,23 @@ class Type(object2, PureType, CLinkerType):
types. Type references are also useful to do type-checking in pattern-based optimizations. types. Type references are also useful to do type-checking in pattern-based optimizations.
""" """
def convert_variable(self, var):
"""Patch variable so that its type will match self, if possible.
If the variable can't be converted, this should return None.
The conversion can only happen if the following implication is
true for all possible `val`.
self.is_valid_value(val) => var.type.is_valid_value(val)
For the majority of types this means that you can only have
non-broadcastable dimensions become broadcastable and not the
inverse.
The default is to not convert anything which is always safe.
"""
return None
class SingletonType(Type): class SingletonType(Type):
......
...@@ -232,6 +232,13 @@ class CudaNdarrayType(Type): ...@@ -232,6 +232,13 @@ class CudaNdarrayType(Type):
return (type(self) == type(other) and return (type(self) == type(other) and
other.broadcastable == self.broadcastable) other.broadcastable == self.broadcastable)
def convert_variable(self, var):
if (type(self) == type(var.type) and
self.ndim == var.type.ndim and
all(sb == ob or ob for sb, ob in zip(self.broadcastable,
var.type.broadcastable))):
return theano.tensor.patternbroadcast(var, self.broadcastable)
def __hash__(self): def __hash__(self):
"""Hash equal for same kinds of CudaNdarrayType""" """Hash equal for same kinds of CudaNdarrayType"""
return hash(type(self)) ^ hash(self.broadcastable) return hash(type(self)) ^ hash(self.broadcastable)
......
...@@ -148,6 +148,14 @@ class GpuArrayType(Type): ...@@ -148,6 +148,14 @@ class GpuArrayType(Type):
self.typecode == other.typecode and self.typecode == other.typecode and
self.broadcastable == other.broadcastable) self.broadcastable == other.broadcastable)
def convert_variable(self, var):
if (type(self) == type(var.type) and
self.typecode == var.type.typecode and
self.ndim == var.type.ndim and
all(sb == ob or ob for sb, ob in zip(self.broadcastable,
var.type.broadcastable))):
return theano.tensor.patternbroadcast(var, self.broadcastable)
def __hash__(self): def __hash__(self):
return (hash(self.typecode) ^ hash(self.broadcastable)) return (hash(self.typecode) ^ hash(self.broadcastable))
......
...@@ -260,6 +260,14 @@ class TensorType(Type): ...@@ -260,6 +260,14 @@ class TensorType(Type):
return type(self) == type(other) and other.dtype == self.dtype \ return type(self) == type(other) and other.dtype == self.dtype \
and other.broadcastable == self.broadcastable and other.broadcastable == self.broadcastable
def convert_variable(self, var):
if (type(self) == type(var.type) and
self.dtype == var.type.dtype and
self.ndim == var.type.ndim and
all(sb == ob or ob for sb, ob in zip(self.broadcastable,
var.type.broadcastable))):
return theano.tensor.patternbroadcast(var, self.broadcastable)
@staticmethod @staticmethod
def may_share_memory(a, b): def may_share_memory(a, b):
# This is a method of TensorType, so both a and b should be ndarrays # This is a method of TensorType, so both a and b should be ndarrays
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论