提交 357f8a08 authored 作者: Cory Lorenz's avatar Cory Lorenz

Inv Func Opt: Fix Fast Compile Mode & More Tests

In fast compile mode, non-inverse functions can be split into two operations, which broke one of the asserts. Also added exhaustive tests for all inverse function pairs. Abstracted out the test assertions to enhance readability.
上级 8f2a43d4
...@@ -3818,48 +3818,63 @@ class T_func_inverse(unittest.TestCase): ...@@ -3818,48 +3818,63 @@ class T_func_inverse(unittest.TestCase):
mode = theano.compile.get_default_mode() mode = theano.compile.get_default_mode()
self.mode = mode.including('local_func_inv') self.mode = mode.including('local_func_inv')
def test(self):
def assert_func_pair_optimized(self, func1, func2, data,
should_copy=True, is_complex=False):
""" """
test that consecutive ops that are functional inverses are removed Check that a pair of funcs is optimized properly
""" """
x = T.fmatrix() x = T.cmatrix() if is_complex else T.fmatrix()
o = T.deg2rad(T.rad2deg(x)) o = func2(func1(x))
f = theano.function([x], o, mode=self.mode) f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4).astype("float32") delta = f(data) - data
delta = f(dx) - dx
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
self.assertEqual(len(topo), 1) if should_copy:
self.assertTrue(numpy.all(delta == 0)) acceptable_topo_lens = [1]
self.assertTrue(isinstance(topo[0].op, DeepCopyOp), else:
# The 2 funcs can be split apart if they are not inverses
acceptable_topo_lens = [1, 2]
if should_copy:
delta_condition = numpy.all(delta == 0)
else:
delta_condition = numpy.all(delta != 0)
self.assertTrue(len(topo) in acceptable_topo_lens)
self.assertTrue(delta_condition)
self.assertEqual(isinstance(topo[0].op, DeepCopyOp), should_copy,
"Inverse functions not removed!") "Inverse functions not removed!")
# Test that the other ordering of functions works def test(self):
x = T.fmatrix() """
o = T.rad2deg(T.deg2rad(x)) test optimization for consecutive functional inverses
f = theano.function([x], o, mode=self.mode) """
dx = numpy.random.rand(5, 4).astype("float32")
self.assert_func_pair_optimized(T.deg2rad, T.rad2deg, dx)
dx = numpy.random.rand(5, 4).astype("float32")*180 dx = numpy.random.rand(5, 4).astype("float32")*180
delta = f(dx) - dx self.assert_func_pair_optimized(T.rad2deg, T.deg2rad, dx)
topo = f.maker.fgraph.toposort()
self.assertEqual(len(topo), 1) # Test the other functional inverses
self.assertTrue(numpy.all(delta == 0)) dx = numpy.random.rand(5, 4).astype("float32")
self.assertTrue(isinstance(topo[0].op, DeepCopyOp), self.assert_func_pair_optimized(T.cosh, T.arccosh, dx)
"Inverse functions not removed!") self.assert_func_pair_optimized(T.arcsinh, T.sinh, dx)
self.assert_func_pair_optimized(T.arctanh, T.tanh, dx)
self.assert_func_pair_optimized(T.inv, T.inv, dx)
self.assert_func_pair_optimized(T.neg, T.neg, dx)
cx = dx + complex(0, 1)*(dx + 0.01)
self.assert_func_pair_optimized(T.conj, T.conj, cx, is_complex=True)
# Test that non-inverse functions are ran normally # Test that non-inverse functions are ran normally
x = T.fmatrix() self.assert_func_pair_optimized(T.conj, T.neg, cx,
o = T.rad2deg(T.rad2deg(x)) should_copy=False, is_complex=True)
f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4).astype("float32")+0.01 dx = numpy.random.rand(5, 4).astype("float32")+0.01
delta = f(dx) - dx self.assert_func_pair_optimized(T.rad2deg, T.rad2deg, dx,
topo = f.maker.fgraph.toposort() should_copy=False)
self.assert_func_pair_optimized(T.rad2deg, T.cosh, dx,
self.assertEqual(len(topo), 1) should_copy=False)
self.assertTrue(numpy.all(delta != 0))
self.assertFalse(isinstance(topo[0].op, DeepCopyOp),
"Non-inverse functions removed!")
def test_constant_folding(): def test_constant_folding():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论