提交 6b3003fe authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Make imports from theano.gof.utils explicit

上级 70a12a17
...@@ -13,8 +13,15 @@ from itertools import count ...@@ -13,8 +13,15 @@ from itertools import count
from six import string_types, integer_types from six import string_types, integer_types
from theano import config from theano import config
from theano.gof import utils from theano.gof.utils import (
from theano.gof.utils import TestValueError TestValueError,
object2,
MethodNotDefined,
Scratchpad,
ValidatingScratchpad,
get_variable_trace_string,
add_tag_trace,
)
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
...@@ -22,7 +29,7 @@ __docformat__ = "restructuredtext en" ...@@ -22,7 +29,7 @@ __docformat__ = "restructuredtext en"
NoParams = object() NoParams = object()
class Node(utils.object2): class Node(object2):
"""A `Node` in a Theano graph. """A `Node` in a Theano graph.
Currently, graphs contain two kinds of `Nodes`: `Variable`s and `Apply`s. Currently, graphs contain two kinds of `Nodes`: `Variable`s and `Apply`s.
...@@ -90,7 +97,7 @@ class Apply(Node): ...@@ -90,7 +97,7 @@ class Apply(Node):
def __init__(self, op, inputs, outputs): def __init__(self, op, inputs, outputs):
self.op = op self.op = op
self.inputs = [] self.inputs = []
self.tag = utils.Scratchpad() self.tag = Scratchpad()
if not isinstance(inputs, (list, tuple)): if not isinstance(inputs, (list, tuple)):
raise TypeError("The inputs of an Apply must be a list or tuple") raise TypeError("The inputs of an Apply must be a list or tuple")
...@@ -132,7 +139,7 @@ class Apply(Node): ...@@ -132,7 +139,7 @@ class Apply(Node):
""" """
try: try:
return self.op.get_params(self) return self.op.get_params(self)
except theano.gof.utils.MethodNotDefined: except MethodNotDefined:
return NoParams return NoParams
def __getstate__(self): def __getstate__(self):
...@@ -381,7 +388,7 @@ class Variable(Node): ...@@ -381,7 +388,7 @@ class Variable(Node):
def __init__(self, type, owner=None, index=None, name=None): def __init__(self, type, owner=None, index=None, name=None):
super().__init__() super().__init__()
self.tag = utils.ValidatingScratchpad("test_value", type.filter) self.tag = ValidatingScratchpad("test_value", type.filter)
self.type = type self.type = type
...@@ -410,7 +417,7 @@ class Variable(Node): ...@@ -410,7 +417,7 @@ class Variable(Node):
""" """
if not hasattr(self.tag, "test_value"): if not hasattr(self.tag, "test_value"):
detailed_err_msg = utils.get_variable_trace_string(self) detailed_err_msg = get_variable_trace_string(self)
raise TestValueError( raise TestValueError(
"{} has no test value {}".format(self, detailed_err_msg) "{} has no test value {}".format(self, detailed_err_msg)
) )
...@@ -602,7 +609,7 @@ class Constant(Variable): ...@@ -602,7 +609,7 @@ class Constant(Variable):
def __init__(self, type, data, name=None): def __init__(self, type, data, name=None):
super().__init__(type, None, None, name) super().__init__(type, None, None, name)
self.data = type.filter(data) self.data = type.filter(data)
utils.add_tag_trace(self) add_tag_trace(self)
def get_test_value(self): def get_test_value(self):
return self.data return self.data
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论