提交 8724e6e3 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixed the hash function such that two ops have same hash when they compare

equal and the other way around.
上级 845ab744
...@@ -214,6 +214,9 @@ class Scan(Op): ...@@ -214,6 +214,9 @@ class Scan(Op):
self.n_shared_outs ) self.n_shared_outs )
self.n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot self.n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
self.n_tap_outs = self.n_mit_mot + self.n_mit_sot self.n_tap_outs = self.n_mit_mot + self.n_mit_sot
self._cmodule_key = gof.CLinker.cmodule_key_(self.fn.maker.env,[])
self._hash_inner_graph = hash(self._cmodule_key)
def make_node(self, *inputs): def make_node(self, *inputs):
assert numpy.all(isinstance(i, gof.Variable) for i in inputs) assert numpy.all(isinstance(i, gof.Variable) for i in inputs)
...@@ -321,19 +324,27 @@ class Scan(Op): ...@@ -321,19 +324,27 @@ class Scan(Op):
return apply_node return apply_node
def __eq__(self, other): def __eq__(self, other):
# Check if we are dealing with same type of objects
if not type(self) == type(other): if not type(self) == type(other):
return False return False
# This are some safety checks ( namely that the inner graph has the
# same number of inputs and same number of outputs )
elif not len(self.inputs) == len(other.inputs): elif not len(self.inputs) == len(other.inputs):
return False return False
elif not len(self.outputs) == len(other.outputs): elif not len(self.outputs) == len(other.outputs):
return False return False
else: else:
# If everything went OK up to here, there is still one thing to
# check. Namely, do the internal graph represent same
# computations
for x,y in zip(self.inputs, other.inputs): for x,y in zip(self.inputs, other.inputs):
if not scan_utils.equal_computations(x,y): if not scan_utils.equal_computations(x,y):
return False return False
for x,y in zip(self.outputs, other.outputs): for x,y in zip(self.outputs, other.outputs):
if not scan_utils.equal_computations(x,y): if not scan_utils.equal_computations(x,y):
return False return False
# If they do, then they need to match in other small details
# like name, mode, etc.
return self.info == other.info return self.info == other.info
def __str__(self): def __str__(self):
...@@ -353,10 +364,10 @@ class Scan(Op): ...@@ -353,10 +364,10 @@ class Scan(Op):
def __hash__(self): def __hash__(self):
# any two objects that *might* compare equal must hash equal.
# so we don't check the self.inputs and self.outputs here and it's
# ok.
return ( hash(type(self)) ^ return ( hash(type(self)) ^
# and a hash representing the inner graph using the
# CLinker.cmodule_key_
self._hash_inner_graph ^
scan_utils.hash_listsDictsTuples(self.info) ) scan_utils.hash_listsDictsTuples(self.info) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论