提交 f9a3234e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Rewrite Blockwise IncSubtensor

Also cover cases of AdvancedIncSubtensor with batch indices that were not supported before
上级 5046519a
...@@ -24,6 +24,7 @@ from pytensor.tensor.basic import ( ...@@ -24,6 +24,7 @@ from pytensor.tensor.basic import (
ScalarFromTensor, ScalarFromTensor,
TensorFromScalar, TensorFromScalar,
alloc, alloc,
arange,
cast, cast,
concatenate, concatenate,
expand_dims, expand_dims,
...@@ -34,9 +35,10 @@ from pytensor.tensor.basic import ( ...@@ -34,9 +35,10 @@ from pytensor.tensor.basic import (
switch, switch,
) )
from pytensor.tensor.basic import constant as tensor_constant from pytensor.tensor.basic import constant as tensor_constant
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise, _squeeze_left
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.math import ( from pytensor.tensor.math import (
add, add,
and_, and_,
...@@ -58,6 +60,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -58,6 +60,7 @@ from pytensor.tensor.rewriting.basic import (
) )
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
shape_padleft, shape_padleft,
shape_padright,
shape_tuple, shape_tuple,
) )
from pytensor.tensor.sharedvar import TensorSharedVariable from pytensor.tensor.sharedvar import TensorSharedVariable
...@@ -1578,6 +1581,9 @@ def local_blockwise_of_subtensor(fgraph, node): ...@@ -1578,6 +1581,9 @@ def local_blockwise_of_subtensor(fgraph, node):
"""Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor. """Rewrite Blockwise of Subtensor, where the only batch input is the indexed tensor.
Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none Blockwise(Subtensor{a: b})(x, a, b) -> x[:, a:b] when x has one batch dimension, and a/b none
TODO: Handle batched indices like we do with blockwise of inc_subtensor
TODO: Extend to AdvanceSubtensor
""" """
if not isinstance(node.op.core_op, Subtensor): if not isinstance(node.op.core_op, Subtensor):
return return
...@@ -1598,64 +1604,151 @@ def local_blockwise_of_subtensor(fgraph, node): ...@@ -1598,64 +1604,151 @@ def local_blockwise_of_subtensor(fgraph, node):
@register_stabilize("shape_unsafe") @register_stabilize("shape_unsafe")
@register_specialize("shape_unsafe") @register_specialize("shape_unsafe")
@node_rewriter([Blockwise]) @node_rewriter([Blockwise])
def local_blockwise_advanced_inc_subtensor(fgraph, node): def local_blockwise_inc_subtensor(fgraph, node):
"""Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices.""" """Rewrite blockwised inc_subtensors.
if not isinstance(node.op.core_op, AdvancedIncSubtensor):
return None
x, y, *idxs = node.inputs Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch
Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites
# It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y),
if any( and can be safely rewritten without Blockwise.
( """
isinstance(idx, SliceType | NoneTypeT) core_op = node.op.core_op
or (idx.type.dtype == "bool" and idx.type.ndim > 0) if not isinstance(core_op, AdvancedIncSubtensor | IncSubtensor):
)
for idx in idxs
):
return None return None
op: Blockwise = node.op # type: ignore x, y, *idxs = node.inputs
batch_ndim = op.batch_ndim(node) [out] = node.outputs
if isinstance(node.op.core_op, AdvancedIncSubtensor):
new_idxs = [] if any(
for idx in idxs: (
if all(idx.type.broadcastable[:batch_ndim]): # Blockwise requires all inputs to be tensors so it is not possible
new_idxs.append(idx.squeeze(tuple(range(batch_ndim)))) # to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case
else: # If this is ever supported we need to pay attention to special behavior of numpy when advanced indices
# Rewrite does not apply # are separated by basic indices
isinstance(idx, SliceType | NoneTypeT)
# Also get out if we have boolean indices as they cross dimension boundaries
# / can't be safely broadcasted depending on their runtime content
or (idx.type.dtype == "bool")
)
for idx in idxs
):
return None return None
x_batch_bcast = x.type.broadcastable[:batch_ndim] batch_ndim = node.op.batch_ndim(node)
y_batch_bcast = y.type.broadcastable[:batch_ndim] idxs_core_ndim = [len(inp_sig) for inp_sig in node.op.inputs_sig[2:]]
if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast, strict=True)): max_idx_core_ndim = max(idxs_core_ndim, default=0)
# Need to broadcast batch x dims
batch_shape = tuple( # Step 1. Broadcast buffer to batch_shape
x_dim if (not xb or yb) else y_dim if x.type.broadcastable != out.type.broadcastable:
for xb, x_dim, yb, y_dim in zip( batch_shape = [1] * batch_ndim
x_batch_bcast, for inp in node.inputs:
for i, (broadcastable, batch_dim) in enumerate(
zip(inp.type.broadcastable[:batch_ndim], tuple(inp.shape)[:batch_ndim])
):
if broadcastable:
# This dimension is broadcastable, it doesn't provide shape information
continue
if batch_shape[i] != 1:
# We already found a source of shape for this batch dimension
continue
batch_shape[i] = batch_dim
x = broadcast_to(x, (*batch_shape, *x.shape[batch_ndim:]))
assert x.type.broadcastable == out.type.broadcastable
# Step 2. Massage indices so they respect blockwise semantics
if isinstance(core_op, IncSubtensor):
# For basic IncSubtensor there are two cases:
# 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice
# 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions
# in case we can end up with a basic IncSubtensor again
core_idxs = []
counter = 0
for idx in core_op.idx_list:
if isinstance(idx, slice):
# Squeeze away dummy dimensions so we can convert to slice
new_entries = [None, None, None]
for i, entry in enumerate((idx.start, idx.stop, idx.step)):
if entry is None:
continue
else:
new_entries[i] = new_entry = idxs[counter].squeeze()
counter += 1
if new_entry.ndim > 0:
# If the slice entry has dimensions after the squeeze we can't convert it to a slice
# We could try to convert to equivalent integer indices, but nothing guarantees
# that the slice is "square".
return None
core_idxs.append(slice(*new_entries))
else:
core_idxs.append(_squeeze_left(idxs[counter]))
counter += 1
else:
# For AdvancedIncSubtensor we have tensor integer indices,
# We need to expand batch indexes on the right, so they don't interact with core index dimensions
# We still squeeze on the left in case that allows us to use simpler indices
core_idxs = [
_squeeze_left(
shape_padright(idx, max_idx_core_ndim - idx_core_ndim),
stop_at_dim=batch_ndim,
)
for idx, idx_core_ndim in zip(idxs, idxs_core_ndim)
]
# Step 3. Create new indices for the new batch dimension of x
if not all(
all(idx.type.broadcastable[:batch_ndim])
for idx in idxs
if not isinstance(idx, slice)
):
# If indices have batch dimensions in the indices, they will interact with the new dimensions of x
# We build vectorized indexing with new arange indices that do not interact with core indices or each other
# (i.e., they broadcast)
# Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front),
# we don't want to create a mix of slice(None), and arange() indices for the new batch dimension,
# even if not all batch dimensions have corresponding batch indices.
batch_slices = [
shape_padright(arange(x_batch_shape, dtype="int64"), n)
for (x_batch_shape, n) in zip(
tuple(x.shape)[:batch_ndim], tuple(x.shape)[:batch_ndim],
y_batch_bcast, reversed(range(max_idx_core_ndim, max_idx_core_ndim + batch_ndim)),
tuple(y.shape)[:batch_ndim],
strict=True,
) )
) ]
core_shape = tuple(x.shape)[batch_ndim:] else:
x = alloc(x, *batch_shape, *core_shape) # In the case we don't have batch indices,
# we can use slice(None) to broadcast the core indices to each new batch dimension of x / y
new_idxs = [slice(None)] * batch_ndim + new_idxs batch_slices = [slice(None)] * batch_ndim
x_view = x[tuple(new_idxs)]
new_idxs = (*batch_slices, *core_idxs)
# We need to introduce any implicit expand_dims on core dimension of y x_view = x[new_idxs]
y_core_ndim = y.type.ndim - batch_ndim
if (missing_y_core_ndim := x_view.type.ndim - batch_ndim - y_core_ndim) > 0: # Step 4. Introduce any implicit expand_dims on core dimension of y
missing_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim)) missing_y_core_ndim = x_view.type.ndim - y.type.ndim
y = expand_dims(y, missing_axes) implicit_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim))
y = _squeeze_left(expand_dims(y, implicit_axes), stop_at_dim=batch_ndim)
symbolic_idxs = x_view.owner.inputs[1:]
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs if isinstance(core_op, IncSubtensor):
copy_stack_trace(node.outputs, new_out) # Check if we can still use a basic IncSubtensor
return new_out if isinstance(x_view.owner.op, Subtensor):
new_props = core_op._props_dict()
new_props["idx_list"] = x_view.owner.op.idx_list
new_core_op = type(core_op)(**new_props)
symbolic_idxs = x_view.owner.inputs[1:]
new_out = new_core_op(x, y, *symbolic_idxs)
else:
# We need to use AdvancedSet/IncSubtensor
if core_op.set_instead_of_inc:
new_out = x[new_idxs].set(y)
else:
new_out = x[new_idxs].inc(y)
else:
# AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op
symbolic_idxs = x_view.owner.inputs[1:]
new_out = core_op(x, y, *symbolic_idxs)
copy_stack_trace(out, new_out)
return [new_out]
@node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor]) @node_rewriter(tracks=[AdvancedSubtensor, AdvancedIncSubtensor])
......
...@@ -1417,7 +1417,6 @@ class SubtensorPrinter(Printer): ...@@ -1417,7 +1417,6 @@ class SubtensorPrinter(Printer):
pprint.assign(Subtensor, SubtensorPrinter()) pprint.assign(Subtensor, SubtensorPrinter())
# TODO: Implement similar vectorize for Inc/SetSubtensor
@_vectorize_node.register(Subtensor) @_vectorize_node.register(Subtensor)
def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs): def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs):
"""Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices.""" """Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices."""
......
...@@ -1790,101 +1790,204 @@ def test_local_uint_constant_indices(): ...@@ -1790,101 +1790,204 @@ def test_local_uint_constant_indices():
assert new_index.type.dtype == "uint8" assert new_index.type.dtype == "uint8"
@pytest.mark.parametrize("core_y_implicitly_batched", (False, True)) class TestBlockwiseIncSubtensor:
@pytest.mark.parametrize("set_instead_of_inc", (True, False)) @staticmethod
def test_local_blockwise_advanced_inc_subtensor( def compile_fn_and_ref(*args, **kwargs):
set_instead_of_inc, core_y_implicitly_batched fn = pytensor.function(*args, **kwargs, mode="FAST_RUN")
): ref_fn = pytensor.function(
rng = np.random.default_rng([1764, set_instead_of_inc, core_y_implicitly_batched]) *args, **kwargs, mode=Mode(linker="py", optimizer=None)
)
def np_inplace_f(x, idx, y): return fn, ref_fn
if core_y_implicitly_batched:
y = y[..., None]
if set_instead_of_inc:
x[idx] = y
else:
x[idx] += y
core_y_shape = () if core_y_implicitly_batched else (3,)
core_x = tensor("x", shape=(6,))
core_y = tensor("y", shape=core_y_shape, dtype=int)
core_idxs = [0, 2, 4]
if set_instead_of_inc:
core_graph = set_subtensor(core_x[core_idxs], core_y)
else:
core_graph = inc_subtensor(core_x[core_idxs], core_y)
# Only x is batched @staticmethod
x = tensor("x", shape=(5, 2, 6)) def has_blockwise(fn):
y = tensor("y", shape=core_y_shape, dtype=int) return any(
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
assert isinstance(out.owner.op, Blockwise) )
fn = pytensor.function([x, y], out, mode="FAST_RUN") @pytest.mark.parametrize(
assert not any( "core_y_implicitly_batched", (False, True), ids=["y_explicit", "y_implicit"]
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
) )
@pytest.mark.parametrize("set_instead_of_inc", (True, False), ids=["set", "inc"])
@pytest.mark.parametrize("basic_idx", (True, False), ids=["basic_idx", "adv_idx"])
def test_idxs_not_vectorized(
self, basic_idx, set_instead_of_inc, core_y_implicitly_batched
):
rng = np.random.default_rng(
[1764, set_instead_of_inc, core_y_implicitly_batched, basic_idx]
)
test_x = np.ones(x.type.shape, dtype=x.type.dtype) core_y_shape = () if core_y_implicitly_batched else (3,)
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype) core_x = tensor("x", shape=(6, 6))
expected_out = test_x.copy() core_y = tensor("y", shape=core_y_shape, dtype=int)
np_inplace_f(expected_out, np.s_[:, :, core_idxs], test_y) core_idxs = (-1, slice(None, 3)) if basic_idx else (-1, [0, 2, 4])
np.testing.assert_allclose(fn(test_x, test_y), expected_out) if set_instead_of_inc:
core_graph = set_subtensor(core_x[core_idxs], core_y)
# Only y is batched else:
x = tensor("y", shape=(6,)) core_graph = inc_subtensor(core_x[core_idxs], core_y)
y = tensor("y", shape=(2, *core_y_shape), dtype=int) assert isinstance(
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) core_graph.owner.op, IncSubtensor if basic_idx else AdvancedIncSubtensor
assert isinstance(out.owner.op, Blockwise) )
fn = pytensor.function([x, y], out, mode="FAST_RUN")
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)
test_x = np.ones(x.type.shape, dtype=x.type.dtype) # Only x is batched
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype) x = tensor("x", shape=(5, 2, 6, 6))
expected_out = np.ones((2, *x.type.shape)) y = tensor("y", shape=core_y_shape, dtype=int)
np_inplace_f(expected_out, np.s_[:, core_idxs], test_y) out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
np.testing.assert_allclose(fn(test_x, test_y), expected_out) fn, ref_fn = self.compile_fn_and_ref([x, y], out)
assert self.has_blockwise(ref_fn)
# Both x and y are batched, and do not need to be broadcasted assert not self.has_blockwise(fn)
x = tensor("y", shape=(2, 6)) test_x = np.ones(x.type.shape, dtype=x.type.dtype)
y = tensor("y", shape=(2, *core_y_shape), dtype=int) test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y))
assert isinstance(out.owner.op, Blockwise)
# Only y is batched
fn = pytensor.function([x, y], out, mode="FAST_RUN") x = tensor("y", shape=(6, 6))
assert not any( y = tensor("y", shape=(2, *core_y_shape), dtype=int)
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
fn, ref_fn = self.compile_fn_and_ref([x, y], out)
assert self.has_blockwise(ref_fn)
assert not self.has_blockwise(fn)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y))
# Both x and y are batched, and do not need to be broadcasted
x = tensor("y", shape=(2, 6, 6))
y = tensor("y", shape=(2, *core_y_shape), dtype=int)
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
fn, ref_fn = self.compile_fn_and_ref([x, y], out)
assert self.has_blockwise(ref_fn)
assert not self.has_blockwise(fn)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y))
# Both x and y are batched, but must be broadcasted
x = tensor("y", shape=(5, 1, 6, 6))
y = tensor("y", shape=(1, 2, *core_y_shape), dtype=int)
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
fn, ref_fn = self.compile_fn_and_ref([x, y], out)
assert self.has_blockwise(ref_fn)
assert not self.has_blockwise(fn)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y))
@pytest.mark.parametrize("basic_idx", (True, False), ids=["basic_idx", "adv_idx"])
@pytest.mark.parametrize(
"batched_y", (False, True), ids=("unbatched_y", "batched_y")
) )
@pytest.mark.parametrize(
test_x = np.ones(x.type.shape, dtype=x.type.dtype) "batched_x", (False, True), ids=("unbatched_x", "batched_x")
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
expected_out = test_x.copy()
np_inplace_f(expected_out, np.s_[:, core_idxs], test_y)
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
# Both x and y are batched, but must be broadcasted
x = tensor("y", shape=(5, 1, 6))
y = tensor("y", shape=(1, 2, *core_y_shape), dtype=int)
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
assert isinstance(out.owner.op, Blockwise)
fn = pytensor.function([x, y], out, mode="FAST_RUN")
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
) )
def test_vectorized_idxs(
self,
basic_idx,
batched_y,
batched_x,
):
rng = np.random.default_rng([1874, basic_idx, batched_y, batched_x])
core_x = tensor("x", shape=(6, 6))
core_y = tensor("y", shape=(), dtype=int)
scalar_idx = scalar("scalar_idx", dtype="int64")
vector_idx = vector("vector_idx", dtype="int64")
core_idxs = (
(slice(None, 3), scalar_idx) if basic_idx else (scalar_idx, vector_idx)
)
core_graph = inc_subtensor(core_x[core_idxs], core_y)
assert isinstance(
core_graph.owner.op, IncSubtensor if basic_idx else AdvancedIncSubtensor
)
# Indices don't broadcast with each other
x = pt.tensor("x", shape=(4, 1, *core_x.type.shape)) if batched_x else core_x
y = pt.tensor("y", shape=(2,), dtype=int) if batched_y else core_y
out = vectorize_graph(
core_graph,
replace={
scalar_idx: pt.constant([0, -1]),
vector_idx: pt.constant([[0, 2, 4], [1, 3, 5]]),
core_x: x,
core_y: y,
},
)
fn, ref_fn = self.compile_fn_and_ref([x, y], out)
assert self.has_blockwise(ref_fn)
assert not self.has_blockwise(fn)
test_x = np.ones(x.type.shape, dtype=core_x.type.dtype)
test_y = rng.integers(1, 10, size=y.type.shape)
np.testing.assert_allclose(ref_fn(test_x, test_y), ref_fn(test_x, test_y))
# Indices broadcast with each other
x = core_x
y = pt.tensor("y", shape=(2,), dtype=int) if batched_y else core_y
out = vectorize_graph(
core_graph,
replace={
scalar_idx: pt.constant([0, -1, 0, -1])[:, None],
vector_idx: pt.constant([[0, 2, 4], [1, 3, 5]])[None, :],
core_x: x,
core_y: y,
},
)
fn, ref_fn = self.compile_fn_and_ref([x, y], out)
assert self.has_blockwise(ref_fn)
assert not self.has_blockwise(fn)
test_x = np.ones(core_x.type.shape, dtype=x.type.dtype)
test_y = rng.integers(1, 10, size=y.type.shape)
np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y))
test_x = np.ones(x.type.shape, dtype=x.type.dtype) @pytest.mark.parametrize(
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype) "basic_idx",
final_shape = ( [
*np.broadcast_shapes(x.type.shape[:2], y.type.shape[:2]), True,
x.type.shape[-1], pytest.param(
False,
marks=pytest.mark.xfail(
reason="AdvancedIncSubtensor with slices can't be blockwise"
),
),
],
ids=["basic_idx", "adv_idx"],
)
@pytest.mark.parametrize(
"vectorize_idx", (False, True), ids=lambda x: f"vectorize_idx={x}"
) )
expected_out = np.broadcast_to(test_x, final_shape).copy() def test_non_consecutive_integer_indices(self, vectorize_idx, basic_idx):
np_inplace_f(expected_out, np.s_[:, :, core_idxs], test_y) """Test numpy special behavior of transposing non-consecutive advanced indices to the front.
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
Either in the original graph (id adv_idx) or in the induced graph after rewrite
"""
core_a = pt.tensor("a", shape=(4, 3, 2))
core_v = pt.tensor("v", dtype="float64", shape=(3,) if basic_idx else (2, 3))
core_idx = pt.tensor("idx", dtype=int, shape=() if basic_idx else (2,))
# The empty slice before core_idx, will lead to a transposition of the advanced view
# once it is paired with an new arange slice on the batched dimensions.
# That's why core_v is (2, 3), and not (3, 2), in the case of advanced indexing
core_out = core_a[0, :, core_idx].set(core_v)
vec_a = pt.tensor(shape=(2, 2, 4, 3, 2))
vec_idx = pt.constant([0, -1]) if vectorize_idx else pt.constant(-1, dtype=int)
vec_v = pt.constant([[0, 1, 2], [2, 1, 0]])
if not basic_idx:
vec_idx = pt.repeat(vec_idx[..., None], 2, axis=-1)
vec_v = pt.repeat(vec_v[None], repeats=2, axis=0)
vec_out = vectorize_graph(
core_out,
{core_a: vec_a, core_v: vec_v, core_idx: vec_idx},
)
fn, ref_fn = self.compile_fn_and_ref([vec_a], vec_out)
assert self.has_blockwise(ref_fn)
assert not self.has_blockwise(fn)
test_vec_a = np.arange(np.prod(vec_a.type.shape), dtype=vec_a.dtype).reshape(
vec_a.type.shape
)
np.testing.assert_allclose(fn(test_vec_a), ref_fn(test_vec_a))
class TestUselessSlice: class TestUselessSlice:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论