提交 9a17df56 authored 作者: James Bergstra's avatar James Bergstra

factored adding traceback tag out of op and type. Added cls.Variable and cls.Constant to type

上级 98f7218c
...@@ -9,9 +9,10 @@ compatible with `gof`'s :doc:`graph` routines. ...@@ -9,9 +9,10 @@ compatible with `gof`'s :doc:`graph` routines.
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
import utils import utils
import traceback
from theano import config from theano import config
class CLinkerObject(object): class CLinkerObject(object):
"""Standard elements of an Op or Type used with the CLinker """Standard elements of an Op or Type used with the CLinker
""" """
...@@ -320,9 +321,7 @@ class PureOp(object): ...@@ -320,9 +321,7 @@ class PureOp(object):
""" """
node = self.make_node(*inputs, **kwargs) node = self.make_node(*inputs, **kwargs)
limit = config.traceback.limit self.add_tag_trace(node)
if limit == -1: limit = None
node.tag.trace = traceback.extract_stack(limit=limit)[:-1]
if self.default_output is not None: if self.default_output is not None:
return node.outputs[self.default_output] return node.outputs[self.default_output]
else: else:
...@@ -331,6 +330,9 @@ class PureOp(object): ...@@ -331,6 +330,9 @@ class PureOp(object):
else: else:
return node.outputs return node.outputs
# Convenience so that subclass implementers don't have to import utils
# just to self.add_tag_trace
add_tag_trace = staticmethod(utils.add_tag_trace)
######################### #########################
# Python implementation # # Python implementation #
......
...@@ -5,8 +5,7 @@ __docformat__ = "restructuredtext en" ...@@ -5,8 +5,7 @@ __docformat__ = "restructuredtext en"
import copy import copy
import utils import utils
from utils import MethodNotDefined, object2 from utils import MethodNotDefined, object2
from graph import Variable import graph
import traceback
from theano import config from theano import config
######## ########
...@@ -202,6 +201,9 @@ class PureType(object): ...@@ -202,6 +201,9 @@ class PureType(object):
""" """
Variable = graph.Variable #the type that will be created by call to make_variable.
Constant = graph.Constant #the type that will be created by call to make_constant
def filter(self, data, strict = False): def filter(self, data, strict = False):
"""Required: Return data or an appropriately wrapped/converted data. """Required: Return data or an appropriately wrapped/converted data.
...@@ -233,8 +235,11 @@ class PureType(object): ...@@ -233,8 +235,11 @@ class PureType(object):
A pretty string for printing and debugging. A pretty string for printing and debugging.
""" """
r = Variable(self, name = name) return self.Variable(self, name = name)
return r
def make_constant(self, value, name=None):
return self.Constant(type=self, data=value, name=name)
def __call__(self, name = None): def __call__(self, name = None):
"""Return a new `Variable` instance of Type `self`. """Return a new `Variable` instance of Type `self`.
...@@ -244,11 +249,7 @@ class PureType(object): ...@@ -244,11 +249,7 @@ class PureType(object):
A pretty string for printing and debugging. A pretty string for printing and debugging.
""" """
r = self.make_variable(name) return utils.add_tag_trace(self.make_variable(name))
limit = config.traceback.limit
if limit == -1: limit = None
r.tag.trace = traceback.extract_stack(limit=limit)[:-1]
return r
def values_eq(self, a, b): def values_eq(self, a, b):
""" """
...@@ -319,9 +320,11 @@ class Type(object2, PureType, CLinkerType): ...@@ -319,9 +320,11 @@ class Type(object2, PureType, CLinkerType):
""" """
## DELETEME ##
class SingletonType(Type): class SingletonType(Type):
"""WRITEME""" """Convenient Base class for a Type subclass with no attributes
It saves having to implement __eq__ and __hash__
"""
__instance = None __instance = None
def __new__(cls): def __new__(cls):
if cls.__instance is None: if cls.__instance is None:
...@@ -378,6 +381,5 @@ class Generic(SingletonType): ...@@ -378,6 +381,5 @@ class Generic(SingletonType):
Py_INCREF(py_%(name)s); Py_INCREF(py_%(name)s);
""" % locals() """ % locals()
generic = Generic() generic = Generic()
...@@ -2,7 +2,18 @@ ...@@ -2,7 +2,18 @@
# import op # import op
# import variable # import variable
import re, os from theano import config
import re, os, traceback
def add_tag_trace(thing):
"""Add tag.trace to an node or variable.
The argument is returned after being affected (inplace).
"""
limit = config.traceback.limit
if limit == -1: limit = None
thing.tag.trace = traceback.extract_stack(limit=limit)[:-1]
return thing
def hashgen(): def hashgen():
hashgen.next += 1 hashgen.next += 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论