提交 7ba1082f authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: sentient07

refactor local_useless_alloc and local_canonicalize_alloc opt

上级 a6994e93
...@@ -1749,12 +1749,33 @@ def local_useless_fill(node): ...@@ -1749,12 +1749,33 @@ def local_useless_fill(node):
return [v] return [v]
@register_useless
@gof.local_optimizer([T.alloc])
def local_useless_alloc(node):
"""
If the input type is the same as the output type (dtype and broadcast)
there is no change in the shape of the input. So this is just a simple copy
of the input. This is not needed.
"""
op = node.op
if not isinstance(op, Alloc):
return False
input = node.inputs[0]
output = node.outputs[0]
# Check if dtype and broadcast remain the same.
if input.type == output.type:
# We don't need to copy over any stack traces here
return [input]
@register_specialize @register_specialize
@register_stabilize @register_stabilize
@register_canonicalize @register_canonicalize
@register_useless
@gof.local_optimizer([T.alloc]) @gof.local_optimizer([T.alloc])
def local_useless_alloc(node): def local_canonicalize_alloc(node):
""" """
If the input type is the same as the output type (dtype and broadcast) If the input type is the same as the output type (dtype and broadcast)
there is no change in the shape of the input. So this is just a simple copy there is no change in the shape of the input. So this is just a simple copy
...@@ -1778,8 +1799,8 @@ def local_useless_alloc(node): ...@@ -1778,8 +1799,8 @@ def local_useless_alloc(node):
for client, i in clients: for client, i in clients:
if client != "output" and isinstance(client.op, Alloc): if client != "output" and isinstance(client.op, Alloc):
return return
# Check if alloc adds a broadcastable dimension with shape 1. # Check if alloc adds a broadcastable dimension with shape 1.
output_shape = node.inputs[1:] output_shape = node.inputs[1:]
num_dims_with_size_1_added_to_left = 0 num_dims_with_size_1_added_to_left = 0
for i in range(len(output_shape) - input.ndim): for i in range(len(output_shape) - input.ndim):
......
...@@ -3174,7 +3174,7 @@ class Test_local_elemwise_alloc(unittest.TestCase): ...@@ -3174,7 +3174,7 @@ class Test_local_elemwise_alloc(unittest.TestCase):
# Exclude local_useless_alloc, since it does not introduce # Exclude local_useless_alloc, since it does not introduce
# assert in all the same cases. # assert in all the same cases.
self.fast_run_mode = self.fast_run_mode.excluding( self.fast_run_mode = self.fast_run_mode.excluding(
'local_useless_alloc') 'local_useless_alloc', 'local_canonicalize_alloc')
# No optimization on alloc # No optimization on alloc
func = function( func = function(
[self.vec, self.mat], [self.vec, self.mat],
...@@ -3669,7 +3669,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3669,7 +3669,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
class Test_local_useless_alloc(unittest.TestCase): class Test_local_canonicalize_alloc(unittest.TestCase):
def setUp(self): def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed()) self.rng = numpy.random.RandomState(utt.fetch_seed())
...@@ -3691,11 +3691,11 @@ class Test_local_useless_alloc(unittest.TestCase): ...@@ -3691,11 +3691,11 @@ class Test_local_useless_alloc(unittest.TestCase):
self.assertRaises(ValueError, f) self.assertRaises(ValueError, f)
# No need to check_stack_trace as the optimization # No need to check_stack_trace as the optimization
# local_useless_alloc only removes nodes. # local_canonicalize_alloc only removes nodes.
def test1(self): def test1(self):
# Test that alloc never gets instantiated during optimization # Test that alloc never gets instantiated during optimization
mode = mode_opt.excluding('local_useless_alloc') mode = mode_opt.excluding('local_canonicalize_alloc')
x = tensor.matrix('x') x = tensor.matrix('x')
xx = tensor.fill(x, x) xx = tensor.fill(x, x)
...@@ -3707,11 +3707,11 @@ class Test_local_useless_alloc(unittest.TestCase): ...@@ -3707,11 +3707,11 @@ class Test_local_useless_alloc(unittest.TestCase):
assert tensor.Alloc not in op_classes assert tensor.Alloc not in op_classes
# No need to check_stack_trace as the optimization # No need to check_stack_trace as the optimization
# local_useless_alloc only removes nodes. # local_canonicalize_alloc only removes nodes.
def test2(self): def test2(self):
# Test that alloc never gets instantiated during optimization # Test that alloc never gets instantiated during optimization
mode = mode_opt.excluding('local_useless_alloc') mode = mode_opt.excluding('local_canonicalize_alloc')
x = tensor.matrix('x') x = tensor.matrix('x')
y = tensor.tile(x, (1,)*2) y = tensor.tile(x, (1,)*2)
...@@ -3729,7 +3729,7 @@ class Test_local_useless_alloc(unittest.TestCase): ...@@ -3729,7 +3729,7 @@ class Test_local_useless_alloc(unittest.TestCase):
# The correct opt removes nodes, no need for check_stack_trace # The correct opt removes nodes, no need for check_stack_trace
def test_useless_alloc_with_shape_one(self): def test_useless_alloc_with_shape_one(self):
alloc_lift = out2in(local_useless_alloc) alloc_lift = out2in(local_canonicalize_alloc)
x = shared(self.rng.randn(2,)) x = shared(self.rng.randn(2,))
y = shared(self.rng.randn()) y = shared(self.rng.randn())
z = shared(self.rng.randn(1, 1)) z = shared(self.rng.randn(1, 1))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论