提交 3de303d2 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Remove Join view flag

Do not normalize constant axis in make_node and fix rewrite that assumed this would always be positive
上级 ff092688
......@@ -87,14 +87,7 @@ def jax_funcify_Join(op, **kwargs):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [jnp.asarray(tensor) for tensor in tensors]
view = op.view
if (view != -1) and all(
tensor.shape[axis] == 0 for tensor in tensors[0:view] + tensors[view + 1 :]
):
return tensors[view]
else:
return jnp.concatenate(tensors, axis=axis)
return jnp.concatenate(tensors, axis=axis)
return join
......
......@@ -117,17 +117,9 @@ def numba_funcify_ARange(op, **kwargs):
@numba_funcify.register(Join)
def numba_funcify_Join(op, **kwargs):
view = op.view
if view != -1:
# TODO: Where (and why) is this `Join.view` even being used? From a
# quick search, the answer appears to be "nowhere", so we should
# probably just remove it.
raise NotImplementedError("The `view` parameter to `Join` is not supported")
@numba_basic.numba_njit
def join(axis, *tensors):
return np.concatenate(tensors, numba_basic.to_scalar(axis))
return np.concatenate(tensors, axis.item())
return join
......
import pytensor.tensor.basic as ptb
from pytensor.scan.basic import scan
from pytensor.tensor.basic import Join
from pytensor.tensor.math import ceil, eq, neq
from pytensor.tensor.subtensor import set_subtensor
......@@ -127,14 +126,12 @@ def scan_checkpoints(
# Pad the sequences if needed
if padding:
# Since padding could be an empty tensor, Join returns a view of s.
join = Join(view=0)
for i, s in enumerate(sequences):
overshoots_by = s.shape[0] % save_every_N
overshoots = neq(overshoots_by, 0)
n = (save_every_N - overshoots_by) * overshoots
z = ptb.zeros((n, *s.shape[1:]), dtype=s.dtype)
sequences[i] = join(0, s, z)
sequences[i] = ptb.join(0, s, z)
# Establish the input variables of the outer scan
o_sequences = [
......
差异被折叠。
......@@ -41,6 +41,7 @@ from pytensor.graph.rewriting.basic import (
node_rewriter,
)
from pytensor.graph.rewriting.db import RewriteDatabase
from pytensor.npy_2_compat import normalize_axis_index
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
from pytensor.scalar.basic import Second
from pytensor.tensor.basic import (
......@@ -817,52 +818,38 @@ def local_join_1(fgraph, node):
return [tensors[0]]
# TODO: merge in local_useless_join
@register_infer_shape
@register_useless
@register_specialize
@register_canonicalize
@register_specialize
@node_rewriter([Join])
def local_join_empty(fgraph, node):
"""Join(i, x, y, empty) => Join(i, x, y)
Remove empty inputs to joins. The empty inputs can be anywhere.
"""
if not isinstance(node.op, Join):
return
new_inputs = []
axis, *tensors = node.inputs
try:
join_idx = get_scalar_constant_value(
static_axis = get_scalar_constant_value(
node.inputs[0], only_process_constants=True
)
except NotScalarConstantError:
return
for idx in range(1, len(node.inputs)):
inp = node.inputs[idx]
# We can not use size == 0,, as this can change shape from 3,0
# to 2,0. This trigger DebugMode error. This happen with
# stack(...,[]) as this add a dimshuffle on [], that add a
# dimensions with shape 1.
if isinstance(inp, Constant) and inp.data.shape[join_idx] == 0:
continue
new_inputs.append(inp)
if len(new_inputs) < len(node.inputs) - 1:
if len(new_inputs) == 0:
# at.join do not work in that case.
# constant folding will take care of this case.
return
ret = join(node.inputs[0], *new_inputs)
o = node.outputs[0]
if ret.dtype != o.dtype:
# Join can upcast some inputs
return
# Copy over stacktrace from previous output (after join op)
# to new output, because an error in the new op must be caused
# by an error in the old join op.
copy_stack_trace(node.outputs, ret)
new_tensors = [tensor for tensor in tensors if tensor.type.shape[static_axis] != 0]
# If there are zero tensors, the join is useless but so is any other operation
# Another rewrite will (one day) handle all those cases
if 0 < len(new_tensors) < len(tensors):
# join eagerly returns a tensor when there is only one, no need for us to check
ret = join(axis, *new_tensors)
[old_output] = node.outputs
if ret.dtype != old_output.dtype:
ret = ret.astype(old_output.dtype)
copy_stack_trace(old_output, ret)
return [ret]
......@@ -1298,7 +1285,7 @@ def local_join_of_alloc(fgraph, node):
# Axis can never be lifted
# Non-axis allocated dimensions can be lifted if they are all broadcastable
[out] = node.outputs
axis = axis.data
static_axis = normalize_axis_index(axis.data, tensors[0].type.ndim)
broadcasted_dims = list(
zip(
......@@ -1320,7 +1307,7 @@ def local_join_of_alloc(fgraph, node):
lifteable_alloc_dims = {
dim
for dim in range(out.type.ndim)
if dim != axis and all(broadcasted_dims[dim])
if dim != static_axis and all(broadcasted_dims[dim])
}
if not lifteable_alloc_dims:
......@@ -1337,13 +1324,13 @@ def local_join_of_alloc(fgraph, node):
copy_stack_trace(tensor, new_tensor)
new_tensors.append(new_tensor)
new_join = node.op(axis, *new_tensors)
new_join = node.op(static_axis, *new_tensors)
copy_stack_trace(node.outputs[0], new_join)
# Reintroduce the lifted dims
post_join_shape = []
for i, alloc_dims in enumerate(zip(*alloc_shapes, strict=True)):
if i == axis:
if i == static_axis:
# The alloc dim along the axis is the sum of all the pre-join alloc dims
post_join_shape.append(add(*alloc_dims))
else:
......
......@@ -172,24 +172,6 @@ def test_Join(vals, axis):
)
def test_Join_view():
vals, vals_test = zip(
*(
(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
),
strict=True,
)
g = ptb.Join(view=1)(1, *vals)
with pytest.raises(NotImplementedError):
compare_numba_and_py(
vals,
g,
vals_test,
)
@pytest.mark.parametrize(
"n_splits, axis, values, sizes",
[
......
......@@ -1248,65 +1248,41 @@ def test_local_join_1():
def test_local_join_empty():
# test for vector, vector, empty to vector
# Vector case
empty_vec = np.asarray([], dtype=config.floatX)
a = vector("a")
s = pt.join(0, a, a, empty_vec)
f = function([a], s, mode=rewrite_mode)
val = f([1])
assert np.all(val == [1])
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert all(
not isinstance(n.op, Join) or len(n.inputs) == 3
for n in e
if isinstance(n.op, Join)
vec = vector("vec")
s = pt.join(0, vec, vec, empty_vec)
new_s = rewrite_graph(s)
assert equal_computations([new_s], [join(0, vec, vec)])
assert new_s.dtype == s.dtype
# Matrix case
empty_mat = np.zeros((2, 0), dtype=config.floatX)
empty_sym_mat = matrix("m", shape=(2, 0))
mat = matrix("mat", shape=(2, 10))
s = join(1, empty_mat, mat, empty_sym_mat, mat, mat)
new_s = rewrite_graph(s)
assert equal_computations([new_s], [join(1, mat, mat, mat)])
assert new_s.dtype == s.dtype
# Join can be completely removed, but casting and specify_shape are propagated
int_mat = matrix("int_mat", dtype=int)
s = join(-1, empty_mat, int_mat, empty_sym_mat)
new_s = rewrite_graph(s)
assert equal_computations(
[new_s], [specify_shape(int_mat, (2, None)).astype(s.dtype)]
)
assert f.maker.fgraph.outputs[0].dtype == config.floatX
# test for matrix join(1,a)
empty_mat = np.asarray([[]], dtype=config.floatX)
m = matrix("m")
s = join(1, empty_mat, m, m, m)
f = function([m], s, mode=rewrite_mode)
val = f([[1]])
assert np.all(val == [[1]])
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert all(
not isinstance(n.op, Join) or len(n.inputs) == 4
for n in e
if isinstance(n.op, Join)
)
assert f.maker.fgraph.outputs[0].dtype == config.floatX
# test for vector, vector, empty to matrix
# We can't rewrite this case.
s = pt.stack([a, a, empty_vec])
f = function([a], s, mode=rewrite_mode)
val = f([])
assert np.all(val == [1])
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert all(
not isinstance(n.op, Join) or len(n.inputs) == 4
for n in e
if isinstance(n.op, Join)
)
assert f.maker.fgraph.outputs[0].dtype == config.floatX
# test for matrix join(0,a)
# We can't rewrite this case.
s = join(0, m, np.asarray([[2.0]], dtype=config.floatX), m)
f = function([m], s, mode=rewrite_mode)
val = f([[1]])
assert np.all(val == [[1], [2], [1]])
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert all(
not isinstance(n.op, Join) or len(n.inputs) == 4
for n in e
if isinstance(n.op, Join)
)
assert f.maker.fgraph.outputs[0].dtype == config.floatX
# Dynamic axis, can't apply rewrite
axis = scalar("axis", dtype=int)
s = join(axis, empty_mat, int_mat, empty_sym_mat)
new_s = rewrite_graph(s)
assert equal_computations([new_s], [s])
# Stack introduces an expand_dims in the join, that's a nonzero dim!
s = pt.stack([vec, vec, empty_vec])
new_s = rewrite_graph(s)
assert equal_computations([new_s], [s])
def test_local_join_make_vector():
......
......@@ -2118,28 +2118,6 @@ class TestJoinAndSplit:
y = Split(2)(x, 0, [s, 5 - s])[0]
assert y.type.shape == (None,)
def test_join_inplace(self):
# Test join to work inplace.
#
# This function tests the case when several elements are passed to the
# join function but all except one of them are empty. In this case join
# should work inplace and the output should be the view of the non-empty
# element.
s = lscalar()
x = vector("x")
z = ptb.zeros((s,))
join = Join(view=0)
c = join(0, x, z, z)
f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True))
data = np.array([3, 4, 5], dtype=config.floatX)
if config.mode not in ["DebugMode", "DEBUG_MODE"]:
assert f(data, 0) is data
assert np.allclose(f(data, 0), [3, 4, 5])
def test_join_oneInput(self):
# Test join when only 1 input is given.
#
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论