提交 c9a6f69e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Implement basic rewrites for Unique

上级 b48b803d
......@@ -71,7 +71,7 @@ from aesara.tensor.basic import (
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
from aesara.tensor.extra_ops import broadcast_shape
from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, broadcast_shape
from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft
......@@ -3495,3 +3495,160 @@ def local_Shape_i_of_broadcastable(fgraph, node):
if shape_arg.broadcastable[node.op.i]:
return [as_tensor_variable(1, dtype=np.int64)]
@register_useless
@register_canonicalize
@local_optimizer([Unique])
def local_Unique_scalar(fgraph, node):
"""Convert ``unique(x)`` to ``x`` when ``x`` is a scalar."""
if not isinstance(node.op, Unique):
return False
if node.op.return_index or node.op.return_inverse or node.op.return_counts:
return False
uniqued_var = node.inputs[0]
if uniqued_var.ndim != 0:
return False
old_out = node.outputs[0]
res = as_tensor_variable(uniqued_var, ndim=old_out.ndim, dtype=old_out.dtype)
return [res]
@register_useless
@register_canonicalize
@local_optimizer([Unique])
def local_Unique_Alloc_lift(fgraph, node):
"""Convert ``unique(alloc(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
alloc_var = node.inputs[0]
if not (alloc_var.owner and isinstance(alloc_var.owner.op, Alloc)):
return False
alloced_var, *alloc_shape = alloc_var.owner.inputs
new_unique, *_ = node.op.make_node(alloced_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@local_optimizer([Unique])
def local_Unique_BroadcastTo_lift(fgraph, node):
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
bcast_var = node.inputs[0]
if not (bcast_var.owner and isinstance(bcast_var.owner.op, BroadcastTo)):
return False
bcasted_var, *bcast_shape = bcast_var.owner.inputs
new_unique, *_ = node.op.make_node(bcasted_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@local_optimizer([Unique])
def local_Unique_Repeat_lift(fgraph, node):
"""Convert ``unique(repeat(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
repeat_var = node.inputs[0]
if not (repeat_var.owner and isinstance(repeat_var.owner.op, Repeat)):
return False
repeated_var, *repeat_shape = repeat_var.owner.inputs
new_unique, *_ = node.op.make_node(repeated_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@local_optimizer([Unique])
def local_Unique_second(fgraph, node):
"""Convert ``unique(second(x, ...), axis=None)`` to ``second(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
second_var = node.inputs[0]
if not (
second_var.owner
and isinstance(second_var.owner.op, Elemwise)
and isinstance(second_var.owner.op.scalar_op, aes.Second)
):
return False
shape_var, seconded_var = second_var.owner.inputs
new_unique, *_ = node.op.make_node(seconded_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
......@@ -11,7 +11,7 @@ from aesara.assert_op import Assert
from aesara.compile import optdb
from aesara.compile.debugmode import DebugMode
from aesara.compile.function import function
from aesara.compile.mode import Mode, get_default_mode, get_mode
from aesara.compile.mode import OPT_NONE, Mode, get_default_mode, get_mode
from aesara.compile.ops import DeepCopyOp, deep_copy_op
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable
......@@ -30,8 +30,10 @@ from aesara.tensor.basic import (
ScalarFromTensor,
Split,
TensorFromScalar,
alloc,
as_tensor_variable,
join,
second,
tile,
)
from aesara.tensor.basic_opt import (
......@@ -49,6 +51,14 @@ from aesara.tensor.basic_opt import (
register_specialize,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.extra_ops import (
BroadcastTo,
Repeat,
Unique,
broadcast_to,
repeat,
unique,
)
from aesara.tensor.math import (
add,
bitwise_and,
......@@ -3293,3 +3303,302 @@ def test_apply_rebroadcast_opt():
res = apply_rebroadcast_opt(rval)
assert res is rval
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_scalar(return_index, return_counts, return_inverse):
x = dscalar()
y = unique(
x,
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=None,
)
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_opt_fg = optimize_graph(
y_fg, clone=False, include=["canonicalize", "local_Unique_scalar"]
)
y_opt = y_opt_fg.outputs[0]
y_opt_start = y_opt
if isinstance(y_opt.owner.op, Rebroadcast):
y_opt_start = y_opt.owner.inputs[0]
assert isinstance(y_opt_start.owner.op, DimShuffle)
assert y_opt_start.owner.inputs[0] == x
default_mode = get_default_mode()
opt_mode = default_mode.excluding("local_Unique_scalar")
y_fn = function([x], [y, y_opt], mode=opt_mode)
x_val = np.array(-10.0, dtype=np.float64)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, axis, new_shape",
[
(np.array(-10, dtype=np.int64), None, ()),
(np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_Alloc_lift(
x_val, axis, new_shape, return_index, return_counts, return_inverse
):
x = as_tensor_variable(x_val).type()
y = unique(
alloc(x, *new_shape),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_opt_fg = optimize_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_Alloc_lift"],
exclude=["local_Unique_scalar"],
)
y_opt = y_opt_fg.outputs[0]
y_opt_start = y_opt
# Ignore any initial `Rebroadcast`s (they serve to
# make the replacement match the original type)
if isinstance(y_opt.owner.op, Rebroadcast):
y_opt_start = y_opt.owner.inputs[0]
assert isinstance(y_opt_start.owner.op, Unique)
assert y_opt_start.owner.inputs[0] == x
assert not any(isinstance(node.op, Alloc) for node in y_opt_fg.apply_nodes)
default_mode = get_default_mode()
# The optimization has already been applied to `y_opt`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the optimized result, `y_opt`.
# The remaining exclusions simply allow us to perform the check below that
# makes sure the original `Alloc` is present in our reference (sub)graph.
opt_mode = default_mode.excluding(
"local_useless_alloc", "local_canonicalize_alloc", "local_Unique_Alloc_lift"
)
y_fn = function([x], [y, y_opt], mode=opt_mode)
# Make sure that the original `Alloc` is used to compute the reference `y`
# result
assert any(isinstance(node.op, Alloc) for node in y_fn.maker.fgraph.apply_nodes)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, axis, new_shape",
[
(np.array(-10, dtype=np.int64), None, ()),
(np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_BroadcastTo(
x_val, axis, new_shape, return_index, return_counts, return_inverse
):
x = as_tensor_variable(x_val).type()
y = unique(
broadcast_to(x, tuple(new_shape)),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_opt_fg = optimize_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_BroadcastTo_lift"],
exclude=["local_Unique_scalar"],
)
y_opt = y_opt_fg.outputs[0]
y_opt_start = y_opt
# Ignore any initial `Rebroadcast`s (they serve to
# make the replacement match the original type)
if isinstance(y_opt.owner.op, Rebroadcast):
y_opt_start = y_opt.owner.inputs[0]
assert isinstance(y_opt_start.owner.op, Unique)
assert y_opt_start.owner.inputs[0] == x
assert not any(isinstance(node.op, BroadcastTo) for node in y_opt_fg.apply_nodes)
default_mode = get_default_mode()
# The optimization has already been applied to `y_opt`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the optimized result, `y_opt`.
opt_mode = default_mode.excluding("local_Unique_BroadcastTo_lift")
y_fn = function([x], [y, y_opt], mode=opt_mode)
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert any(
isinstance(node.op, BroadcastTo) for node in y_fn.maker.fgraph.apply_nodes
)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, unique_axis, repeats, repeat_axis",
[
(np.array([[-10, -3], [-10, 2]], dtype=np.int64), None, (1, 2), 0),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_Repeat(
x_val,
unique_axis,
repeats,
repeat_axis,
return_index,
return_counts,
return_inverse,
):
x = as_tensor_variable(x_val).type()
y = unique(
repeat(x, tuple(repeats), axis=repeat_axis),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=unique_axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_opt_fg = optimize_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_Repeat_lift"],
exclude=["local_Unique_scalar"],
)
y_opt = y_opt_fg.outputs[0]
y_opt_start = y_opt
# Ignore any initial `Rebroadcast`s (they serve to
# make the replacement match the original type)
if isinstance(y_opt.owner.op, Rebroadcast):
y_opt_start = y_opt.owner.inputs[0]
assert isinstance(y_opt_start.owner.op, Unique)
assert y_opt_start.owner.inputs[0] == x
assert not any(isinstance(node.op, Repeat) for node in y_opt_fg.apply_nodes)
default_mode = get_default_mode()
# The optimization has already been applied to `y_opt`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the optimized result, `y_opt`.
opt_mode = default_mode.excluding("local_Unique_Repeat_lift")
y_fn = function([x], [y, y_opt], mode=opt_mode)
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert any(isinstance(node.op, Repeat) for node in y_fn.maker.fgraph.apply_nodes)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, unique_axis, new_shape",
[
(np.array(-10, dtype=np.int64), None, ()),
(np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_second(
x_val, unique_axis, new_shape, return_index, return_counts, return_inverse
):
x = as_tensor_variable(x_val).type()
a = np.zeros(tuple(new_shape), dtype=x.dtype)
y = unique(
second(a, x),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=unique_axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_opt_fg = optimize_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_second_lift"],
exclude=["local_Unique_scalar", "topo_constant_folding"],
)
y_opt = y_opt_fg.outputs[0]
y_opt_start = y_opt
# Ignore any initial `Rebroadcast`s (they serve to
# make the replacement match the original type)
if y_opt.owner and isinstance(y_opt.owner.op, Rebroadcast):
y_opt_start = y_opt.owner.inputs[0]
assert isinstance(y_opt_start.owner.op, Unique)
y_opt_start = y_opt_start.owner.inputs[0]
if y_opt_start.owner and isinstance(y_opt_start.owner.op, DimShuffle):
y_opt_start = y_opt_start.owner.inputs[0]
assert y_opt_start == x
assert not any(
isinstance(node.op.scalar_op, aes.Second)
for node in y_opt_fg.apply_nodes
if isinstance(node.op, Elemwise)
)
# The optimization has already been applied to `y_opt`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the optimized result, `y_opt`.
y_fn = function([x], [y, y_opt], mode=Mode(optimizer=OPT_NONE))
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert any(
isinstance(node.op.scalar_op, aes.Second)
for node in y_fn.maker.fgraph.apply_nodes
if isinstance(node.op, Elemwise)
)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论