提交 1e739294 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed infinite canonizer loop with NaN constants

上级 84ac684c
...@@ -2761,19 +2761,29 @@ class Canonizer(gof.LocalOptimizer): ...@@ -2761,19 +2761,29 @@ class Canonizer(gof.LocalOptimizer):
# Wrapping ct in a Constant with the right dtype # Wrapping ct in a Constant with the right dtype
ct = [T.constant(c, dtype=out_type.dtype) for c in ct] ct = [T.constant(c, dtype=out_type.dtype) for c in ct]
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct and\ if orig_num and len(numct) == 1 and len(denumct) == 0 and ct:
N.all([c.data for c in ct] == self.get_constant(orig_num[0])): # In that case we should only have one constant in `ct`.
# this is an important trick :( if it so happens that: assert len(ct) == 1
first_num_ct = self.get_constant(orig_num[0])
if first_num_ct is not None and ct[0].type.values_eq(ct[0].data,
first_num_ct):
# This is an important trick :( if it so happens that:
# * there's exactly one constant on the numerator and none on # * there's exactly one constant on the numerator and none on
# the denominator # the denominator
# * it's not the neutral element (ct is an empty list in that case) # * it's not the neutral element (ct is an empty list in that
# * the constant is the same as the first argument in the numerator # case)
# Then we return very exactly the original num/denum # * the constant is the same as the first argument in the
# numerator (we only check the first argument because the
# canonizer puts the computed constants first)
# -> then we return very exactly the original num/denum.
# If we don't do that the optimizer will just loop # If we don't do that the optimizer will just loop
# infinitely because it will not catch on that there are # infinitely because it will not catch on that there are
# no changes to be made and everytime it will want to # no changes to be made and everytime it will want to
# replace something by the same thing... # replace something by the same thing...
# Note that it is important to use `values_eq` instead of
# the == operator, to handle NaN values correctly.
return orig_num, orig_denum return orig_num, orig_denum
return ct + num, denum return ct + num, denum
def transform(self, node): def transform(self, node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论