提交 35ad2538 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove strict TensorType.broadcastable usage from local_elemwise_alloc

上级 8d5a8c8c
...@@ -68,7 +68,13 @@ from aesara.tensor.basic import ( ...@@ -68,7 +68,13 @@ from aesara.tensor.basic import (
) )
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError, ShapeError from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, broadcast_shape from aesara.tensor.extra_ops import (
BroadcastTo,
Repeat,
Unique,
broadcast_shape,
broadcast_to,
)
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq from aesara.tensor.math import eq
from aesara.tensor.shape import ( from aesara.tensor.shape import (
...@@ -1491,26 +1497,11 @@ def local_elemwise_alloc(fgraph, node): ...@@ -1491,26 +1497,11 @@ def local_elemwise_alloc(fgraph, node):
introduces them as a canonicalization of `Alloc`'s with leading introduces them as a canonicalization of `Alloc`'s with leading
broadcastable dimensions. broadcastable dimensions.
""" """
if not isinstance(node.op, Elemwise):
return False
# Rewrite is only applicable when there are at least two inputs # Rewrite is only applicable when there are at least two inputs
if len(node.inputs) == 1: if len(node.inputs) == 1:
return None return False
if len(node.outputs) > 1: if len(node.outputs) > 1:
# Ensure all outputs have the same broadcast pattern
# This is a supposition that I'm not sure is always true.
assert all(
o.type.broadcastable == node.outputs[0].type.broadcastable
for o in node.outputs[1:]
)
# The broadcast pattern of the output must match the broadcast
# pattern of at least one of the inputs.
if not any(
i.type.broadcastable == node.outputs[0].type.broadcastable for i in node.inputs
):
return False return False
def dimshuffled_alloc(i): def dimshuffled_alloc(i):
...@@ -1523,103 +1514,82 @@ def local_elemwise_alloc(fgraph, node): ...@@ -1523,103 +1514,82 @@ def local_elemwise_alloc(fgraph, node):
# At least one input must have an owner that is either a `Alloc` or a # At least one input must have an owner that is either a `Alloc` or a
# `DimShuffle` with an owner that is a `Alloc` -- otherwise there is # `DimShuffle` with an owner that is a `Alloc` -- otherwise there is
# nothing to optimize. # nothing to optimize.
if not any( alloc_idxs = [
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)) idx
for i in node.inputs for idx, i in enumerate(node.inputs)
): if i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
]
if len(alloc_idxs) == 0:
return False return False
# Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a # Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a
# baseline for the dimensions. # baseline for the dimensions.
assert_op_idx = None ref_var_idx = None
for idx, i in enumerate(node.inputs): for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable: if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not a `Alloc` nor a `DimShuffle` of a # Prefer an input that is not an `Alloc` nor a `DimShuffle` of an
# `Alloc` so that all `Alloc`s can be optimized. # `Alloc`, so that all `Alloc`s can be optimized.
if not ( if idx not in alloc_idxs:
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)) ref_var_idx = idx
):
assert_op_idx = idx
break break
# If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one # If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one
if assert_op_idx is None: if ref_var_idx is None:
for idx, i in enumerate(node.inputs): for idx, i in enumerate(node.inputs):
if (i.type.broadcastable == node.outputs[0].type.broadcastable) and ( # XXX: This broadcastable comparison doesn't work
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)) if (
): i.type.broadcastable == node.outputs[0].type.broadcastable
assert_op_idx = idx ) and idx in alloc_idxs:
ref_var_idx = idx
break break
assert_op_in = node.inputs[assert_op_idx] if not hasattr(fgraph, "shape_feature"):
cmp_op = assert_op_in return False
new_i = []
same_shape = fgraph.shape_feature.same_shape input_shapes = [
for i in node.inputs: tuple(fgraph.shape_feature.get_shape(i, j) for j in range(i.type.ndim))
for i in node.inputs
]
bcasted_shape = broadcast_shape(
*input_shapes,
arrays_are_shapes=True,
)
new_inputs = list(node.inputs)
for idx in alloc_idxs:
i = node.inputs[idx]
# Remove `Alloc` # Remove `Alloc`
if i.owner and isinstance(i.owner.op, Alloc): if isinstance(i.owner.op, Alloc):
assert i.type.ndim == cmp_op.ndim new_alloc = broadcast_to(i.owner.inputs[0], bcasted_shape)
if config.experimental__local_alloc_elemwise_assert:
get_shape = fgraph.shape_feature.get_shape
cond = []
for idx in range(i.type.ndim):
if not i.type.broadcastable[idx] and not same_shape(
i, cmp_op, idx, idx
):
i_shp = get_shape(i, idx)
cmp_shp = get_shape(cmp_op, idx)
cond.append(eq(i_shp, cmp_shp))
if cond:
assert_op_in = assert_op(assert_op_in, *cond)
alloc_input = i.owner.inputs[0]
if alloc_input.ndim != i.ndim:
# The `Alloc` can add dimensions to the value.
# We replace those cases with a `DimShuffle` here.
nb_dim_to_add = i.ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
)
copy_stack_trace(i, alloc_input)
new_i.append(alloc_input)
# TODO FIXME: This shouldn't be handled here.
# `DimShuffle`s should be lifted through `Alloc`s
# by other, more general rewrites.
# Remove `Alloc` in `DimShuffle` # Remove `Alloc` in `DimShuffle`
elif i.owner and dimshuffled_alloc(i): elif isinstance(i.owner.op, DimShuffle):
assert i.type.ndim == cmp_op.type.ndim old_alloc = i.owner.inputs[0]
if config.experimental__local_alloc_elemwise_assert: new_alloc = old_alloc.owner.inputs[0]
assert_cond = [ # We need to keep the old `DimShuffle`. It could swap axes or
eq(i.shape[idx], cmp_op.shape[idx]) # add dimensions anywhere.
for idx in range(i.type.ndim) if new_alloc.ndim != old_alloc.ndim:
if not i.type.broadcastable[idx]
and not same_shape(i, cmp_op, idx, idx)
]
if assert_cond:
assert_op_in = assert_op(assert_op_in, *assert_cond)
alloc_input = i.owner.inputs[0].owner.inputs[0]
if alloc_input.ndim != i.owner.inputs[0].ndim:
# The `Alloc` can add dimensions to the value. # The `Alloc` can add dimensions to the value.
# We replace those cases with a `DimShuffle` here. # We replace those cases with a `DimShuffle` here.
# We let later optimizations merge the nested `DimShuffle`s nb_dim_to_add = old_alloc.ndim - new_alloc.ndim
nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim new_alloc = new_alloc.dimshuffle(
alloc_input = alloc_input.dimshuffle( ["x"] * nb_dim_to_add + list(range(new_alloc.ndim))
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
) )
new_alloc = broadcast_to(i.owner.op(new_alloc), bcasted_shape)
# We need to keep the old `DimShuffle`. It could swap axes or copy_stack_trace(i, new_alloc)
# add dimensions anywhere. new_inputs[idx] = new_alloc
r_i = i.owner.op(alloc_input)
copy_stack_trace(i, r_i)
new_i.append(r_i)
else:
new_i.append(i)
new_i[assert_op_idx] = assert_op_in
# If this assert is triggered, it means we are recreating an equivalent graph # If this assert is triggered, it means we are recreating an equivalent graph
# which would result in a cyclical merge optimization. # which would result in a cyclical merge optimization.
if all(new is old for new, old in zip(new_i, node.inputs)): if all(new is old for new, old in zip(new_inputs, node.inputs)):
return return
ret = node.op(*new_i, return_list=True) ret = node.op(*new_inputs, return_list=True)
copy_stack_trace(node.outputs, ret) copy_stack_trace(node.outputs, ret)
return ret return ret
......
...@@ -1489,7 +1489,7 @@ class TestLocalAdvSub1AdvIncSub1: ...@@ -1489,7 +1489,7 @@ class TestLocalAdvSub1AdvIncSub1:
assert check_stack_trace(f, ops_to_check=(Assert, aes.Cast)) assert check_stack_trace(f, ops_to_check=(Assert, aes.Cast))
class TestAllocZero: class TestSubtensorAllocRewrites:
def setup_method(self): def setup_method(self):
mode = get_default_mode() mode = get_default_mode()
self.mode = mode.including( self.mode = mode.including(
...@@ -1783,207 +1783,6 @@ def test_local_set_to_inc_subtensor(): ...@@ -1783,207 +1783,6 @@ def test_local_set_to_inc_subtensor():
assert check_stack_trace(f2, ops_to_check="all") assert check_stack_trace(f2, ops_to_check="all")
class TestLocalElemwiseAlloc:
dtype = config.floatX
def setup_method(self):
self.fast_compile_mode = get_mode("FAST_COMPILE")
self.fast_run_mode = get_mode("FAST_RUN")
self.vec = vector("vec", dtype=self.dtype)
self.mat = matrix("mat", dtype=self.dtype)
self.tens = tensor3("tens", dtype=self.dtype)
self.alloc_wo_dep = at.alloc(self.vec, 2, 2)
self.alloc_wo_dep_broad = at.alloc(self.vec, 1, 2)
self.alloc_w_dep = at.alloc(self.vec, *self.mat.shape)
self.alloc_w_dep_broad = at.alloc(self.vec, 1, *self.mat.shape)
self.alloc_w_dep_broad2 = at.alloc(
self.vec, self.mat.shape[0], self.mat.shape[1], 1
)
self.alloc_w_dep_tens = at.alloc(
self.vec, self.tens.shape[0], self.tens.shape[1]
)
self.tv_wo_dep = at.alloc(self.vec, 5, 5)
self.tm_wo_dep = at.alloc(self.mat, 5, 5, 5)
self.s = iscalar("s")
self.tv_w_dep = at.alloc(self.vec, self.s, self.s)
self.tm_w_dep = at.alloc(self.mat, 5, 5, 5)
self.row = row(dtype=self.dtype)
self.o = at.alloc(self.row, 5, 5)
def _verify_alloc_count(self, f, count):
assert (
sum(
isinstance(elem.op, Alloc)
for elem in f.maker.fgraph.toposort()
if elem.op is not None
)
== count
)
def _verify_assert_count(self, f, count):
assert (
sum(
isinstance(elem.op, Assert)
for elem in f.maker.fgraph.toposort()
if elem.op is not None
)
== count
)
def test_remove_alloc_wo_dimshuffle(self):
# Exclude local_useless_alloc, since it does not introduce
# assert in all the same cases.
self.fast_run_mode = self.fast_run_mode.excluding(
"local_useless_alloc", "local_alloc_sink_dimshuffle"
)
# No optimization on alloc
func = function(
[self.vec, self.mat],
self.alloc_wo_dep + self.mat,
mode=self.fast_compile_mode,
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 0)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(func, ops_to_check="all")
# Optimization on alloc with assert
func = function(
[self.vec, self.mat], self.alloc_wo_dep + self.mat, mode=self.fast_run_mode
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 1)
# Optimization on alloc with assert and broadcast
func = function(
[self.vec, self.mat],
self.alloc_wo_dep_broad + self.mat,
mode=self.fast_run_mode,
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 1)
# No optimization on alloc without assert
func = function(
[self.vec, self.mat],
self.alloc_w_dep + self.mat,
mode=self.fast_compile_mode,
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 0)
# Optimization on alloc without assert
func = function(
[self.vec, self.mat], self.alloc_w_dep + self.mat, mode=self.fast_run_mode
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 0)
# Optimization on alloc without assert and with broadcast
func = function(
[self.vec, self.mat],
self.alloc_w_dep_broad + self.mat,
mode=self.fast_run_mode,
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 0)
# Not optimized case on alloc and with broadcast
func = function(
[self.vec, self.mat],
self.alloc_w_dep_broad2 + self.mat,
mode=self.fast_run_mode,
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 0)
def test_remove_alloc_w_dimshuffle(self):
# No optimization on dimshuffle with assert
func = function(
[self.vec, self.tens],
self.alloc_wo_dep.dimshuffle(0, 1, "x") + self.tens,
mode=self.fast_compile_mode,
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 0)
# Optimization on dimshuffle with assert
func = function(
[self.vec, self.tens],
self.alloc_wo_dep.dimshuffle(0, 1, "x") + self.tens,
mode=self.fast_run_mode,
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 1)
# No optimization on dimshuffle without assert
func = function(
[self.vec, self.tens],
self.alloc_w_dep_tens.dimshuffle(0, 1, "x") + self.tens,
mode=self.fast_compile_mode,
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 0)
# Optimization on dimshuffle without assert
func = function(
[self.vec, self.tens],
self.alloc_w_dep_tens.dimshuffle(0, 1, "x") + self.tens,
mode=self.fast_run_mode,
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 0)
def test_multi_input_single_alloc(self):
# No optimization on dimshuffle with assert
func = function(
[self.vec, self.mat],
self.tv_wo_dep + self.tm_wo_dep,
mode=self.fast_compile_mode,
)
self._verify_alloc_count(func, 2)
self._verify_assert_count(func, 0)
# Optimization on dimshuffle with assert
func = function(
[self.vec, self.mat],
self.tv_wo_dep + self.tm_wo_dep,
mode=self.fast_run_mode,
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 0)
# No optimization on dimshuffle without assert
func = function(
[self.vec, self.mat, self.s],
self.tv_w_dep + self.tm_w_dep,
mode=self.fast_compile_mode,
)
self._verify_alloc_count(func, 2)
self._verify_assert_count(func, 0)
# Optimization on dimshuffle without assert
func = function(
[self.vec, self.mat, self.s],
self.tv_w_dep + self.tm_w_dep,
mode=self.fast_run_mode,
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 1)
def test_error(self):
t3fft = tensor(dtype=self.dtype, shape=(False, False, True))
o = self.o.dimshuffle(0, 1, "x") + t3fft
func = function([t3fft, self.row], o, mode=self.fast_run_mode)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 1)
d = np.random.random((5, 5, 1)).astype(self.dtype)
r = np.random.random((1, 5)).astype(self.dtype)
func(d, r)
def test_local_subtensor_of_alloc(): def test_local_subtensor_of_alloc():
# DebugMode should detect if something goes wrong. # DebugMode should detect if something goes wrong.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论