提交 d3d234f2 authored 作者: Nicolas Ballas's avatar Nicolas Ballas

Updates according to PR comments

上级 68433621
...@@ -33,12 +33,6 @@ class Node(utils.object2): ...@@ -33,12 +33,6 @@ class Node(utils.object2):
Variable.owner / Apply.inputs and its children Variable.owner / Apply.inputs and its children
via Variable.clients / Apply.outputs. via Variable.clients / Apply.outputs.
""" """
_count = count(0)
def __init__(self):
self.auto_name = 'auto_' + str(self._count.next())
def get_parents(self): def get_parents(self):
""" Return a list of the parents of this node. """ Return a list of the parents of this node.
Should return a copy--i.e., modifying the return Should return a copy--i.e., modifying the return
...@@ -93,8 +87,6 @@ class Apply(Node): ...@@ -93,8 +87,6 @@ class Apply(Node):
exception will be raised. exception will be raised.
""" """
super(Apply, self).__init__()
self.op = op self.op = op
self.inputs = [] self.inputs = []
self.tag = utils.scratchpad() self.tag = utils.scratchpad()
...@@ -323,9 +315,10 @@ class Variable(Node): ...@@ -323,9 +315,10 @@ class Variable(Node):
`compile.function` uses each `Apply` instance's `inputs` attribute `compile.function` uses each `Apply` instance's `inputs` attribute
together with each Variable's `owner` field to determine which inputs are necessary to compute the function's outputs. together with each Variable's `owner` field to determine which inputs are necessary to compute the function's outputs.
""" """
#__slots__ = ['type', 'owner', 'index', 'name'] #__slots__ = ['type', 'owner', 'index', 'name']
__count__ = count(0)
def __init__(self, type, owner=None, index=None, name=None): def __init__(self, type, owner=None, index=None, name=None):
"""Initialize type, owner, index, name. """Initialize type, owner, index, name.
...@@ -356,6 +349,7 @@ class Variable(Node): ...@@ -356,6 +349,7 @@ class Variable(Node):
if name is not None and not isinstance(name, basestring): if name is not None and not isinstance(name, basestring):
raise TypeError("name must be a string", name) raise TypeError("name must be a string", name)
self.name = name self.name = name
self.auto_name = 'auto_' + str(next(self.__count__))
def __str__(self): def __str__(self):
"""WRITEME""" """WRITEME"""
......
import pickle import pickle
import unittest import unittest
import numpy
from itertools import count
from theano import tensor from theano import (
clone, sparse,
shared, tensor)
from theano.gof.graph import ( from theano.gof.graph import (
Node, Apply, as_string, clone, general_toposort, inputs, io_toposort, Node, Apply, Constant,
as_string, clone, general_toposort, inputs, io_toposort,
is_same_graph, Variable) is_same_graph, Variable)
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.type import Type from theano.gof.type import Type
from itertools import count from theano.tensor.var import TensorVariable
from theano.sandbox.cuda.var import (
CudaNdarrayVariable, CudaNdarrayConstant, CudaNdarraySharedVariable)
...@@ -321,10 +331,78 @@ class TestAutoName: ...@@ -321,10 +331,78 @@ class TestAutoName:
def test_auto_name(self): def test_auto_name(self):
## Re-init counter ## Re-init counter
Node._count = count(0) Variable.__count__ = count(0)
r1, r2 = MyVariable(1), MyVariable(2) r1, r2 = MyVariable(1), MyVariable(2)
node = MyOp.make_node(r1, r2)
assert r1.auto_name == "auto_0" assert r1.auto_name == "auto_0"
assert r2.auto_name == "auto_1" assert r2.auto_name == "auto_1"
assert node.auto_name == "auto_3"
def test_constant(self):
## Re-init counter
Variable.__count__ = count(0)
r1 = tensor.constant(1.5)
r2 = tensor.constant(1.5)
assert r1.auto_name == "auto_0"
assert r2.auto_name == "auto_1"
def test_tensorvariable(self):
## Re-init counter
Variable.__count__ = count(0)
r1 = tensor.TensorType(dtype='int32', broadcastable=())('myvar')
r2 = tensor.TensorVariable(tensor.TensorType(dtype='int32',
broadcastable=()))
r3 = shared(numpy.random.randn(3,4))
assert r1.auto_name == "auto_0"
assert r2.auto_name == "auto_1"
assert r3.auto_name == "auto_2"
def test_sparsevariable(self):
## Re-init counter
Variable.__count__ = count(0)
r1 = sparse.csc_matrix(name='x', dtype='float32')
r2 = sparse.dense_from_sparse(r1)
r3 = sparse.csc_from_dense(r2)
assert r1.auto_name == "auto_0"
assert r2.auto_name == "auto_1"
assert r3.auto_name == "auto_2"
def test_cudandarrayvariable(self):
## Re-init counter
Variable.__count__ = count(0)
mytype = tensor.TensorType(dtype='int32', broadcastable=())
r1 = CudaNdarrayVariable(type='int32')
r2 = CudaNdarrayVariable(type='int32')
r3 = CudaNdarrayConstant(type=mytype,
data=1)
r4 = CudaNdarraySharedVariable(name='x', type=mytype,
value=1, strict=False)
assert r1.auto_name == "auto_0"
assert r2.auto_name == "auto_1"
assert r3.auto_name == "auto_2"
assert r4.auto_name == "auto_3"
def test_cudandarrayvariable(self):
## Re-init counter
Variable.__count__ = count(0)
mytype = tensor.TensorType(dtype='int32', broadcastable=())
r1 = tensor.shared_randomstreams.RandomStateSharedVariable(name='x',
type=mytype,
value=1,
strict=False)
r2 = tensor.shared_randomstreams.RandomStateSharedVariable(name='x',
type=mytype,
value=1,
strict=False)
assert r1.auto_name == "auto_0"
assert r2.auto_name == "auto_1"
def test_clone(self):
## Re-init counter
Variable.__count__ = count(0)
r1 = MyVariable(1)
r2 = r1.clone()
assert r1.auto_name == "auto_0"
assert r2.auto_name == "auto_1"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论