提交 451b841d authored 作者: Frederic's avatar Frederic

Fix import problem with OutputGuard.

I made a registration system that allow all Theano Type to tell if they allow OutputGuard to generate c code for them. This fix the dependency on sandbox code in theano/compile/mode.py file. This make it a better fix then what was done for SparseType.
上级 3360e333
...@@ -100,9 +100,14 @@ def register_optimizer(name, opt): ...@@ -100,9 +100,14 @@ def register_optimizer(name, opt):
raise ValueError('Optimizer name already taken: %s' % name) raise ValueError('Optimizer name already taken: %s' % name)
predefined_optimizers[name] = opt predefined_optimizers[name] = opt
def register_OutputGuard_c_code(type):
OutputGuard.c_code_types.append(type)
class OutputGuard(gof.Op): class OutputGuard(gof.Op):
destroy_map = {0:[0]} destroy_map = {0:[0]}
view_map = {0:[0]} view_map = {0:[0]}
c_code_types = []
def make_node(self, x): def make_node(self, x):
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
def __eq__(self, other): def __eq__(self, other):
...@@ -124,12 +129,7 @@ class OutputGuard(gof.Op): ...@@ -124,12 +129,7 @@ class OutputGuard(gof.Op):
return """ return """
%(z)s = %(x)s; %(z)s = %(x)s;
""" % locals() """ % locals()
elif (isinstance(node.inputs[0].type, elif (isinstance(node.inputs[0].type, tuple(self.c_code_types))):
(theano.tensor.TensorType,
theano.sandbox.cuda.CudaNdarrayType,
theano.tensor.raw_random.RandomStateType)) or
node.inputs[0].type.__class__.__name__ == 'SparseType'
):
# These are Python object types # These are Python object types
return """ return """
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
......
...@@ -351,6 +351,12 @@ class CudaNdarrayType(Type): ...@@ -351,6 +351,12 @@ class CudaNdarrayType(Type):
ret.append('-use_fast_math') ret.append('-use_fast_math')
return ret return ret
# Register CudaNdarrayType to the OutputGuard list of know type
# to have OutputGuard generate c code for this type
theano.compile.mode.register_OutputGuard_c_code(CudaNdarrayType)
# THIS WORKS # THIS WORKS
# But CudaNdarray instances don't compare equal to one another, and what about __hash__ ? # But CudaNdarray instances don't compare equal to one another, and what about __hash__ ?
# So the unpickled version doesn't equal the pickled version, and the cmodule cache is not # So the unpickled version doesn't equal the pickled version, and the cmodule cache is not
......
...@@ -326,6 +326,11 @@ class SparseType(gof.Type): ...@@ -326,6 +326,11 @@ class SparseType(gof.Type):
def is_valid_value(self, a): def is_valid_value(self, a):
return scipy.sparse.issparse(a) and (a.format == self.format) return scipy.sparse.issparse(a) and (a.format == self.format)
# Register CudaNdarrayType to the OutputGuard list of know type
# to have OutputGuard generate c code for this type
theano.compile.mode.register_OutputGuard_c_code(SparseType)
# for more dtypes, call SparseType(format, dtype) # for more dtypes, call SparseType(format, dtype)
def matrix(format, name=None, dtype=None): def matrix(format, name=None, dtype=None):
if dtype is None: if dtype is None:
......
...@@ -910,6 +910,10 @@ class TensorType(Type): ...@@ -910,6 +910,10 @@ class TensorType(Type):
else: else:
return () return ()
# Register CudaNdarrayType to the OutputGuard list of know type
# to have OutputGuard generate c code for this type
theano.compile.mode.register_OutputGuard_c_code(TensorType)
# Easy constructors # Easy constructors
def tensor(*args, **kwargs): def tensor(*args, **kwargs):
......
...@@ -53,6 +53,10 @@ class RandomStateType(gof.Type): ...@@ -53,6 +53,10 @@ class RandomStateType(gof.Type):
return False return False
return True return True
# Register CudaNdarrayType to the OutputGuard list of know type
# to have OutputGuard generate c code for this type
theano.compile.mode.register_OutputGuard_c_code(RandomStateType)
random_state_type = RandomStateType() random_state_type = RandomStateType()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论