提交 f4004b27 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a rewrite that changes constant indices to unsigned integers

上级 0845fa48
......@@ -1749,3 +1749,94 @@ def local_join_subtensors(fgraph, node):
return [concatenate(new_joined_tensors, axis=axis)]
else:
return [merged_subtensors]
@register_specialize
@node_rewriter(
[
Subtensor,
AdvancedSubtensor1,
AdvancedSubtensor,
IncSubtensor,
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
]
)
def local_uint_constant_indices(fgraph, node):
"""Convert constant indices to unsigned dtypes."""
if isinstance(node.op, (IncSubtensor, AdvancedIncSubtensor, AdvancedIncSubtensor1)):
x, y, *indices = node.inputs
else:
x, *indices = node.inputs
y = None
idx_list = getattr(node.op, "idx_list", None)
new_indices = list(indices_from_subtensor(indices, idx_list))
has_new_index = False
for i, index in enumerate(new_indices):
if not isinstance(index, Constant):
continue
index_val = index.data
if index_val is None or isinstance(index_val, slice):
# TODO: If slice index dtypes matter, we can consider converting
# those, as well.
continue
assert isinstance(index_val, (np.generic, np.ndarray))
if index_val.size == 0:
continue
if index_val.dtype == bool:
continue
if np.ndim(index_val) > 0:
minval = index_val.min()
else:
minval = index_val
if minval >= 0:
maxval = index_val.max()
dtype = np.min_scalar_type(maxval)
else:
# If we can't convert to unsigned, then don't attempt to minimize
# the type size either--at least not for now.
# dtype = np.min_scalar_type(-max(-minval, maxval))
continue
if dtype == index_val.dtype:
continue
if index_val.ndim > 0:
new_index = aesara.tensor.as_tensor_variable(
index_val.astype(dtype), dtype=dtype
)
else:
new_index = aes.constant(index_val.astype(dtype), dtype=dtype)
new_indices[i] = new_index
has_new_index = True
if not has_new_index:
return False
new_out = x[tuple(new_indices)]
if y is not None:
new_out = inc_subtensor(
new_out,
y,
inplace=node.op.inplace,
set_instead_of_inc=node.op.set_instead_of_inc,
ignore_duplicates=getattr(node.op, "ignore_duplicates", False),
)
new_outs = new_out.owner.outputs
copy_stack_trace(node.outputs, new_outs)
return new_outs
......@@ -618,7 +618,7 @@ def test_debugprint_compiled_fn():
forall_inplace,cpu,scan_fn} [id A] (outer_out_sit_sot-0)
>Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0)
> |TensorConstant{0} [id J]
> |Subtensor{int64, int64, int64} [id K]
> |Subtensor{int64, int64, uint8} [id K]
> | |*2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
> | |ScalarFromTensor [id M]
> | | |*0-<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0)
......
......@@ -31,6 +31,7 @@ from aesara.tensor.subtensor import (
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
advanced_inc_subtensor1,
inc_subtensor,
set_subtensor,
)
......@@ -52,7 +53,7 @@ from aesara.tensor.type import (
tensor4,
vector,
)
from aesara.tensor.type_other import slicetype
from aesara.tensor.type_other import make_slice, slicetype
from tests import unittest_tools as utt
from tests.unittest_tools import create_aesara_param
......@@ -2160,3 +2161,146 @@ def test_deprecations():
"""Make sure we can import from deprecated modules."""
with pytest.deprecated_call():
from aesara.tensor.subtensor_opt import get_advsubtensor_axis # noqa: F401 F811
def test_local_uint_constant_indices():
mode = get_default_mode().including("specialize", "local_uint_constant_indices")
rng = np.random.default_rng(20900)
# Subtensor, don't convert
x = at.vector("x")
idx = at.as_tensor_variable(np.array(-1, np.int64))
z = x[idx]
z_fn = aesara.function([x], z, mode=mode)
deepcopy_node = z_fn.maker.fgraph.outputs[0].owner
subtensor_node = deepcopy_node.inputs[0].owner
assert isinstance(subtensor_node.op, Subtensor)
new_index = subtensor_node.inputs[1]
assert isinstance(new_index, Constant)
assert new_index.type.dtype == "int64"
# `Subtensor`, one index, convert
x = at.vector("x")
idx = at.as_tensor_variable(np.array(1, np.int64))
z = x[idx]
z_fn = aesara.function([x], z, mode=mode)
deepcopy_node = z_fn.maker.fgraph.outputs[0].owner
subtensor_node = deepcopy_node.inputs[0].owner
assert isinstance(subtensor_node.op, Subtensor)
new_index = subtensor_node.inputs[1]
assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8"
# `Subtensor`, two indices, one slice, convert
x = at.matrix("x")
indices = (at.as_tensor_variable(np.array(1, np.int64)), slice(None, 10))
z = x[indices]
z_fn = aesara.function([x], z, mode=mode)
deepcopy_node = z_fn.maker.fgraph.outputs[0].owner
subtensor_node = deepcopy_node.inputs[0].owner
assert isinstance(subtensor_node.op, Subtensor)
new_index = subtensor_node.inputs[1]
assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8"
# `AdvancedSubtensor`, two indices, one symbolic slice, convert
x = at.matrix("x")
indices = (
at.as_tensor_variable(np.array(1, np.int64)),
make_slice(slice(None, 10)),
)
z = x[indices]
z_fn = aesara.function([x], z, mode=mode)
subtensor_node = z_fn.maker.fgraph.outputs[0].owner
assert isinstance(subtensor_node.op, AdvancedSubtensor)
new_index = subtensor_node.inputs[1]
assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8"
# `AdvancedSubtensor1`, convert
x = at.vector("x")
idx = at.as_tensor_variable(rng.integers(0, 10, size=10).astype(np.int64))
z = x[idx]
z_fn = aesara.function([x], z, mode=mode)
subtensor_node = z_fn.maker.fgraph.outputs[0].owner
assert isinstance(subtensor_node.op, AdvancedSubtensor1)
new_index = subtensor_node.inputs[1]
assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8"
# AdvancedSubtensor, empty, convert
x = at.matrix("x")
idx = at.as_tensor_variable(1, dtype=np.int64)
z = x[idx, []]
z_fn = aesara.function([x], z, mode=mode)
subtensor_node = z_fn.maker.fgraph.outputs[0].owner
assert isinstance(subtensor_node.op, AdvancedSubtensor)
new_index = subtensor_node.inputs[1]
assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8"
# AdvancedSubtensor, bool, don't convert
x = at.matrix("x")
idx = at.as_tensor_variable(np.array([True]), dtype=bool)
z = x[idx, []]
z_fn = aesara.function([x], z, mode=mode)
subtensor_node = z_fn.maker.fgraph.outputs[0].owner
assert isinstance(subtensor_node.op, AdvancedSubtensor)
new_index = subtensor_node.inputs[1]
assert isinstance(new_index, Constant)
assert new_index.type.dtype == "bool"
# `IncSubtensor`, convert
x = at.vector("x")
y = at.scalar("y")
idx = at.as_tensor_variable(1, dtype=np.int64)
z = inc_subtensor(x[idx], y)
z_fn = aesara.function([x, y], z, mode=mode)
subtensor_node = z_fn.maker.fgraph.outputs[0].owner
assert isinstance(subtensor_node.op, IncSubtensor)
new_index = subtensor_node.inputs[2]
assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8"
# `AdvancedIncSubtensor1`, convert
x = at.vector("x")
y = at.vector("y")
idx = at.as_tensor_variable(rng.integers(0, 10, size=10).astype(np.int64))
z = advanced_inc_subtensor1(x, y, idx)
z_fn = aesara.function([x, y], z, mode=mode)
subtensor_node = z_fn.maker.fgraph.outputs[0].owner
assert isinstance(subtensor_node.op, AdvancedIncSubtensor1)
new_index = subtensor_node.inputs[2]
assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8"
# `AdvancedIncSubtensor1`, convert
x = at.vector("x")
idx = at.as_tensor_variable(rng.integers(0, 10, size=10).astype(np.int64))
z = x[idx, None]
z_fn = aesara.function([x], z, mode=mode)
subtensor_node = z_fn.maker.fgraph.outputs[0].owner
assert isinstance(subtensor_node.op, AdvancedSubtensor)
new_index = subtensor_node.inputs[1]
assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论