提交 6761f0d1 authored 作者: carriepl's avatar carriepl

Merge pull request #2398 from ballasn/autoname

Add autoname
...@@ -10,6 +10,8 @@ __docformat__ = "restructuredtext en" ...@@ -10,6 +10,8 @@ __docformat__ = "restructuredtext en"
from copy import copy from copy import copy
from itertools import count
import theano import theano
import warnings import warnings
...@@ -31,7 +33,6 @@ class Node(utils.object2): ...@@ -31,7 +33,6 @@ class Node(utils.object2):
Variable.owner / Apply.inputs and its children Variable.owner / Apply.inputs and its children
via Variable.clients / Apply.outputs. via Variable.clients / Apply.outputs.
""" """
def get_parents(self): def get_parents(self):
""" Return a list of the parents of this node. """ Return a list of the parents of this node.
Should return a copy--i.e., modifying the return Should return a copy--i.e., modifying the return
...@@ -314,9 +315,10 @@ class Variable(Node): ...@@ -314,9 +315,10 @@ class Variable(Node):
`compile.function` uses each `Apply` instance's `inputs` attribute `compile.function` uses each `Apply` instance's `inputs` attribute
together with each Variable's `owner` field to determine which inputs are necessary to compute the function's outputs. together with each Variable's `owner` field to determine which inputs are necessary to compute the function's outputs.
""" """
#__slots__ = ['type', 'owner', 'index', 'name'] #__slots__ = ['type', 'owner', 'index', 'name']
__count__ = count(0)
def __init__(self, type, owner=None, index=None, name=None): def __init__(self, type, owner=None, index=None, name=None):
"""Initialize type, owner, index, name. """Initialize type, owner, index, name.
...@@ -334,6 +336,8 @@ class Variable(Node): ...@@ -334,6 +336,8 @@ class Variable(Node):
:param name: a string for pretty-printing and debugging :param name: a string for pretty-printing and debugging
""" """
super(Variable, self).__init__()
self.tag = utils.scratchpad() self.tag = utils.scratchpad()
self.type = type self.type = type
if owner is not None and not isinstance(owner, Apply): if owner is not None and not isinstance(owner, Apply):
...@@ -345,6 +349,7 @@ class Variable(Node): ...@@ -345,6 +349,7 @@ class Variable(Node):
if name is not None and not isinstance(name, basestring): if name is not None and not isinstance(name, basestring):
raise TypeError("name must be a string", name) raise TypeError("name must be a string", name)
self.name = name self.name = name
self.auto_name = 'auto_' + str(next(self.__count__))
def __str__(self): def __str__(self):
"""WRITEME""" """WRITEME"""
......
import pickle import pickle
import unittest import unittest
import numpy
from itertools import count
from theano import tensor
from theano import (
clone, sparse,
shared, tensor)
from theano.gof.graph import ( from theano.gof.graph import (
Apply, as_string, clone, general_toposort, inputs, io_toposort, Node, Apply, Constant,
as_string, clone, general_toposort, inputs, io_toposort,
is_same_graph, Variable) is_same_graph, Variable)
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.type import Type from theano.gof.type import Type
from theano.tensor.var import TensorVariable
from theano.sandbox.cuda.var import (
CudaNdarrayVariable, CudaNdarrayConstant, CudaNdarraySharedVariable)
def as_variable(x): def as_variable(x):
...@@ -48,7 +60,6 @@ class MyOp(Op): ...@@ -48,7 +60,6 @@ class MyOp(Op):
MyOp = MyOp() MyOp = MyOp()
########## ##########
# inputs # # inputs #
########## ##########
...@@ -311,3 +322,87 @@ class TestEval(unittest.TestCase): ...@@ -311,3 +322,87 @@ class TestEval(unittest.TestCase):
"variable must have cache after eval") "variable must have cache after eval")
self.assertFalse(hasattr(pickle.loads(pickle.dumps(self.w)), '_fn_cache'), self.assertFalse(hasattr(pickle.loads(pickle.dumps(self.w)), '_fn_cache'),
"temporary functions must not be serialized") "temporary functions must not be serialized")
################
# autoname #
################
class TestAutoName:
def test_auto_name(self):
## Re-init counter
Variable.__count__ = count(0)
r1, r2 = MyVariable(1), MyVariable(2)
assert r1.auto_name == "auto_0"
assert r2.auto_name == "auto_1"
def test_constant(self):
## Re-init counter
Variable.__count__ = count(0)
r1 = tensor.constant(1.5)
r2 = tensor.constant(1.5)
assert r1.auto_name == "auto_0"
assert r2.auto_name == "auto_1"
def test_tensorvariable(self):
## Re-init counter
Variable.__count__ = count(0)
r1 = tensor.TensorType(dtype='int32', broadcastable=())('myvar')
r2 = tensor.TensorVariable(tensor.TensorType(dtype='int32',
broadcastable=()))
r3 = shared(numpy.random.randn(3,4))
assert r1.auto_name == "auto_0"
assert r2.auto_name == "auto_1"
assert r3.auto_name == "auto_2"
def test_sparsevariable(self):
## Re-init counter
Variable.__count__ = count(0)
r1 = sparse.csc_matrix(name='x', dtype='float32')
r2 = sparse.dense_from_sparse(r1)
r3 = sparse.csc_from_dense(r2)
assert r1.auto_name == "auto_0"
assert r2.auto_name == "auto_1"
assert r3.auto_name == "auto_2"
def test_cudandarrayvariable(self):
## Re-init counter
Variable.__count__ = count(0)
mytype = tensor.TensorType(dtype='int32', broadcastable=())
r1 = CudaNdarrayVariable(type='int32')
r2 = CudaNdarrayVariable(type='int32')
r3 = CudaNdarrayConstant(type=mytype,
data=1)
r4 = CudaNdarraySharedVariable(name='x', type=mytype,
value=1, strict=False)
assert r1.auto_name == "auto_0"
assert r2.auto_name == "auto_1"
assert r3.auto_name == "auto_2"
assert r4.auto_name == "auto_3"
def test_randomvariable(self):
## Re-init counter
Variable.__count__ = count(0)
mytype = tensor.TensorType(dtype='int32', broadcastable=())
r1 = tensor.shared_randomstreams.RandomStateSharedVariable(name='x',
type=mytype,
value=1,
strict=False)
r2 = tensor.shared_randomstreams.RandomStateSharedVariable(name='x',
type=mytype,
value=1,
strict=False)
assert r1.auto_name == "auto_0"
assert r2.auto_name == "auto_1"
def test_clone(self):
## Re-init counter
Variable.__count__ = count(0)
r1 = MyVariable(1)
r2 = r1.clone()
assert r1.auto_name == "auto_0"
assert r2.auto_name == "auto_1"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论