提交 4ee7198d authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2760 from clorenz7/optimization/deg2rad

Remove consecutive functional inverses to optimize
...@@ -34,6 +34,7 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice, ...@@ -34,6 +34,7 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
AdvancedSubtensor1, AdvancedSubtensor1,
advanced_inc_subtensor1) advanced_inc_subtensor1)
from theano import scalar from theano import scalar
from theano.scalar import basic
from theano.tensor import basic as T from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file from theano import compile # to register the optimizer built by this file
from theano.compile.ops import Shape_i from theano.compile.ops import Shape_i
...@@ -1489,6 +1490,55 @@ def local_cast_cast(node): ...@@ -1489,6 +1490,55 @@ def local_cast_cast(node):
return [x] return [x]
@register_canonicalize
@register_specialize
@gof.local_optimizer([T.Elemwise])
def local_func_inv(node):
"""
Check for two consecutive operations that are functional inverses
and remove them from the function graph
"""
inv_pairs = (
(basic.Deg2Rad, basic.Rad2Deg),
(basic.Cosh, basic.ArcCosh),
(basic.Tanh, basic.ArcTanh),
(basic.Sinh, basic.ArcSinh),
(basic.Conj, basic.Conj),
(basic.Neg, basic.Neg),
(basic.Inv, basic.Inv),
)
x = node.inputs[0]
if not isinstance(node.op, T.Elemwise):
return
if (not x.owner or not isinstance(x.owner.op, T.Elemwise)):
return
prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op
for inv_pair in inv_pairs:
if is_inverse_pair(node_op, prev_op, inv_pair):
return x.owner.inputs
return
def is_inverse_pair(node_op, prev_op, inv_pair):
"""
Given two consecutive operations, check if they are the
provided pair of inverse functions
"""
node_is_op0 = isinstance(node_op, inv_pair[0])
node_is_op1 = isinstance(node_op, inv_pair[1])
prev_is_op0 = isinstance(prev_op, inv_pair[0])
prev_is_op1 = isinstance(prev_op, inv_pair[1])
return (node_is_op0 and prev_is_op1) or (node_is_op1 and prev_is_op0)
class Assert(T.Op): class Assert(T.Op):
""" """
Implements assertion in a computational graph. Implements assertion in a computational graph.
......
...@@ -1591,7 +1591,7 @@ def test_local_useless_slice(): ...@@ -1591,7 +1591,7 @@ def test_local_useless_slice():
apply_node = f_opt.maker.fgraph.toposort()[0] apply_node = f_opt.maker.fgraph.toposort()[0]
subtens = apply_node.op subtens = apply_node.op
assert not any(isinstance(idx, slice) for idx in subtens.idx_list), "Slice should be gone" assert not any(isinstance(idx, slice) for idx in subtens.idx_list), "Slice should be gone"
# test a 4d tensor # test a 4d tensor
z = tensor.tensor4('z') z = tensor.tensor4('z')
o2 = z[1, :, :, 1] o2 = z[1, :, :, 1]
...@@ -3812,6 +3812,71 @@ class T_cast_cast(unittest.TestCase): ...@@ -3812,6 +3812,71 @@ class T_cast_cast(unittest.TestCase):
assert isinstance(topo[0].op, T.Elemwise) assert isinstance(topo[0].op, T.Elemwise)
class T_func_inverse(unittest.TestCase):
def setUp(self):
mode = theano.compile.get_default_mode()
self.mode = mode.including('local_func_inv')
def assert_func_pair_optimized(self, func1, func2, data,
should_copy=True, is_complex=False):
"""
Check that a pair of funcs is optimized properly
"""
x = T.cmatrix() if is_complex else T.fmatrix()
o = func2(func1(x))
f = theano.function([x], o, mode=self.mode)
delta = f(data) - data
topo = f.maker.fgraph.toposort()
if should_copy:
acceptable_topo_lens = [1]
else:
# The 2 funcs can be split apart if they are not inverses
acceptable_topo_lens = [1, 2]
if should_copy:
delta_condition = numpy.all(delta == 0)
else:
delta_condition = numpy.all(delta != 0)
self.assertTrue(len(topo) in acceptable_topo_lens)
self.assertTrue(delta_condition)
self.assertEqual(isinstance(topo[0].op, DeepCopyOp), should_copy,
"Inverse functions not removed!")
def test(self):
"""
test optimization for consecutive functional inverses
"""
dx = numpy.random.rand(5, 4).astype("float32")
self.assert_func_pair_optimized(T.deg2rad, T.rad2deg, dx)
dx = numpy.random.rand(5, 4).astype("float32")*180
self.assert_func_pair_optimized(T.rad2deg, T.deg2rad, dx)
# Test the other functional inverses
dx = numpy.random.rand(5, 4).astype("float32")
self.assert_func_pair_optimized(T.cosh, T.arccosh, dx)
self.assert_func_pair_optimized(T.arcsinh, T.sinh, dx)
self.assert_func_pair_optimized(T.arctanh, T.tanh, dx)
self.assert_func_pair_optimized(T.inv, T.inv, dx)
self.assert_func_pair_optimized(T.neg, T.neg, dx)
cx = dx + complex(0, 1)*(dx + 0.01)
self.assert_func_pair_optimized(T.conj, T.conj, cx, is_complex=True)
# Test that non-inverse functions are ran normally
self.assert_func_pair_optimized(T.conj, T.neg, cx,
should_copy=False, is_complex=True)
dx = numpy.random.rand(5, 4).astype("float32")+0.01
self.assert_func_pair_optimized(T.rad2deg, T.rad2deg, dx,
should_copy=False)
self.assert_func_pair_optimized(T.rad2deg, T.cosh, dx,
should_copy=False)
def test_constant_folding(): def test_constant_folding():
""" Test that constant folding get registered at fast_compile """ Test that constant folding get registered at fast_compile
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论