提交 09bd9be8 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

merge

......@@ -214,6 +214,9 @@ class Scan(Op):
self.n_shared_outs )
self.n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
self.n_tap_outs = self.n_mit_mot + self.n_mit_sot
self._cmodule_key = gof.CLinker.cmodule_key_(self.fn.maker.env,[])
self._hash_inner_graph = hash(self._cmodule_key)
def make_node(self, *inputs):
assert numpy.all(isinstance(i, gof.Variable) for i in inputs)
......@@ -321,19 +324,27 @@ class Scan(Op):
return apply_node
def __eq__(self, other):
# Check if we are dealing with same type of objects
if not type(self) == type(other):
return False
# This are some safety checks ( namely that the inner graph has the
# same number of inputs and same number of outputs )
elif not len(self.inputs) == len(other.inputs):
return False
elif not len(self.outputs) == len(other.outputs):
return False
else:
# If everything went OK up to here, there is still one thing to
# check. Namely, do the internal graph represent same
# computations
for x,y in zip(self.inputs, other.inputs):
if not scan_utils.equal_computations(x,y):
return False
for x,y in zip(self.outputs, other.outputs):
if not scan_utils.equal_computations(x,y):
return False
# If they do, then they need to match in other small details
# like name, mode, etc.
return self.info == other.info
def __str__(self):
......@@ -353,10 +364,10 @@ class Scan(Op):
def __hash__(self):
# any two objects that *might* compare equal must hash equal.
# so we don't check the self.inputs and self.outputs here and it's
# ok.
return ( hash(type(self)) ^
# and a hash representing the inner graph using the
# CLinker.cmodule_key_
self._hash_inner_graph ^
scan_utils.hash_listsDictsTuples(self.info) )
......
......@@ -1999,6 +1999,16 @@ class T_Scan(unittest.TestCase):
conv = theano.tensor.signal.conv.conv2d(m1, m2)
def test_hash(self):
x = theano.tensor.vector()
y = theano.tensor.vector()
scan1,updates = theano.scan(lambda _x:_x+1, x )
scan2,updates = theano.scan(lambda _x:_x+1, y )
assert scan1.owner.op == scan2.owner.op
assert hash(scan1.owner.op) == hash(scan2.owner.op)
if __name__ == '__main__':
#'''
print ' Use nosetests to run these tests '
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论