提交 3867a916 authored 作者: Dustin Webb's avatar Dustin Webb

Improved support for multiple alloc and dimshuffle alloc combinations and added assocaited tests.

上级 4b0a5c10
...@@ -1554,14 +1554,14 @@ def local_alloc_elemwise(node): ...@@ -1554,14 +1554,14 @@ def local_alloc_elemwise(node):
if len(node.outputs) > 1: if len(node.outputs) > 1:
# Ensure all outputs have the same broadcast pattern # Ensure all outputs have the same broadcast pattern
# This is a supposition that I'm not sure is always true. # This is a supposition that I'm not sure is always true.
assert all([list(o.type.broadcastable) == list( assert all([o.type.broadcastable ==
node.outputs[0].type.broadcastable) for o in node.outputs[0].type.broadcastable for o in
node.outputs[1:]]) node.outputs[1:]])
# The broadcast pattern of the ouptut must match the broadcast pattern of # The broadcast pattern of the ouptut must match the broadcast pattern of
# at least one of the inputs. # at least one of the inputs.
if not any([list(i.type.broadcastable) == list( if not any([i.type.broadcastable ==
node.outputs[0].type.broadcastable) for i in node.inputs]): node.outputs[0].type.broadcastable for i in node.inputs]):
return False return False
def dimshuffled_alloc(i): def dimshuffled_alloc(i):
...@@ -1593,12 +1593,20 @@ def local_alloc_elemwise(node): ...@@ -1593,12 +1593,20 @@ def local_alloc_elemwise(node):
if assert_op_idx < 0: if assert_op_idx < 0:
# We want to optimize as many allocs as possible. When there is more # We want to optimize as many allocs as possible. When there is more
# than one then do all but one. # than one then do all but one.
l = [idx for idx, i in enumerate(node.inputs) # number of inputs with alloc or dimshuffle alloc
if i.type.broadcastable == node.outputs[0].type.broadcastable] l2 = [i for i in node.inputs
if len(l) > 1: if (i.owner and (isinstance(i.owner.op, T.Alloc)
or dimshuffled_alloc(i)))]
# If only 1 alloc or dimshuffle alloc, it is the one we will use for the shape
# So no alloc would be removed.
if len(l2) > 1:
# l containt inputs with alloc or dimshuffle alloc only.
# Its length will always be at least one, as we checked that before
l = [idx for idx, i in enumerate(node.inputs)
if i.type.broadcastable == node.outputs[0].type.broadcastable]
assert_op_idx = l[0] # The first one is as good as any to use. assert_op_idx = l[0] # The first one is as good as any to use.
else: else:
# Otherwise nothing can be done. # Nothing would be optimized!
return False return False
assert_op = node.inputs[assert_op_idx] assert_op = node.inputs[assert_op_idx]
......
...@@ -2620,6 +2620,47 @@ class Test_local_alloc_elemwise(unittest.TestCase): ...@@ -2620,6 +2620,47 @@ class Test_local_alloc_elemwise(unittest.TestCase):
self._verify_alloc_count(func, 0) self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 0) self._verify_assert_count(func, 0)
def test_multi_input_single_alloc(self):
tv = T.alloc(self.vec, 5, 5)
tm = T.alloc(self.mat, 5, 5, 5)
func = function(
[self.vec, self.mat],
tv + tm,
mode='FAST_COMPILE'
)
self._verify_alloc_count(func, 2)
self._verify_assert_count(func, 0)
func = function(
[self.vec, self.mat],
tv + tm,
mode='FAST_RUN'
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 0)
s = T.iscalar('s')
tv = T.alloc(self.vec, s, s)
tm = T.alloc(self.mat, 5, 5, 5)
func = function(
[self.vec, self.mat, s],
tv + tm,
mode='FAST_COMPILE'
)
self._verify_alloc_count(func, 2)
self._verify_assert_count(func, 0)
func = function(
[self.vec, self.mat, s],
tv + tm,
mode='FAST_RUN'
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 1)
def test_local_subtensor_of_alloc(): def test_local_subtensor_of_alloc():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论