提交 5f179854 authored 作者: James Bergstra's avatar James Bergstra

added eq, hash, c_code to OutputGuard Op

上级 f8493348
"""WRITEME
"""
import os, logging
import numpy
import os
import scipy.sparse as sp
from theano import gof
_logger = logging.getLogger('theano.compile.mode')
def check_equal(x, y):
"""
Returns True iff x[0] and y[0] are equal (checks the dtype and
......@@ -79,10 +84,23 @@ class OutputGuard(gof.Op):
view_map = {0:[0]}
def make_node(self, x):
return gof.Apply(self, [x], [x.type()])
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def perform(self, node, (x,), (z,)):
z[0] = x
def __str__(self):
return '%s' % self.__class__.__name__
def c_code(self, node, nodename, (x,), (z,), sub):
return """
Py_XDECREF(%(z)s);
%(z)s = %(x)s;
Py_XINCREF(%(z)s);
""" %locals()
def c_code_cache_version(self):
return (1,)
_output_guard = OutputGuard()
class AddDestroyHandler(gof.Optimizer):
"""This optimizer performs two important functions:
......@@ -97,11 +115,10 @@ class AddDestroyHandler(gof.Optimizer):
not be possible to destroy outputs.
"""
def apply(self, env):
output_guard = OutputGuard()
for o in env.outputs:
try:
env.replace_validate(o, output_guard(o), reason='output_guard')
warning("Output variable %s required output_guard,"
env.replace_validate(o, _output_guard(o), reason='output_guard')
_logger.warning("Output variable %s required output_guard,"
" how was this output left unprotected against destructive operations?"
% o)
except gof.InconsistencyError:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论