提交 35ad2538 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove strict TensorType.broadcastable usage from local_elemwise_alloc

上级 8d5a8c8c
...@@ -68,7 +68,13 @@ from aesara.tensor.basic import ( ...@@ -68,7 +68,13 @@ from aesara.tensor.basic import (
) )
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError, ShapeError from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, broadcast_shape from aesara.tensor.extra_ops import (
BroadcastTo,
Repeat,
Unique,
broadcast_shape,
broadcast_to,
)
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq from aesara.tensor.math import eq
from aesara.tensor.shape import ( from aesara.tensor.shape import (
...@@ -1491,26 +1497,11 @@ def local_elemwise_alloc(fgraph, node): ...@@ -1491,26 +1497,11 @@ def local_elemwise_alloc(fgraph, node):
introduces them as a canonicalization of `Alloc`'s with leading introduces them as a canonicalization of `Alloc`'s with leading
broadcastable dimensions. broadcastable dimensions.
""" """
if not isinstance(node.op, Elemwise):
return False
# Rewrite is only applicable when there are at least two inputs # Rewrite is only applicable when there are at least two inputs
if len(node.inputs) == 1: if len(node.inputs) == 1:
return None return False
if len(node.outputs) > 1: if len(node.outputs) > 1:
# Ensure all outputs have the same broadcast pattern
# This is a supposition that I'm not sure is always true.
assert all(
o.type.broadcastable == node.outputs[0].type.broadcastable
for o in node.outputs[1:]
)
# The broadcast pattern of the output must match the broadcast
# pattern of at least one of the inputs.
if not any(
i.type.broadcastable == node.outputs[0].type.broadcastable for i in node.inputs
):
return False return False
def dimshuffled_alloc(i): def dimshuffled_alloc(i):
...@@ -1523,103 +1514,82 @@ def local_elemwise_alloc(fgraph, node): ...@@ -1523,103 +1514,82 @@ def local_elemwise_alloc(fgraph, node):
# At least one input must have an owner that is either a `Alloc` or a # At least one input must have an owner that is either a `Alloc` or a
# `DimShuffle` with an owner that is a `Alloc` -- otherwise there is # `DimShuffle` with an owner that is a `Alloc` -- otherwise there is
# nothing to optimize. # nothing to optimize.
if not any( alloc_idxs = [
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)) idx
for i in node.inputs for idx, i in enumerate(node.inputs)
): if i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
]
if len(alloc_idxs) == 0:
return False return False
# Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a # Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a
# baseline for the dimensions. # baseline for the dimensions.
assert_op_idx = None ref_var_idx = None
for idx, i in enumerate(node.inputs): for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable: if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not a `Alloc` nor a `DimShuffle` of a # Prefer an input that is not an `Alloc` nor a `DimShuffle` of an
# `Alloc` so that all `Alloc`s can be optimized. # `Alloc`, so that all `Alloc`s can be optimized.
if not ( if idx not in alloc_idxs:
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)) ref_var_idx = idx
):
assert_op_idx = idx
break break
# If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one # If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one
if assert_op_idx is None: if ref_var_idx is None:
for idx, i in enumerate(node.inputs): for idx, i in enumerate(node.inputs):
if (i.type.broadcastable == node.outputs[0].type.broadcastable) and ( # XXX: This broadcastable comparison doesn't work
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)) if (
): i.type.broadcastable == node.outputs[0].type.broadcastable
assert_op_idx = idx ) and idx in alloc_idxs:
ref_var_idx = idx
break break
assert_op_in = node.inputs[assert_op_idx] if not hasattr(fgraph, "shape_feature"):
cmp_op = assert_op_in return False
new_i = []
same_shape = fgraph.shape_feature.same_shape
for i in node.inputs:
# Remove `Alloc`
if i.owner and isinstance(i.owner.op, Alloc):
assert i.type.ndim == cmp_op.ndim
if config.experimental__local_alloc_elemwise_assert:
get_shape = fgraph.shape_feature.get_shape
cond = []
for idx in range(i.type.ndim):
if not i.type.broadcastable[idx] and not same_shape(
i, cmp_op, idx, idx
):
i_shp = get_shape(i, idx)
cmp_shp = get_shape(cmp_op, idx)
cond.append(eq(i_shp, cmp_shp))
if cond:
assert_op_in = assert_op(assert_op_in, *cond)
alloc_input = i.owner.inputs[0]
if alloc_input.ndim != i.ndim:
# The `Alloc` can add dimensions to the value.
# We replace those cases with a `DimShuffle` here.
nb_dim_to_add = i.ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
)
copy_stack_trace(i, alloc_input)
new_i.append(alloc_input)
# Remove `Alloc` in `DimShuffle` input_shapes = [
elif i.owner and dimshuffled_alloc(i): tuple(fgraph.shape_feature.get_shape(i, j) for j in range(i.type.ndim))
assert i.type.ndim == cmp_op.type.ndim for i in node.inputs
if config.experimental__local_alloc_elemwise_assert:
assert_cond = [
eq(i.shape[idx], cmp_op.shape[idx])
for idx in range(i.type.ndim)
if not i.type.broadcastable[idx]
and not same_shape(i, cmp_op, idx, idx)
] ]
if assert_cond: bcasted_shape = broadcast_shape(
assert_op_in = assert_op(assert_op_in, *assert_cond) *input_shapes,
alloc_input = i.owner.inputs[0].owner.inputs[0] arrays_are_shapes=True,
if alloc_input.ndim != i.owner.inputs[0].ndim:
# The `Alloc` can add dimensions to the value.
# We replace those cases with a `DimShuffle` here.
# We let later optimizations merge the nested `DimShuffle`s
nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
) )
new_inputs = list(node.inputs)
for idx in alloc_idxs:
i = node.inputs[idx]
# Remove `Alloc`
if isinstance(i.owner.op, Alloc):
new_alloc = broadcast_to(i.owner.inputs[0], bcasted_shape)
# TODO FIXME: This shouldn't be handled here.
# `DimShuffle`s should be lifted through `Alloc`s
# by other, more general rewrites.
# Remove `Alloc` in `DimShuffle`
elif isinstance(i.owner.op, DimShuffle):
old_alloc = i.owner.inputs[0]
new_alloc = old_alloc.owner.inputs[0]
# We need to keep the old `DimShuffle`. It could swap axes or # We need to keep the old `DimShuffle`. It could swap axes or
# add dimensions anywhere. # add dimensions anywhere.
r_i = i.owner.op(alloc_input) if new_alloc.ndim != old_alloc.ndim:
copy_stack_trace(i, r_i) # The `Alloc` can add dimensions to the value.
new_i.append(r_i) # We replace those cases with a `DimShuffle` here.
nb_dim_to_add = old_alloc.ndim - new_alloc.ndim
new_alloc = new_alloc.dimshuffle(
["x"] * nb_dim_to_add + list(range(new_alloc.ndim))
)
new_alloc = broadcast_to(i.owner.op(new_alloc), bcasted_shape)
else: copy_stack_trace(i, new_alloc)
new_i.append(i) new_inputs[idx] = new_alloc
new_i[assert_op_idx] = assert_op_in
# If this assert is triggered, it means we are recreating an equivalent graph # If this assert is triggered, it means we are recreating an equivalent graph
# which would result in a cyclical merge optimization. # which would result in a cyclical merge optimization.
if all(new is old for new, old in zip(new_i, node.inputs)): if all(new is old for new, old in zip(new_inputs, node.inputs)):
return return
ret = node.op(*new_i, return_list=True) ret = node.op(*new_inputs, return_list=True)
copy_stack_trace(node.outputs, ret) copy_stack_trace(node.outputs, ret)
return ret return ret
......
...@@ -121,6 +121,7 @@ from aesara.tensor.type import ( ...@@ -121,6 +121,7 @@ from aesara.tensor.type import (
lvector, lvector,
matrices, matrices,
matrix, matrix,
row,
scalar, scalar,
scalars, scalars,
tensor, tensor,
...@@ -3569,9 +3570,65 @@ def test_Shape_i_canonicalize(): ...@@ -3569,9 +3570,65 @@ def test_Shape_i_canonicalize():
assert y_opt.owner.inputs[0] == x assert y_opt.owner.inputs[0] == x
@pytest.mark.parametrize( class TestLocalElemwiseAlloc:
"""
TODO FIXME: Remove redundant tests.
"""
dtype = config.floatX
def setup_method(self):
self.fast_compile_mode = get_mode("FAST_COMPILE")
self.fast_run_mode = get_mode("FAST_RUN")
self.vec = vector("vec", dtype=self.dtype)
self.mat = matrix("mat", dtype=self.dtype)
self.tens = tensor3("tens", dtype=self.dtype)
self.alloc_wo_dep = at.alloc(self.vec, 2, 2)
self.alloc_wo_dep_broad = at.alloc(self.vec, 1, 2)
self.alloc_w_dep = at.alloc(self.vec, *self.mat.shape)
self.alloc_w_dep_broad = at.alloc(self.vec, 1, *self.mat.shape)
self.alloc_w_dep_broad2 = at.alloc(
self.vec, self.mat.shape[0], self.mat.shape[1], 1
)
self.alloc_w_dep_tens = at.alloc(
self.vec, self.tens.shape[0], self.tens.shape[1]
)
self.tv_wo_dep = at.alloc(self.vec, 5, 5)
self.tm_wo_dep = at.alloc(self.mat, 5, 5, 5)
self.s = iscalar("s")
self.tv_w_dep = at.alloc(self.vec, self.s, self.s)
self.tm_w_dep = at.alloc(self.mat, 5, 5, 5)
self.row = row(dtype=self.dtype)
self.o = at.alloc(self.row, 5, 5)
@staticmethod
def verify_op_count(f, count, cls):
assert (
sum(
isinstance(elem.op, cls)
for elem in f.maker.fgraph.toposort()
if elem.op is not None
)
== count
)
@pytest.mark.parametrize(
"expr, x_shape, y_shape", "expr, x_shape, y_shape",
[ [
(lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 2), (3, 2)),
(lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 1), (1, 1)),
(lambda x, y: at.mul(x, at.alloc(y, 2, 3)), (1, 3), (2, 3)),
(
lambda x, y: at.mul(
at.alloc(x, 3).dimshuffle("x", 0), y.dimshuffle("x", "x")
),
(),
(),
),
pytest.param( pytest.param(
lambda x, y: at.mul(y, at.alloc(1, x)), lambda x, y: at.mul(y, at.alloc(1, x)),
(), (),
...@@ -3580,9 +3637,21 @@ def test_Shape_i_canonicalize(): ...@@ -3580,9 +3637,21 @@ def test_Shape_i_canonicalize():
), ),
(lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)), (lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)),
(lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)), (lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)),
(lambda x, y: at.mul(at.alloc(x, 15, 1), at.alloc(y, 15, 1)), (15, 1), (15, 1)), (
(lambda x, y: at.mul(at.alloc(x, 15, 2), at.alloc(y, 15, 2)), (15, 2), (15, 2)), lambda x, y: at.mul(at.alloc(x, 15, 1), at.alloc(y, 15, 1)),
(lambda x, y: at.mul(at.alloc(x, 15, 2).dimshuffle(1, 0), y), (15, 2), (2, 15)), (15, 1),
(15, 1),
),
(
lambda x, y: at.mul(at.alloc(x, 15, 2), at.alloc(y, 15, 2)),
(15, 2),
(15, 2),
),
(
lambda x, y: at.mul(at.alloc(x, 15, 2).dimshuffle(1, 0), y),
(15, 2),
(2, 15),
),
(lambda x, y: at.mul(at.alloc(x, 1, 15, 2), y), (15, 2), (15, 2)), (lambda x, y: at.mul(at.alloc(x, 1, 15, 2), y), (15, 2), (15, 2)),
( (
lambda x, y: at.mul(at.alloc(x, 1, 15, 2).dimshuffle(0, 2, 1), y), lambda x, y: at.mul(at.alloc(x, 1, 15, 2).dimshuffle(0, 2, 1), y),
...@@ -3590,10 +3659,10 @@ def test_Shape_i_canonicalize(): ...@@ -3590,10 +3659,10 @@ def test_Shape_i_canonicalize():
(2, 15), (2, 15),
), ),
], ],
) )
def test_local_elemwise_alloc(expr, x_shape, y_shape): def test_basic(self, expr, x_shape, y_shape):
x = at.tensor("int64", (False,) * len(x_shape)) x = at.tensor("int64", (False,) * len(x_shape), name="x")
y = at.tensor("int64", (False,) * len(y_shape)) y = at.tensor("int64", (False,) * len(y_shape), name="y")
z = expr(x, y) z = expr(x, y)
z_opt = aesara.function( z_opt = aesara.function(
...@@ -3603,7 +3672,9 @@ def test_local_elemwise_alloc(expr, x_shape, y_shape): ...@@ -3603,7 +3672,9 @@ def test_local_elemwise_alloc(expr, x_shape, y_shape):
on_unused_input="ignore", on_unused_input="ignore",
) )
assert not any(isinstance(node.op, Alloc) for node in z_opt.maker.fgraph.toposort()) assert not any(
isinstance(node.op, Alloc) for node in z_opt.maker.fgraph.toposort()
)
z_no_opt = aesara.function( z_no_opt = aesara.function(
[x, y], [x, y],
...@@ -3619,9 +3690,8 @@ def test_local_elemwise_alloc(expr, x_shape, y_shape): ...@@ -3619,9 +3690,8 @@ def test_local_elemwise_alloc(expr, x_shape, y_shape):
exp_res = z_no_opt(x_val, y_val) exp_res = z_no_opt(x_val, y_val)
assert np.array_equal(res, exp_res) assert np.array_equal(res, exp_res)
def test_single_input(self):
def test_local_elemwise_alloc_single_input(): """Test that rewrite is not triggered when there is only one `Alloc` in an `Elemwise`."""
# Test that rewrite is not triggered when there is only one Alloc in an Elemwise
x = at.matrix("x") x = at.matrix("x")
z = at.exp(at.alloc(x, 15, 1)) z = at.exp(at.alloc(x, 15, 1))
...@@ -3629,3 +3699,166 @@ def test_local_elemwise_alloc_single_input(): ...@@ -3629,3 +3699,166 @@ def test_local_elemwise_alloc_single_input():
z_opt_fg = optimize_graph(z_fg, clone=False, include=["local_elemwise_alloc"]) z_opt_fg = optimize_graph(z_fg, clone=False, include=["local_elemwise_alloc"])
assert any(isinstance(node.op, Alloc) for node in z_opt_fg.apply_nodes) assert any(isinstance(node.op, Alloc) for node in z_opt_fg.apply_nodes)
def test_remove_alloc_wo_dimshuffle(self):
# Exclude local_useless_alloc, since it does not introduce
# assert in all the same cases.
self.fast_run_mode = self.fast_run_mode.excluding(
"local_useless_alloc", "local_alloc_sink_dimshuffle"
)
# No optimization on alloc
func = function(
[self.vec, self.mat],
self.alloc_wo_dep + self.mat,
mode=self.fast_compile_mode,
)
self.verify_op_count(func, 1, Alloc)
self.verify_op_count(func, 0, Assert)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(func, ops_to_check="all")
# Optimization on alloc with assert
func = function(
[self.vec, self.mat], self.alloc_wo_dep + self.mat, mode=self.fast_run_mode
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 2, Assert)
# Optimization on alloc with assert and broadcast
func = function(
[self.vec, self.mat],
self.alloc_wo_dep_broad + self.mat,
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 1, Assert)
# No optimization on alloc without assert
func = function(
[self.vec, self.mat],
self.alloc_w_dep + self.mat,
mode=self.fast_compile_mode,
)
self.verify_op_count(func, 1, Alloc)
self.verify_op_count(func, 0, Assert)
# Optimization on alloc without assert
func = function(
[self.vec, self.mat], self.alloc_w_dep + self.mat, mode=self.fast_run_mode
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 0, Assert)
# Optimization on alloc without assert and with broadcast
func = function(
[self.vec, self.mat],
self.alloc_w_dep_broad + self.mat,
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 0, Assert)
# This was previously not optimized, but it is now that we
# have `BroadcastTo`.
func = function(
[self.vec, self.mat],
self.alloc_w_dep_broad2 + self.mat,
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 1, Assert)
def test_remove_alloc_w_dimshuffle(self):
# No optimization on dimshuffle with assert
func = function(
[self.vec, self.tens],
self.alloc_wo_dep.dimshuffle(0, 1, "x") + self.tens,
mode=self.fast_compile_mode,
)
self.verify_op_count(func, 1, Alloc)
self.verify_op_count(func, 0, Assert)
# Optimization on dimshuffle with assert
# TODO FIXME: The `BroadcastTo` shapes should use the constants
# provided by the first/`Alloc` term, and not the unknown values from
# the `tens` term.
func = function(
[self.vec, self.tens],
self.alloc_wo_dep.dimshuffle(0, 1, "x") + self.tens,
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 2, Assert)
# No optimization on dimshuffle without assert
func = function(
[self.vec, self.tens],
self.alloc_w_dep_tens.dimshuffle(0, 1, "x") + self.tens,
mode=self.fast_compile_mode,
)
self.verify_op_count(func, 1, Alloc)
self.verify_op_count(func, 0, Assert)
# Optimization on dimshuffle without assert
func = function(
[self.vec, self.tens],
self.alloc_w_dep_tens.dimshuffle(0, 1, "x") + self.tens,
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 0, Assert)
def test_multi_input_single_alloc(self):
# No optimization on dimshuffle with assert
func = function(
[self.vec, self.mat],
self.tv_wo_dep + self.tm_wo_dep,
mode=self.fast_compile_mode,
)
self.verify_op_count(func, 2, Alloc)
self.verify_op_count(func, 0, Assert)
# Optimization on dimshuffle with assert
# TODO: When we support static shape constraints like `shape[i] != 1`,
# reproduce this with such a constraint on `mat` and make sure the
# `BroadcastTo` is removed.
func = function(
[self.vec, self.mat],
self.tv_wo_dep + self.tm_wo_dep,
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 0, Assert)
# No optimization on dimshuffle without assert
func = function(
[self.vec, self.mat, self.s],
self.tv_w_dep + self.tm_w_dep,
mode=self.fast_compile_mode,
)
self.verify_op_count(func, 2, Alloc)
self.verify_op_count(func, 0, Assert)
# Optimization on dimshuffle without assert
func = function(
[self.vec, self.mat, self.s],
self.tv_w_dep + self.tm_w_dep,
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 1, Assert)
def test_misc(self):
x = row(dtype=self.dtype)
y = tensor(dtype=self.dtype, shape=(False, False, True))
out = at.alloc(x, 5, 5).dimshuffle(0, 1, "x") + y
func = function([y, x], out, mode=self.fast_run_mode)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 2, Assert)
y_val = np.random.random((5, 5, 1)).astype(self.dtype)
x_val = np.random.random((1, 5)).astype(self.dtype)
exp_res = np.broadcast_to(x_val, (5, 5))[..., None] + y_val
assert np.array_equal(func(y_val, x_val), exp_res)
...@@ -1489,7 +1489,7 @@ class TestLocalAdvSub1AdvIncSub1: ...@@ -1489,7 +1489,7 @@ class TestLocalAdvSub1AdvIncSub1:
assert check_stack_trace(f, ops_to_check=(Assert, aes.Cast)) assert check_stack_trace(f, ops_to_check=(Assert, aes.Cast))
class TestAllocZero: class TestSubtensorAllocRewrites:
def setup_method(self): def setup_method(self):
mode = get_default_mode() mode = get_default_mode()
self.mode = mode.including( self.mode = mode.including(
...@@ -1783,207 +1783,6 @@ def test_local_set_to_inc_subtensor(): ...@@ -1783,207 +1783,6 @@ def test_local_set_to_inc_subtensor():
assert check_stack_trace(f2, ops_to_check="all") assert check_stack_trace(f2, ops_to_check="all")
class TestLocalElemwiseAlloc:
dtype = config.floatX
def setup_method(self):
self.fast_compile_mode = get_mode("FAST_COMPILE")
self.fast_run_mode = get_mode("FAST_RUN")
self.vec = vector("vec", dtype=self.dtype)
self.mat = matrix("mat", dtype=self.dtype)
self.tens = tensor3("tens", dtype=self.dtype)
self.alloc_wo_dep = at.alloc(self.vec, 2, 2)
self.alloc_wo_dep_broad = at.alloc(self.vec, 1, 2)
self.alloc_w_dep = at.alloc(self.vec, *self.mat.shape)
self.alloc_w_dep_broad = at.alloc(self.vec, 1, *self.mat.shape)
self.alloc_w_dep_broad2 = at.alloc(
self.vec, self.mat.shape[0], self.mat.shape[1], 1
)
self.alloc_w_dep_tens = at.alloc(
self.vec, self.tens.shape[0], self.tens.shape[1]
)
self.tv_wo_dep = at.alloc(self.vec, 5, 5)
self.tm_wo_dep = at.alloc(self.mat, 5, 5, 5)
self.s = iscalar("s")
self.tv_w_dep = at.alloc(self.vec, self.s, self.s)
self.tm_w_dep = at.alloc(self.mat, 5, 5, 5)
self.row = row(dtype=self.dtype)
self.o = at.alloc(self.row, 5, 5)
def _verify_alloc_count(self, f, count):
assert (
sum(
isinstance(elem.op, Alloc)
for elem in f.maker.fgraph.toposort()
if elem.op is not None
)
== count
)
def _verify_assert_count(self, f, count):
assert (
sum(
isinstance(elem.op, Assert)
for elem in f.maker.fgraph.toposort()
if elem.op is not None
)
== count
)
def test_remove_alloc_wo_dimshuffle(self):
# Exclude local_useless_alloc, since it does not introduce
# assert in all the same cases.
self.fast_run_mode = self.fast_run_mode.excluding(
"local_useless_alloc", "local_alloc_sink_dimshuffle"
)
# No optimization on alloc
func = function(
[self.vec, self.mat],
self.alloc_wo_dep + self.mat,
mode=self.fast_compile_mode,
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 0)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(func, ops_to_check="all")
# Optimization on alloc with assert
func = function(
[self.vec, self.mat], self.alloc_wo_dep + self.mat, mode=self.fast_run_mode
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 1)
# Optimization on alloc with assert and broadcast
func = function(
[self.vec, self.mat],
self.alloc_wo_dep_broad + self.mat,
mode=self.fast_run_mode,
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 1)
# No optimization on alloc without assert
func = function(
[self.vec, self.mat],
self.alloc_w_dep + self.mat,
mode=self.fast_compile_mode,
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 0)
# Optimization on alloc without assert
func = function(
[self.vec, self.mat], self.alloc_w_dep + self.mat, mode=self.fast_run_mode
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 0)
# Optimization on alloc without assert and with broadcast
func = function(
[self.vec, self.mat],
self.alloc_w_dep_broad + self.mat,
mode=self.fast_run_mode,
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 0)
# Not optimized case on alloc and with broadcast
func = function(
[self.vec, self.mat],
self.alloc_w_dep_broad2 + self.mat,
mode=self.fast_run_mode,
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 0)
def test_remove_alloc_w_dimshuffle(self):
# No optimization on dimshuffle with assert
func = function(
[self.vec, self.tens],
self.alloc_wo_dep.dimshuffle(0, 1, "x") + self.tens,
mode=self.fast_compile_mode,
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 0)
# Optimization on dimshuffle with assert
func = function(
[self.vec, self.tens],
self.alloc_wo_dep.dimshuffle(0, 1, "x") + self.tens,
mode=self.fast_run_mode,
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 1)
# No optimization on dimshuffle without assert
func = function(
[self.vec, self.tens],
self.alloc_w_dep_tens.dimshuffle(0, 1, "x") + self.tens,
mode=self.fast_compile_mode,
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 0)
# Optimization on dimshuffle without assert
func = function(
[self.vec, self.tens],
self.alloc_w_dep_tens.dimshuffle(0, 1, "x") + self.tens,
mode=self.fast_run_mode,
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 0)
def test_multi_input_single_alloc(self):
# No optimization on dimshuffle with assert
func = function(
[self.vec, self.mat],
self.tv_wo_dep + self.tm_wo_dep,
mode=self.fast_compile_mode,
)
self._verify_alloc_count(func, 2)
self._verify_assert_count(func, 0)
# Optimization on dimshuffle with assert
func = function(
[self.vec, self.mat],
self.tv_wo_dep + self.tm_wo_dep,
mode=self.fast_run_mode,
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 0)
# No optimization on dimshuffle without assert
func = function(
[self.vec, self.mat, self.s],
self.tv_w_dep + self.tm_w_dep,
mode=self.fast_compile_mode,
)
self._verify_alloc_count(func, 2)
self._verify_assert_count(func, 0)
# Optimization on dimshuffle without assert
func = function(
[self.vec, self.mat, self.s],
self.tv_w_dep + self.tm_w_dep,
mode=self.fast_run_mode,
)
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 1)
def test_error(self):
t3fft = tensor(dtype=self.dtype, shape=(False, False, True))
o = self.o.dimshuffle(0, 1, "x") + t3fft
func = function([t3fft, self.row], o, mode=self.fast_run_mode)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 1)
d = np.random.random((5, 5, 1)).astype(self.dtype)
r = np.random.random((1, 5)).astype(self.dtype)
func(d, r)
def test_local_subtensor_of_alloc(): def test_local_subtensor_of_alloc():
# DebugMode should detect if something goes wrong. # DebugMode should detect if something goes wrong.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论