提交 9ca91ae5 authored 作者: James Bergstra's avatar James Bergstra

Fixed check_x_over_absX to not replace complex arguments

上级 7397cdf4
...@@ -1991,9 +1991,16 @@ def check_for_x_over_absX(numerators, denominators): ...@@ -1991,9 +1991,16 @@ def check_for_x_over_absX(numerators, denominators):
# This won't catch a dimshuffled absolute value # This won't catch a dimshuffled absolute value
for den in list(denominators): for den in list(denominators):
if den.owner and den.owner.op == T.abs_ and den.owner.inputs[0] in numerators: if den.owner and den.owner.op == T.abs_ and den.owner.inputs[0] in numerators:
denominators.remove(den) if den.owner.inputs[0].type.dtype.startswith('complex'):
numerators.remove(den.owner.inputs[0]) #TODO: Make an Op that projects a complex number to have unit length
numerators.append(T.sgn(den.owner.inputs[0])) # but projects 0 to 0. That would be a weird Op, but consistent with the
# special case below. I heard there's some convention in Matlab that is
# similar to this... but not sure.
pass
else:
denominators.remove(den)
numerators.remove(den.owner.inputs[0])
numerators.append(T.sgn(den.owner.inputs[0]))
return numerators, denominators return numerators, denominators
local_mul_canonizer.add_simplifier(check_for_x_over_absX, 'teststest') local_mul_canonizer.add_simplifier(check_for_x_over_absX, 'teststest')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论