提交 63cf9ea0 authored 作者: Michael I Mandel's avatar Michael I Mandel

added tests for x over abs(x) canonization

上级 3b18fef4
...@@ -1987,6 +1987,7 @@ register_specialize(local_add_specialize) ...@@ -1987,6 +1987,7 @@ register_specialize(local_add_specialize)
mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut, local_fill_sink)) mul_canonizer = in2out(gof.LocalOptGroup(local_mul_canonizer, local_fill_cut, local_fill_sink))
def check_for_x_over_absX(numerators, denominators): def check_for_x_over_absX(numerators, denominators):
"""Convert x/abs(x) into sign(x). """
# TODO: this function should dig/search through dimshuffles # TODO: this function should dig/search through dimshuffles
# This won't catch a dimshuffled absolute value # This won't catch a dimshuffled absolute value
for den in list(denominators): for den in list(denominators):
......
...@@ -466,6 +466,23 @@ class test_canonize(unittest.TestCase): ...@@ -466,6 +466,23 @@ class test_canonize(unittest.TestCase):
topo=f.maker.env.toposort() topo=f.maker.env.toposort()
assert len(topo)==0 assert len(topo)==0
assert(out_dtype==out.dtype) assert(out_dtype==out.dtype)
#test x / abs(x) -> sign(x)
for id,(g, sym_inputs, val_inputs, out_dtype) in enumerate([
(dx/abs(dx),[dx],[0.5-dxv],'float64'),
(fx/abs(fx),[fx],[0.5-fxv],'float32'),
(dx/abs(dx),[dx],[0.0*dxv],'float64'),
(fx/abs(fx),[fx],[0.0*fxv],'float32'),
(dv/abs(dv),[dv],[0.5-dvv],'float64'),
(fv/abs(fv),[fv],[0.5-fvv],'float32'),
]):
f = compile.function(list(sym_inputs), g,
mode=mode)
out = f(*val_inputs)
print out
assert numpy.all(numpy.isfinite(out))
assert numpy.allclose(out,numpy.sign(val_inputs[0]))
assert(out_dtype==out.dtype)
finally: finally:
mode._optimizer = old_optimizer mode._optimizer = old_optimizer
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论