提交 c9a6f69e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Implement basic rewrites for Unique

上级 b48b803d
...@@ -71,7 +71,7 @@ from aesara.tensor.basic import ( ...@@ -71,7 +71,7 @@ from aesara.tensor.basic import (
) )
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError, ShapeError from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
from aesara.tensor.extra_ops import broadcast_shape from aesara.tensor.extra_ops import BroadcastTo, Repeat, Unique, broadcast_shape
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq from aesara.tensor.math import eq
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft
...@@ -3495,3 +3495,160 @@ def local_Shape_i_of_broadcastable(fgraph, node): ...@@ -3495,3 +3495,160 @@ def local_Shape_i_of_broadcastable(fgraph, node):
if shape_arg.broadcastable[node.op.i]: if shape_arg.broadcastable[node.op.i]:
return [as_tensor_variable(1, dtype=np.int64)] return [as_tensor_variable(1, dtype=np.int64)]
@register_useless
@register_canonicalize
@local_optimizer([Unique])
def local_Unique_scalar(fgraph, node):
"""Convert ``unique(x)`` to ``x`` when ``x`` is a scalar."""
if not isinstance(node.op, Unique):
return False
if node.op.return_index or node.op.return_inverse or node.op.return_counts:
return False
uniqued_var = node.inputs[0]
if uniqued_var.ndim != 0:
return False
old_out = node.outputs[0]
res = as_tensor_variable(uniqued_var, ndim=old_out.ndim, dtype=old_out.dtype)
return [res]
@register_useless
@register_canonicalize
@local_optimizer([Unique])
def local_Unique_Alloc_lift(fgraph, node):
"""Convert ``unique(alloc(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
alloc_var = node.inputs[0]
if not (alloc_var.owner and isinstance(alloc_var.owner.op, Alloc)):
return False
alloced_var, *alloc_shape = alloc_var.owner.inputs
new_unique, *_ = node.op.make_node(alloced_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@local_optimizer([Unique])
def local_Unique_BroadcastTo_lift(fgraph, node):
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
bcast_var = node.inputs[0]
if not (bcast_var.owner and isinstance(bcast_var.owner.op, BroadcastTo)):
return False
bcasted_var, *bcast_shape = bcast_var.owner.inputs
new_unique, *_ = node.op.make_node(bcasted_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@local_optimizer([Unique])
def local_Unique_Repeat_lift(fgraph, node):
"""Convert ``unique(repeat(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
repeat_var = node.inputs[0]
if not (repeat_var.owner and isinstance(repeat_var.owner.op, Repeat)):
return False
repeated_var, *repeat_shape = repeat_var.owner.inputs
new_unique, *_ = node.op.make_node(repeated_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
@register_useless
@register_canonicalize
@local_optimizer([Unique])
def local_Unique_second(fgraph, node):
"""Convert ``unique(second(x, ...), axis=None)`` to ``second(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if not isinstance(node.op, Unique):
return False
if (
node.op.return_index
or node.op.return_inverse
or node.op.return_counts
or node.op.axis is not None
):
return False
second_var = node.inputs[0]
if not (
second_var.owner
and isinstance(second_var.owner.op, Elemwise)
and isinstance(second_var.owner.op.scalar_op, aes.Second)
):
return False
shape_var, seconded_var = second_var.owner.inputs
new_unique, *_ = node.op.make_node(seconded_var).outputs
old_out = node.outputs[0]
new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype)
return [new_x]
...@@ -11,7 +11,7 @@ from aesara.assert_op import Assert ...@@ -11,7 +11,7 @@ from aesara.assert_op import Assert
from aesara.compile import optdb from aesara.compile import optdb
from aesara.compile.debugmode import DebugMode from aesara.compile.debugmode import DebugMode
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.mode import Mode, get_default_mode, get_mode from aesara.compile.mode import OPT_NONE, Mode, get_default_mode, get_mode
from aesara.compile.ops import DeepCopyOp, deep_copy_op from aesara.compile.ops import DeepCopyOp, deep_copy_op
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant, Variable from aesara.graph.basic import Apply, Constant, Variable
...@@ -30,8 +30,10 @@ from aesara.tensor.basic import ( ...@@ -30,8 +30,10 @@ from aesara.tensor.basic import (
ScalarFromTensor, ScalarFromTensor,
Split, Split,
TensorFromScalar, TensorFromScalar,
alloc,
as_tensor_variable, as_tensor_variable,
join, join,
second,
tile, tile,
) )
from aesara.tensor.basic_opt import ( from aesara.tensor.basic_opt import (
...@@ -49,6 +51,14 @@ from aesara.tensor.basic_opt import ( ...@@ -49,6 +51,14 @@ from aesara.tensor.basic_opt import (
register_specialize, register_specialize,
) )
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.extra_ops import (
BroadcastTo,
Repeat,
Unique,
broadcast_to,
repeat,
unique,
)
from aesara.tensor.math import ( from aesara.tensor.math import (
add, add,
bitwise_and, bitwise_and,
...@@ -3293,3 +3303,302 @@ def test_apply_rebroadcast_opt(): ...@@ -3293,3 +3303,302 @@ def test_apply_rebroadcast_opt():
res = apply_rebroadcast_opt(rval) res = apply_rebroadcast_opt(rval)
assert res is rval assert res is rval
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_scalar(return_index, return_counts, return_inverse):
x = dscalar()
y = unique(
x,
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=None,
)
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_opt_fg = optimize_graph(
y_fg, clone=False, include=["canonicalize", "local_Unique_scalar"]
)
y_opt = y_opt_fg.outputs[0]
y_opt_start = y_opt
if isinstance(y_opt.owner.op, Rebroadcast):
y_opt_start = y_opt.owner.inputs[0]
assert isinstance(y_opt_start.owner.op, DimShuffle)
assert y_opt_start.owner.inputs[0] == x
default_mode = get_default_mode()
opt_mode = default_mode.excluding("local_Unique_scalar")
y_fn = function([x], [y, y_opt], mode=opt_mode)
x_val = np.array(-10.0, dtype=np.float64)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, axis, new_shape",
[
(np.array(-10, dtype=np.int64), None, ()),
(np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_Alloc_lift(
x_val, axis, new_shape, return_index, return_counts, return_inverse
):
x = as_tensor_variable(x_val).type()
y = unique(
alloc(x, *new_shape),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_opt_fg = optimize_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_Alloc_lift"],
exclude=["local_Unique_scalar"],
)
y_opt = y_opt_fg.outputs[0]
y_opt_start = y_opt
# Ignore any initial `Rebroadcast`s (they serve to
# make the replacement match the original type)
if isinstance(y_opt.owner.op, Rebroadcast):
y_opt_start = y_opt.owner.inputs[0]
assert isinstance(y_opt_start.owner.op, Unique)
assert y_opt_start.owner.inputs[0] == x
assert not any(isinstance(node.op, Alloc) for node in y_opt_fg.apply_nodes)
default_mode = get_default_mode()
# The optimization has already been applied to `y_opt`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the optimized result, `y_opt`.
# The remaining exclusions simply allow us to perform the check below that
# makes sure the original `Alloc` is present in our reference (sub)graph.
opt_mode = default_mode.excluding(
"local_useless_alloc", "local_canonicalize_alloc", "local_Unique_Alloc_lift"
)
y_fn = function([x], [y, y_opt], mode=opt_mode)
# Make sure that the original `Alloc` is used to compute the reference `y`
# result
assert any(isinstance(node.op, Alloc) for node in y_fn.maker.fgraph.apply_nodes)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, axis, new_shape",
[
(np.array(-10, dtype=np.int64), None, ()),
(np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_BroadcastTo(
x_val, axis, new_shape, return_index, return_counts, return_inverse
):
x = as_tensor_variable(x_val).type()
y = unique(
broadcast_to(x, tuple(new_shape)),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_opt_fg = optimize_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_BroadcastTo_lift"],
exclude=["local_Unique_scalar"],
)
y_opt = y_opt_fg.outputs[0]
y_opt_start = y_opt
# Ignore any initial `Rebroadcast`s (they serve to
# make the replacement match the original type)
if isinstance(y_opt.owner.op, Rebroadcast):
y_opt_start = y_opt.owner.inputs[0]
assert isinstance(y_opt_start.owner.op, Unique)
assert y_opt_start.owner.inputs[0] == x
assert not any(isinstance(node.op, BroadcastTo) for node in y_opt_fg.apply_nodes)
default_mode = get_default_mode()
# The optimization has already been applied to `y_opt`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the optimized result, `y_opt`.
opt_mode = default_mode.excluding("local_Unique_BroadcastTo_lift")
y_fn = function([x], [y, y_opt], mode=opt_mode)
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert any(
isinstance(node.op, BroadcastTo) for node in y_fn.maker.fgraph.apply_nodes
)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, unique_axis, repeats, repeat_axis",
[
(np.array([[-10, -3], [-10, 2]], dtype=np.int64), None, (1, 2), 0),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_Repeat(
x_val,
unique_axis,
repeats,
repeat_axis,
return_index,
return_counts,
return_inverse,
):
x = as_tensor_variable(x_val).type()
y = unique(
repeat(x, tuple(repeats), axis=repeat_axis),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=unique_axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_opt_fg = optimize_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_Repeat_lift"],
exclude=["local_Unique_scalar"],
)
y_opt = y_opt_fg.outputs[0]
y_opt_start = y_opt
# Ignore any initial `Rebroadcast`s (they serve to
# make the replacement match the original type)
if isinstance(y_opt.owner.op, Rebroadcast):
y_opt_start = y_opt.owner.inputs[0]
assert isinstance(y_opt_start.owner.op, Unique)
assert y_opt_start.owner.inputs[0] == x
assert not any(isinstance(node.op, Repeat) for node in y_opt_fg.apply_nodes)
default_mode = get_default_mode()
# The optimization has already been applied to `y_opt`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the optimized result, `y_opt`.
opt_mode = default_mode.excluding("local_Unique_Repeat_lift")
y_fn = function([x], [y, y_opt], mode=opt_mode)
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert any(isinstance(node.op, Repeat) for node in y_fn.maker.fgraph.apply_nodes)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
@pytest.mark.parametrize(
"x_val, unique_axis, new_shape",
[
(np.array(-10, dtype=np.int64), None, ()),
(np.array(-10, dtype=np.int64), None, (2, 3)),
(np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)),
],
)
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
def test_local_Unique_second(
x_val, unique_axis, new_shape, return_index, return_counts, return_inverse
):
x = as_tensor_variable(x_val).type()
a = np.zeros(tuple(new_shape), dtype=x.dtype)
y = unique(
second(a, x),
return_index=return_index,
return_counts=return_counts,
return_inverse=return_inverse,
axis=unique_axis,
)
if isinstance(y, list):
y, *_ = y
# This approach allows us to directly confirm that `x` is in the result.
y_fg = FunctionGraph(outputs=[y], copy_inputs=False)
y_opt_fg = optimize_graph(
y_fg,
clone=False,
include=["canonicalize", "local_Unique_second_lift"],
exclude=["local_Unique_scalar", "topo_constant_folding"],
)
y_opt = y_opt_fg.outputs[0]
y_opt_start = y_opt
# Ignore any initial `Rebroadcast`s (they serve to
# make the replacement match the original type)
if y_opt.owner and isinstance(y_opt.owner.op, Rebroadcast):
y_opt_start = y_opt.owner.inputs[0]
assert isinstance(y_opt_start.owner.op, Unique)
y_opt_start = y_opt_start.owner.inputs[0]
if y_opt_start.owner and isinstance(y_opt_start.owner.op, DimShuffle):
y_opt_start = y_opt_start.owner.inputs[0]
assert y_opt_start == x
assert not any(
isinstance(node.op.scalar_op, aes.Second)
for node in y_opt_fg.apply_nodes
if isinstance(node.op, Elemwise)
)
# The optimization has already been applied to `y_opt`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the optimized result, `y_opt`.
y_fn = function([x], [y, y_opt], mode=Mode(optimizer=OPT_NONE))
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert any(
isinstance(node.op.scalar_op, aes.Second)
for node in y_fn.maker.fgraph.apply_nodes
if isinstance(node.op, Elemwise)
)
y_exp_val, y_val = y_fn(x_val)
assert np.array_equal(y_exp_val, y_val)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论