提交 68433621 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

add autoname

上级 9618fc79
...@@ -10,6 +10,8 @@ __docformat__ = "restructuredtext en" ...@@ -10,6 +10,8 @@ __docformat__ = "restructuredtext en"
from copy import copy from copy import copy
from itertools import count
import theano import theano
import warnings import warnings
...@@ -32,6 +34,11 @@ class Node(utils.object2): ...@@ -32,6 +34,11 @@ class Node(utils.object2):
via Variable.clients / Apply.outputs. via Variable.clients / Apply.outputs.
""" """
_count = count(0)
def __init__(self):
self.auto_name = 'auto_' + str(self._count.next())
def get_parents(self): def get_parents(self):
""" Return a list of the parents of this node. """ Return a list of the parents of this node.
Should return a copy--i.e., modifying the return Should return a copy--i.e., modifying the return
...@@ -86,6 +93,8 @@ class Apply(Node): ...@@ -86,6 +93,8 @@ class Apply(Node):
exception will be raised. exception will be raised.
""" """
super(Apply, self).__init__()
self.op = op self.op = op
self.inputs = [] self.inputs = []
self.tag = utils.scratchpad() self.tag = utils.scratchpad()
...@@ -334,6 +343,8 @@ class Variable(Node): ...@@ -334,6 +343,8 @@ class Variable(Node):
:param name: a string for pretty-printing and debugging :param name: a string for pretty-printing and debugging
""" """
super(Variable, self).__init__()
self.tag = utils.scratchpad() self.tag = utils.scratchpad()
self.type = type self.type = type
if owner is not None and not isinstance(owner, Apply): if owner is not None and not isinstance(owner, Apply):
......
...@@ -3,10 +3,12 @@ import unittest ...@@ -3,10 +3,12 @@ import unittest
from theano import tensor from theano import tensor
from theano.gof.graph import ( from theano.gof.graph import (
Apply, as_string, clone, general_toposort, inputs, io_toposort, Node, Apply, as_string, clone, general_toposort, inputs, io_toposort,
is_same_graph, Variable) is_same_graph, Variable)
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.type import Type from theano.gof.type import Type
from itertools import count
def as_variable(x): def as_variable(x):
...@@ -48,7 +50,6 @@ class MyOp(Op): ...@@ -48,7 +50,6 @@ class MyOp(Op):
MyOp = MyOp() MyOp = MyOp()
########## ##########
# inputs # # inputs #
########## ##########
...@@ -311,3 +312,19 @@ class TestEval(unittest.TestCase): ...@@ -311,3 +312,19 @@ class TestEval(unittest.TestCase):
"variable must have cache after eval") "variable must have cache after eval")
self.assertFalse(hasattr(pickle.loads(pickle.dumps(self.w)), '_fn_cache'), self.assertFalse(hasattr(pickle.loads(pickle.dumps(self.w)), '_fn_cache'),
"temporary functions must not be serialized") "temporary functions must not be serialized")
################
# autoname #
################
class TestAutoName:
def test_auto_name(self):
## Re-init counter
Node._count = count(0)
r1, r2 = MyVariable(1), MyVariable(2)
node = MyOp.make_node(r1, r2)
assert r1.auto_name == "auto_0"
assert r2.auto_name == "auto_1"
assert node.auto_name == "auto_3"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论