提交 7ba1082f authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: sentient07

refactor local_useless_alloc and local_canonicalize_alloc opt

上级 a6994e93
......@@ -1749,12 +1749,33 @@ def local_useless_fill(node):
return [v]
@register_useless
@gof.local_optimizer([T.alloc])
def local_useless_alloc(node):
"""
If the input type is the same as the output type (dtype and broadcast)
there is no change in the shape of the input. So this is just a simple copy
of the input. This is not needed.
"""
op = node.op
if not isinstance(op, Alloc):
return False
input = node.inputs[0]
output = node.outputs[0]
# Check if dtype and broadcast remain the same.
if input.type == output.type:
# We don't need to copy over any stack traces here
return [input]
@register_specialize
@register_stabilize
@register_canonicalize
@register_useless
@gof.local_optimizer([T.alloc])
def local_useless_alloc(node):
def local_canonicalize_alloc(node):
"""
If the input type is the same as the output type (dtype and broadcast)
there is no change in the shape of the input. So this is just a simple copy
......@@ -1778,8 +1799,8 @@ def local_useless_alloc(node):
for client, i in clients:
if client != "output" and isinstance(client.op, Alloc):
return
# Check if alloc adds a broadcastable dimension with shape 1.
output_shape = node.inputs[1:]
num_dims_with_size_1_added_to_left = 0
for i in range(len(output_shape) - input.ndim):
......
......@@ -3174,7 +3174,7 @@ class Test_local_elemwise_alloc(unittest.TestCase):
# Exclude local_useless_alloc, since it does not introduce
# assert in all the same cases.
self.fast_run_mode = self.fast_run_mode.excluding(
'local_useless_alloc')
'local_useless_alloc', 'local_canonicalize_alloc')
# No optimization on alloc
func = function(
[self.vec, self.mat],
......@@ -3669,7 +3669,7 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
self.assert_eqs_const(f, 0)
class Test_local_useless_alloc(unittest.TestCase):
class Test_local_canonicalize_alloc(unittest.TestCase):
def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed())
......@@ -3691,11 +3691,11 @@ class Test_local_useless_alloc(unittest.TestCase):
self.assertRaises(ValueError, f)
# No need to check_stack_trace as the optimization
# local_useless_alloc only removes nodes.
# local_canonicalize_alloc only removes nodes.
def test1(self):
# Test that alloc never gets instantiated during optimization
mode = mode_opt.excluding('local_useless_alloc')
mode = mode_opt.excluding('local_canonicalize_alloc')
x = tensor.matrix('x')
xx = tensor.fill(x, x)
......@@ -3707,11 +3707,11 @@ class Test_local_useless_alloc(unittest.TestCase):
assert tensor.Alloc not in op_classes
# No need to check_stack_trace as the optimization
# local_useless_alloc only removes nodes.
# local_canonicalize_alloc only removes nodes.
def test2(self):
# Test that alloc never gets instantiated during optimization
mode = mode_opt.excluding('local_useless_alloc')
mode = mode_opt.excluding('local_canonicalize_alloc')
x = tensor.matrix('x')
y = tensor.tile(x, (1,)*2)
......@@ -3729,7 +3729,7 @@ class Test_local_useless_alloc(unittest.TestCase):
# The correct opt removes nodes, no need for check_stack_trace
def test_useless_alloc_with_shape_one(self):
alloc_lift = out2in(local_useless_alloc)
alloc_lift = out2in(local_canonicalize_alloc)
x = shared(self.rng.randn(2,))
y = shared(self.rng.randn())
z = shared(self.rng.randn(1, 1))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论