提交 68433621 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

add autoname

上级 9618fc79
......@@ -10,6 +10,8 @@ __docformat__ = "restructuredtext en"
from copy import copy
from itertools import count
import theano
import warnings
......@@ -32,6 +34,11 @@ class Node(utils.object2):
via Variable.clients / Apply.outputs.
"""
_count = count(0)
def __init__(self):
self.auto_name = 'auto_' + str(self._count.next())
def get_parents(self):
""" Return a list of the parents of this node.
Should return a copy--i.e., modifying the return
......@@ -86,6 +93,8 @@ class Apply(Node):
exception will be raised.
"""
super(Apply, self).__init__()
self.op = op
self.inputs = []
self.tag = utils.scratchpad()
......@@ -334,6 +343,8 @@ class Variable(Node):
:param name: a string for pretty-printing and debugging
"""
super(Variable, self).__init__()
self.tag = utils.scratchpad()
self.type = type
if owner is not None and not isinstance(owner, Apply):
......
......@@ -3,10 +3,12 @@ import unittest
from theano import tensor
from theano.gof.graph import (
Apply, as_string, clone, general_toposort, inputs, io_toposort,
Node, Apply, as_string, clone, general_toposort, inputs, io_toposort,
is_same_graph, Variable)
from theano.gof.op import Op
from theano.gof.type import Type
from itertools import count
def as_variable(x):
......@@ -48,7 +50,6 @@ class MyOp(Op):
MyOp = MyOp()
##########
# inputs #
##########
......@@ -311,3 +312,19 @@ class TestEval(unittest.TestCase):
"variable must have cache after eval")
self.assertFalse(hasattr(pickle.loads(pickle.dumps(self.w)), '_fn_cache'),
"temporary functions must not be serialized")
################
# autoname #
################
class TestAutoName:
def test_auto_name(self):
## Re-init counter
Node._count = count(0)
r1, r2 = MyVariable(1), MyVariable(2)
node = MyOp.make_node(r1, r2)
assert r1.auto_name == "auto_0"
assert r2.auto_name == "auto_1"
assert node.auto_name == "auto_3"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论