提交 5f179854 authored 作者: James Bergstra's avatar James Bergstra

added eq, hash, c_code to OutputGuard Op

上级 f8493348
"""WRITEME
"""
import os, logging
import numpy import numpy
import os
import scipy.sparse as sp import scipy.sparse as sp
from theano import gof from theano import gof
_logger = logging.getLogger('theano.compile.mode')
def check_equal(x, y): def check_equal(x, y):
""" """
Returns True iff x[0] and y[0] are equal (checks the dtype and Returns True iff x[0] and y[0] are equal (checks the dtype and
...@@ -79,10 +84,23 @@ class OutputGuard(gof.Op): ...@@ -79,10 +84,23 @@ class OutputGuard(gof.Op):
view_map = {0:[0]} view_map = {0:[0]}
def make_node(self, x): def make_node(self, x):
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def perform(self, node, (x,), (z,)): def perform(self, node, (x,), (z,)):
z[0] = x z[0] = x
def __str__(self): def __str__(self):
return '%s' % self.__class__.__name__ return '%s' % self.__class__.__name__
def c_code(self, node, nodename, (x,), (z,), sub):
return """
Py_XDECREF(%(z)s);
%(z)s = %(x)s;
Py_XINCREF(%(z)s);
""" %locals()
def c_code_cache_version(self):
return (1,)
_output_guard = OutputGuard()
class AddDestroyHandler(gof.Optimizer): class AddDestroyHandler(gof.Optimizer):
"""This optimizer performs two important functions: """This optimizer performs two important functions:
...@@ -97,11 +115,10 @@ class AddDestroyHandler(gof.Optimizer): ...@@ -97,11 +115,10 @@ class AddDestroyHandler(gof.Optimizer):
not be possible to destroy outputs. not be possible to destroy outputs.
""" """
def apply(self, env): def apply(self, env):
output_guard = OutputGuard()
for o in env.outputs: for o in env.outputs:
try: try:
env.replace_validate(o, output_guard(o), reason='output_guard') env.replace_validate(o, _output_guard(o), reason='output_guard')
warning("Output variable %s required output_guard," _logger.warning("Output variable %s required output_guard,"
" how was this output left unprotected against destructive operations?" " how was this output left unprotected against destructive operations?"
% o) % o)
except gof.InconsistencyError: except gof.InconsistencyError:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论