提交 39269c13 authored 作者: Tim Cooijmans's avatar Tim Cooijmans 提交者: Reyhane Askari

add context manager to trace constructed Variables, inherit stack traces

上级 9bc05a38
...@@ -4,6 +4,7 @@ Node classes (`Apply`, `Variable`) and expression graph algorithms. ...@@ -4,6 +4,7 @@ Node classes (`Apply`, `Variable`) and expression graph algorithms.
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from collections import deque from collections import deque
import contextlib
from copy import copy from copy import copy
from itertools import count from itertools import count
...@@ -390,6 +391,8 @@ class Variable(Node): ...@@ -390,6 +391,8 @@ class Variable(Node):
self.name = name self.name = name
self.auto_name = 'auto_' + str(next(self.__count__)) self.auto_name = 'auto_' + str(next(self.__count__))
Variable.notify_construction_observers(self)
def __str__(self): def __str__(self):
"""Return a str representation of the Variable. """Return a str representation of the Variable.
...@@ -536,6 +539,21 @@ class Variable(Node): ...@@ -536,6 +539,21 @@ class Variable(Node):
d["tag"] = t d["tag"] = t
return d return d
construction_observers = []
@classmethod
def append_construction_observer(cls, observer):
cls.construction_observers.append(observer)
@classmethod
def remove_construction_observer(cls, observer):
cls.construction_observers.remove(observer)
@classmethod
def notify_construction_observers(cls, instance):
for observer in cls.construction_observers:
observer(instance)
class Constant(Variable): class Constant(Variable):
""" """
...@@ -1426,3 +1444,13 @@ def is_in_ancestors(l_node, f_node): ...@@ -1426,3 +1444,13 @@ def is_in_ancestors(l_node, f_node):
todo.append(cur) todo.append(cur)
todo.extend(i.owner for i in cur.inputs if i.owner) todo.extend(i.owner for i in cur.inputs if i.owner)
return False return False
@contextlib.contextmanager
def nodes_constructed():
new_nodes = []
def observer(node):
new_nodes.append(node)
Variable.append_construction_observer(observer)
yield new_nodes
Variable.remove_construction_observer(observer)
...@@ -6,6 +6,7 @@ amount of useful generic optimization tools. ...@@ -6,6 +6,7 @@ amount of useful generic optimization tools.
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from collections import deque, defaultdict, OrderedDict from collections import deque, defaultdict, OrderedDict
import contextlib
import copy import copy
import inspect import inspect
import logging import logging
...@@ -2976,6 +2977,24 @@ def with_stack_trace(from_var, to_var): ...@@ -2976,6 +2977,24 @@ def with_stack_trace(from_var, to_var):
copy_stack_trace(from_var, to_var) copy_stack_trace(from_var, to_var)
return to_var return to_var
@contextlib.contextmanager
def inherit_stack_trace(from_var):
"""
Contextmanager that copies the stack trace from one or more tensor variables to all tensor
variables constructed in the body.
Parameters
----------
from_var
Tensor variable or list of tensor variables to copy stack traces from.
"""
with graph.nodes_constructed() as new_nodes:
yield
copy_stack_trace(from_var, new_nodes)
def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'): def check_stack_trace(f_or_fgraph, ops_to_check='last', bug_print='raise'):
""" """
This function checks if the outputs of specific ops of a compiled graph This function checks if the outputs of specific ops of a compiled graph
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论