提交 357f8a08 authored 作者: Cory Lorenz's avatar Cory Lorenz

Inv Func Opt: Fix Fast Compile Mode & More Tests

In fast compile mode, non-inverse functions can be split into two operations, which broke one of the asserts. Also added exhaustive tests for all inverse function pairs. Abstracted out the test assertions to enhance readability.
上级 8f2a43d4
......@@ -3818,48 +3818,63 @@ class T_func_inverse(unittest.TestCase):
mode = theano.compile.get_default_mode()
self.mode = mode.including('local_func_inv')
def test(self):
def assert_func_pair_optimized(self, func1, func2, data,
should_copy=True, is_complex=False):
"""
test that consecutive ops that are functional inverses are removed
Check that a pair of funcs is optimized properly
"""
x = T.fmatrix()
o = T.deg2rad(T.rad2deg(x))
x = T.cmatrix() if is_complex else T.fmatrix()
o = func2(func1(x))
f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4).astype("float32")
delta = f(dx) - dx
delta = f(data) - data
topo = f.maker.fgraph.toposort()
self.assertEqual(len(topo), 1)
self.assertTrue(numpy.all(delta == 0))
self.assertTrue(isinstance(topo[0].op, DeepCopyOp),
"Inverse functions not removed!")
if should_copy:
acceptable_topo_lens = [1]
else:
# The 2 funcs can be split apart if they are not inverses
acceptable_topo_lens = [1, 2]
# Test that the other ordering of functions works
x = T.fmatrix()
o = T.rad2deg(T.deg2rad(x))
f = theano.function([x], o, mode=self.mode)
if should_copy:
delta_condition = numpy.all(delta == 0)
else:
delta_condition = numpy.all(delta != 0)
self.assertTrue(len(topo) in acceptable_topo_lens)
self.assertTrue(delta_condition)
self.assertEqual(isinstance(topo[0].op, DeepCopyOp), should_copy,
"Inverse functions not removed!")
def test(self):
"""
test optimization for consecutive functional inverses
"""
dx = numpy.random.rand(5, 4).astype("float32")
self.assert_func_pair_optimized(T.deg2rad, T.rad2deg, dx)
dx = numpy.random.rand(5, 4).astype("float32")*180
delta = f(dx) - dx
topo = f.maker.fgraph.toposort()
self.assert_func_pair_optimized(T.rad2deg, T.deg2rad, dx)
self.assertEqual(len(topo), 1)
self.assertTrue(numpy.all(delta == 0))
self.assertTrue(isinstance(topo[0].op, DeepCopyOp),
"Inverse functions not removed!")
# Test the other functional inverses
dx = numpy.random.rand(5, 4).astype("float32")
self.assert_func_pair_optimized(T.cosh, T.arccosh, dx)
self.assert_func_pair_optimized(T.arcsinh, T.sinh, dx)
self.assert_func_pair_optimized(T.arctanh, T.tanh, dx)
self.assert_func_pair_optimized(T.inv, T.inv, dx)
self.assert_func_pair_optimized(T.neg, T.neg, dx)
cx = dx + complex(0, 1)*(dx + 0.01)
self.assert_func_pair_optimized(T.conj, T.conj, cx, is_complex=True)
# Test that non-inverse functions are ran normally
x = T.fmatrix()
o = T.rad2deg(T.rad2deg(x))
f = theano.function([x], o, mode=self.mode)
self.assert_func_pair_optimized(T.conj, T.neg, cx,
should_copy=False, is_complex=True)
dx = numpy.random.rand(5, 4).astype("float32")+0.01
delta = f(dx) - dx
topo = f.maker.fgraph.toposort()
self.assertEqual(len(topo), 1)
self.assertTrue(numpy.all(delta != 0))
self.assertFalse(isinstance(topo[0].op, DeepCopyOp),
"Non-inverse functions removed!")
self.assert_func_pair_optimized(T.rad2deg, T.rad2deg, dx,
should_copy=False)
self.assert_func_pair_optimized(T.rad2deg, T.cosh, dx,
should_copy=False)
def test_constant_folding():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论