提交 8f2a43d4 authored 作者: Cory Lorenz's avatar Cory Lorenz

Remove consecutive functional inverses to optimize

上级 a6f7beb7
...@@ -34,6 +34,7 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice, ...@@ -34,6 +34,7 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
AdvancedSubtensor1, AdvancedSubtensor1,
advanced_inc_subtensor1) advanced_inc_subtensor1)
from theano import scalar from theano import scalar
from theano.scalar import basic
from theano.tensor import basic as T from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file from theano import compile # to register the optimizer built by this file
from theano.compile.ops import Shape_i from theano.compile.ops import Shape_i
...@@ -1489,6 +1490,55 @@ def local_cast_cast(node): ...@@ -1489,6 +1490,55 @@ def local_cast_cast(node):
return [x] return [x]
@register_canonicalize
@register_specialize
@gof.local_optimizer([T.Elemwise])
def local_func_inv(node):
"""
Check for two consecutive operations that are functional inverses
and remove them from the function graph
"""
inv_pairs = (
(basic.Deg2Rad, basic.Rad2Deg),
(basic.Cosh, basic.ArcCosh),
(basic.Tanh, basic.ArcTanh),
(basic.Sinh, basic.ArcSinh),
(basic.Conj, basic.Conj),
(basic.Neg, basic.Neg),
(basic.Inv, basic.Inv),
)
x = node.inputs[0]
if not isinstance(node.op, T.Elemwise):
return
if (not x.owner or not isinstance(x.owner.op, T.Elemwise)):
return
prev_op = x.owner.op.scalar_op
node_op = node.op.scalar_op
for inv_pair in inv_pairs:
if is_inverse_pair(node_op, prev_op, inv_pair):
return x.owner.inputs
return
def is_inverse_pair(node_op, prev_op, inv_pair):
"""
Given two consecutive operations, check if they are the
provided pair of inverse functions
"""
node_is_op0 = isinstance(node_op, inv_pair[0])
node_is_op1 = isinstance(node_op, inv_pair[1])
prev_is_op0 = isinstance(prev_op, inv_pair[0])
prev_is_op1 = isinstance(prev_op, inv_pair[1])
return (node_is_op0 and prev_is_op1) or (node_is_op1 and prev_is_op0)
class Assert(T.Op): class Assert(T.Op):
""" """
Implements assertion in a computational graph. Implements assertion in a computational graph.
......
...@@ -3812,6 +3812,56 @@ class T_cast_cast(unittest.TestCase): ...@@ -3812,6 +3812,56 @@ class T_cast_cast(unittest.TestCase):
assert isinstance(topo[0].op, T.Elemwise) assert isinstance(topo[0].op, T.Elemwise)
class T_func_inverse(unittest.TestCase):
def setUp(self):
mode = theano.compile.get_default_mode()
self.mode = mode.including('local_func_inv')
def test(self):
"""
test that consecutive ops that are functional inverses are removed
"""
x = T.fmatrix()
o = T.deg2rad(T.rad2deg(x))
f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4).astype("float32")
delta = f(dx) - dx
topo = f.maker.fgraph.toposort()
self.assertEqual(len(topo), 1)
self.assertTrue(numpy.all(delta == 0))
self.assertTrue(isinstance(topo[0].op, DeepCopyOp),
"Inverse functions not removed!")
# Test that the other ordering of functions works
x = T.fmatrix()
o = T.rad2deg(T.deg2rad(x))
f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4).astype("float32")*180
delta = f(dx) - dx
topo = f.maker.fgraph.toposort()
self.assertEqual(len(topo), 1)
self.assertTrue(numpy.all(delta == 0))
self.assertTrue(isinstance(topo[0].op, DeepCopyOp),
"Inverse functions not removed!")
# Test that non-inverse functions are ran normally
x = T.fmatrix()
o = T.rad2deg(T.rad2deg(x))
f = theano.function([x], o, mode=self.mode)
dx = numpy.random.rand(5, 4).astype("float32")+0.01
delta = f(dx) - dx
topo = f.maker.fgraph.toposort()
self.assertEqual(len(topo), 1)
self.assertTrue(numpy.all(delta != 0))
self.assertFalse(isinstance(topo[0].op, DeepCopyOp),
"Non-inverse functions removed!")
def test_constant_folding(): def test_constant_folding():
""" Test that constant folding get registered at fast_compile """ Test that constant folding get registered at fast_compile
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论