提交 30170732 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fix the equal comparison between scan ops

上级 bcba6436
...@@ -370,31 +370,41 @@ class Scan(PureOp): ...@@ -370,31 +370,41 @@ class Scan(PureOp):
# Check if we are dealing with same type of objects # Check if we are dealing with same type of objects
if not type(self) == type(other): if not type(self) == type(other):
return False return False
if not 'destroy_map' in self.info:
self.info['destroy_map'] = {}
if not 'destroy_map' in other.info:
other.info['destroy_map'] = {}
keys_to_check = ['truncate_gradient', 'profile',
'n_seqs', 'tap_array', 'name',
'as_while', 'n_mit_sot', 'destroy_map',
'n_nit_sot', 'n_shared_outs',
'n_sit_sot', 'gpu', 'n_mit_mot_outs',
'n_mit_mot', 'mit_mot_out_slices']
# This are some safety checks ( namely that the inner graph has the # This are some safety checks ( namely that the inner graph has the
# same number of inputs and same number of outputs ) # same number of inputs and same number of outputs )
elif not len(self.inputs) == len(other.inputs): if not len(self.inputs) == len(other.inputs):
return False return False
elif not len(self.outputs) == len(other.outputs): elif not len(self.outputs) == len(other.outputs):
return False return False
elif self.info != other.info: for key in keys_to_check:
return False if self.info[key] != other.info[key]:
else: return False
# If everything went OK up to here, there is still one thing to # If everything went OK up to here, there is still one thing to
# check. Namely, do the internal graph represent same # check. Namely, do the internal graph represent same
# computations # computations
for self_in, other_in in izip(self.inputs, other.inputs): for self_in, other_in in izip(self.inputs, other.inputs):
if self_in.type != other_in.type: if self_in.type != other_in.type:
return False
if not scan_utils.equal_computations(self.outputs,
other.outputs,
self.inputs,
other.inputs):
return False return False
# If they do, then they need to match in other small details if not scan_utils.equal_computations(self.outputs,
# like name, mode, etc. other.outputs,
return True self.inputs,
other.inputs):
return False
# If they do, then they need to match in other small details
# like name, mode, etc.
return True
def __str__(self): def __str__(self):
if self.gpu: if self.gpu:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论