提交 3de303d2 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Remove Join view flag

Do not normalize constant axis in make_node and fix rewrite that assumed this would always be positive
上级 ff092688
......@@ -87,14 +87,7 @@ def jax_funcify_Join(op, **kwargs):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [jnp.asarray(tensor) for tensor in tensors]
view = op.view
if (view != -1) and all(
tensor.shape[axis] == 0 for tensor in tensors[0:view] + tensors[view + 1 :]
):
return tensors[view]
else:
return jnp.concatenate(tensors, axis=axis)
return jnp.concatenate(tensors, axis=axis)
return join
......
......@@ -117,17 +117,9 @@ def numba_funcify_ARange(op, **kwargs):
@numba_funcify.register(Join)
def numba_funcify_Join(op, **kwargs):
view = op.view
if view != -1:
# TODO: Where (and why) is this `Join.view` even being used? From a
# quick search, the answer appears to be "nowhere", so we should
# probably just remove it.
raise NotImplementedError("The `view` parameter to `Join` is not supported")
@numba_basic.numba_njit
def join(axis, *tensors):
return np.concatenate(tensors, numba_basic.to_scalar(axis))
return np.concatenate(tensors, axis.item())
return join
......
import pytensor.tensor.basic as ptb
from pytensor.scan.basic import scan
from pytensor.tensor.basic import Join
from pytensor.tensor.math import ceil, eq, neq
from pytensor.tensor.subtensor import set_subtensor
......@@ -127,14 +126,12 @@ def scan_checkpoints(
# Pad the sequences if needed
if padding:
# Since padding could be an empty tensor, Join returns a view of s.
join = Join(view=0)
for i, s in enumerate(sequences):
overshoots_by = s.shape[0] % save_every_N
overshoots = neq(overshoots_by, 0)
n = (save_every_N - overshoots_by) * overshoots
z = ptb.zeros((n, *s.shape[1:]), dtype=s.dtype)
sequences[i] = join(0, s, z)
sequences[i] = ptb.join(0, s, z)
# Establish the input variables of the outer scan
o_sequences = [
......
......@@ -2434,37 +2434,17 @@ class Join(COp):
The axis has to be an index into the shape
>>> pt.join(2, x, y, z)
Traceback (most recent call last):
ValueError: Axis value 2 is out of range for the given input dimensions
numpy.exceptions.AxisError: axis 2 is out of bounds for array of dimension 2
Joined tensors must have the same rank
>>> pt.join(0, x, u)
Traceback (most recent call last):
TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [2, 1].
TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [2, 1]
"""
check_input = False
__props__ = ("view",)
def __init__(self, view=-1):
self.view = view
if view != -1:
# since the first input is always the axis, the tensors
# start from index 1.
self.view_map = {0: [1 + view]}
def __str__(self):
if self.view == -1:
return self.__class__.__name__
else:
classname = self.__class__.__name__
args = ", ".join(f"{p}={getattr(self, p)!r}" for p in self.__props__)
return f"{classname}{{{args}}}"
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, "view"):
self.view = -1
__props__ = ()
def make_node(self, axis, *tensors):
"""
......@@ -2481,74 +2461,61 @@ class Join(COp):
if not tensors:
raise ValueError("Cannot join an empty list of tensors")
axis = as_tensor_variable(axis)
if axis.type.dtype not in int_dtypes:
raise TypeError(f"Axis {axis} must be an integer type.")
if axis.type.ndim > 0:
raise TypeError(f"Axis {axis} must be 0-d.")
tensors = [as_tensor_variable(x) for x in tensors]
out_dtype = ps.upcast(*[x.type.dtype for x in tensors])
if not builtins.all(targs.type.ndim for targs in tensors):
if not builtins.all(targs.type.ndim > 0 for targs in tensors):
raise TypeError(
"Join cannot handle arguments of dimension 0."
" Use `stack` to join scalar values."
"Join cannot handle scalar arguments of dimension 0."
" Use `stack` to join scalar values or promote the scalars to vectors."
)
if len(tensors) == 1:
out_shape = tensors[0].type.shape
else:
# When the axis is fixed, a dimension should be
# broadcastable if at least one of the inputs is
# broadcastable on that dimension (see justification below),
# except for the axis dimension.
# Initialize bcastable all false, and then fill in some trues with
# the loops.
if not isinstance(axis, int):
try:
axis = int(get_scalar_constant_value(axis))
except NotScalarConstantError:
pass
ndim = tensors[0].type.ndim
if isinstance(axis, int):
# Basically, broadcastable -> length 1, but the
# converse does not hold. So we permit e.g. T/F/T
# joins, and if they fail at runtime they fail, but if
# they don't then it means that the argument where
# that broadcastable flag was False had length 1 along
# this dimension, and therefore this dimension should
# be broadcastable for the output.
if axis < -ndim:
raise IndexError(
f"Axis value {axis} is out of range for the given input dimensions"
)
if axis < 0:
axis += ndim
if axis > ndim - 1:
raise ValueError(
f"Axis value {axis} is out of range for the given input dimensions"
)
# NOTE: Constant negative axis can no longer be negative at this point.
in_shapes = [x.type.shape for x in tensors]
in_ndims = [len(s) for s in in_shapes]
if set(in_ndims) != {ndim}:
raise TypeError(
"Only tensors with the same number of dimensions can be joined."
f" Input ndims were: {in_ndims}."
)
if not builtins.all(x.ndim == ndim for x in tensors):
raise TypeError(
"Only tensors with the same number of dimensions can be joined. "
f"Input ndims were: {[x.ndim for x in tensors]}"
)
try:
static_axis = int(get_scalar_constant_value(axis))
except NotScalarConstantError:
static_axis = None
if static_axis is None:
# When axis isn't static, we can't conclude anything about output dimension
# (unless we had some degenerate zero arrays) that can be removed during rewrites.
# We could also raise errors if any dimensions are pairwise inconsistent across all the axes
# As no matter the join it would be invalid.
# However, dynamic axis is so rare that is not worth the trouble
out_shape = [None] * ndim
else: # We know the axis statically
static_axis = normalize_axis_index(static_axis, ndim)
static_shapes = [x.type.shape for x in tensors]
# Determine output shapes from a matrix of input shapes
in_shapes = np.array(in_shapes)
static_shapes = np.array(static_shapes)
out_shape = [None] * ndim
for d in range(ndim):
ins = in_shapes[:, d]
if d == axis:
# Any unknown size along the axis means we can't sum
ins = static_shapes[:, d]
if d == static_axis:
# Any unknown size along the axis means we can't infer it
if None in ins:
out_shape[d] = None
else:
out_shape[d] = sum(ins)
else:
inset = set(in_shapes[:, d])
inset = set(static_shapes[:, d])
# Other dims must match exactly,
# or if a mix of None and ? the output will be ?
# otherwise the input shapes are incompatible.
......@@ -2558,100 +2525,71 @@ class Join(COp):
(out_shape[d],) = inset - {None}
else:
raise ValueError(
f"all input array dimensions other than the specified `axis` ({axis})"
f"all input array dimensions other than the specified `axis` ({static_axis})"
" must match exactly, or be unknown (None),"
f" but along dimension {d}, the inputs shapes are incompatible: {ins}"
)
else:
# When the axis may vary, no dimension can be guaranteed to be
# broadcastable.
out_shape = [None] * tensors[0].type.ndim
if not builtins.all(x.ndim == len(out_shape) for x in tensors):
raise TypeError(
"Only tensors with the same number of dimensions can be joined"
)
inputs = [as_tensor_variable(axis), *tensors]
if inputs[0].type.dtype not in int_dtypes:
raise TypeError(f"Axis value {inputs[0]} must be an integer type")
inputs = [axis, *tensors]
out_dtype = ps.upcast(*[x.type.dtype for x in tensors])
return Apply(self, inputs, [tensor(dtype=out_dtype, shape=out_shape)])
def perform(self, node, axis_and_tensors, out_):
(out,) = out_
view = self.view
axis, tens = axis_and_tensors[0], axis_and_tensors[1:]
# we check these tensors for being empty.
if (view != -1) and all(
tensor.shape[axis] == 0 for tensor in tens[0:view] + tens[view + 1 :]
):
out[0] = tens[view]
else:
ndim = tens[0].ndim
if axis < -ndim:
raise IndexError(
f"Join axis {int(axis)} out of bounds [0, {int(ndim)})"
)
out[0] = np.asarray(
np.concatenate(tens, axis=axis), dtype=node.outputs[0].type.dtype
)
def perform(self, node, inputs, output_storage):
axis, *arrays = inputs
output_storage[0][0] = np.concatenate(
arrays, axis=axis, dtype=node.outputs[0].type.dtype
)
def c_code_cache_version(self):
return (5,)
return (6,)
def c_code(self, node, name, inputs, outputs, sub):
axis, tens = inputs[0], inputs[1:]
view = self.view
non_empty_tensor = tens[view]
input_1 = tens[0]
l = len(tens)
(out,) = outputs
axis, *arrays = inputs
[out] = outputs
n = len(arrays)
ndim = node.outputs[0].type.ndim
fail = sub["fail"]
adtype = node.inputs[0].type.dtype_specs()[1]
copy_to_list = (
f"""Py_INCREF({inp}); PyList_SetItem(list, {i}, (PyObject*){inp});"""
for i, inp in enumerate(tens)
)
# Most times axis is constant, inline it
# This is safe to do because the hash of the c_code includes the constant signature
if isinstance(node.inputs[0], Constant):
static_axis = int(node.inputs[0].data)
static_axis = normalize_axis_index(static_axis, ndim)
axis_def = f"{static_axis};"
axis_check = ""
else:
axis_ctype = node.inputs[0].type.dtype_specs()[1]
axis_def = f"(({axis_ctype} *)PyArray_DATA({axis}))[0];"
axis_check = f"""
if (axis < 0){{
axis = {ndim} + axis;
}}
if (axis >= {ndim} || axis < 0) {{
PyErr_SetString(PyExc_ValueError, "Join axis is out of bounds");
{fail}
}}
"""
copy_inputs_to_list = "\n".join(copy_to_list)
n = len(tens)
copy_arrays_to_tuple = "\n".join(
(
f"""Py_INCREF({array}); PyTuple_SetItem(arrays_tuple, {i}, (PyObject*){array});"""
for i, array in enumerate(arrays)
)
)
code = f"""
int axis = (({adtype} *)PyArray_DATA({axis}))[0];
PyObject* list = PyList_New({l});
{copy_inputs_to_list}
int tensors_lens_sum;
if({view} != -1) {{
tensors_lens_sum = 0;
for(int i=0; i < {n}; i++){{
tensors_lens_sum += PyArray_DIM((PyArrayObject *)(PyList_GetItem(list, i)), axis);
}}
tensors_lens_sum -= PyArray_DIM({non_empty_tensor}, axis);
}}
if({view} != -1 && tensors_lens_sum == 0) {{
Py_XDECREF({out});
Py_INCREF({non_empty_tensor});
{out} = {non_empty_tensor};
}}else{{
//PyObject* PyArray_Concatenate(PyObject* obj, int axis)
int ndim = PyArray_NDIM({input_1});
if( axis < -ndim ){{
PyErr_Format(PyExc_IndexError,
"Join axis %d out of bounds [0, %d)", axis, ndim);
{fail}
}}
Py_XDECREF({out});
{out} = (PyArrayObject *)PyArray_Concatenate(list, axis);
Py_DECREF(list);
if(!{out}){{
{fail}
}}
int axis = {axis_def}
PyArrayObject* arrays[{n}] = {{{','.join(arrays)}}};
PyObject* arrays_tuple = PyTuple_New({n});
{axis_check}
Py_XDECREF({out});
{copy_arrays_to_tuple}
{out} = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
Py_DECREF(arrays_tuple);
if(!{out}){{
{fail}
}}
"""
return code
......@@ -2661,22 +2599,21 @@ class Join(COp):
return [None]
return self.make_node(inputs[0], *eval_points[1:]).outputs
def grad(self, axis_and_tensors, grads):
def L_op(self, inputs, outputs, grads):
"""The gradient wrt a join op is a `Split`, used to partition
the gradient along the `axis` which was used for joining.
"""
(gz,) = grads
axis, tens = axis_and_tensors[0], axis_and_tensors[1:]
[gz] = grads
[out] = outputs
axis, *tensors = inputs
rval = [grad_undefined(self, 0, axis)]
dtypes = [as_tensor_variable(x).type.dtype for x in tens]
out_dtype = ps.upcast(*dtypes)
out_dtype = out.type.dtype
if "float" in out_dtype or "complex" in out_dtype:
# assume that this is differentiable
split = Split(len(tens))
split_gz = split(gz, axis, stack([shape(x)[axis] for x in tens]))
split_sizes = stack([shape(x)[axis] for x in tensors])
split_gz = split(gz, split_sizes, n_splits=len(tensors), axis=axis)
# If there is only one split, it might not be in a list.
if not isinstance(split_gz, list):
split_gz = [split_gz]
......@@ -2689,13 +2626,12 @@ class Join(COp):
else specify_broadcastable(
g, *(ax for (ax, s) in enumerate(t.type.shape) if s == 1)
)
for t, g in zip(tens, split_gz, strict=True)
for t, g in zip(tensors, split_gz, strict=True)
]
rval = rval + split_gz
else:
# the output has integer type, so the gradient through it
# is 0
rval = rval + [t.zeros_like(dtype=config.floatX) for t in tens]
# the output has integer type, so the gradient through it is 0
rval = rval + [t.zeros_like(dtype=config.floatX) for t in tensors]
return rval
......@@ -2715,7 +2651,8 @@ class Join(COp):
# An axis < -n_dim or >= ndim would be invalid, but this is
# not checked here. A `CheckAndRaise` `Op` would be a way of
# addressing that, but it may disrupt optimizations.
join_dim = switch(ge(node.inputs[0], 0), node.inputs[0], node.inputs[0] + n_dim)
axis = node.inputs[0]
join_dim = switch(ge(axis, 0), axis, axis + n_dim)
out_shapes = []
for dim in range(n_dim):
# we have to deal with 2 possible cases in here :
......@@ -2738,7 +2675,7 @@ class Join(COp):
return [tuple(out_shapes)]
join_ = Join()
_join = Join()
pprint.assign(Join, printing.FunctionPrinter(["join"]))
......@@ -2781,7 +2718,7 @@ def join(axis, *tensors_list):
if len(tensors_list) == 1:
return tensors_list[0]
else:
return join_(axis, *tensors_list)
return _join(axis, *tensors_list)
@_vectorize_node.register(Join)
......
......@@ -41,6 +41,7 @@ from pytensor.graph.rewriting.basic import (
node_rewriter,
)
from pytensor.graph.rewriting.db import RewriteDatabase
from pytensor.npy_2_compat import normalize_axis_index
from pytensor.raise_op import Assert, CheckAndRaise, assert_op
from pytensor.scalar.basic import Second
from pytensor.tensor.basic import (
......@@ -817,52 +818,38 @@ def local_join_1(fgraph, node):
return [tensors[0]]
# TODO: merge in local_useless_join
@register_infer_shape
@register_useless
@register_specialize
@register_canonicalize
@register_specialize
@node_rewriter([Join])
def local_join_empty(fgraph, node):
"""Join(i, x, y, empty) => Join(i, x, y)
Remove empty inputs to joins. The empty inputs can be anywhere.
"""
if not isinstance(node.op, Join):
return
new_inputs = []
axis, *tensors = node.inputs
try:
join_idx = get_scalar_constant_value(
static_axis = get_scalar_constant_value(
node.inputs[0], only_process_constants=True
)
except NotScalarConstantError:
return
for idx in range(1, len(node.inputs)):
inp = node.inputs[idx]
# We can not use size == 0,, as this can change shape from 3,0
# to 2,0. This trigger DebugMode error. This happen with
# stack(...,[]) as this add a dimshuffle on [], that add a
# dimensions with shape 1.
if isinstance(inp, Constant) and inp.data.shape[join_idx] == 0:
continue
new_inputs.append(inp)
if len(new_inputs) < len(node.inputs) - 1:
if len(new_inputs) == 0:
# at.join do not work in that case.
# constant folding will take care of this case.
return
ret = join(node.inputs[0], *new_inputs)
o = node.outputs[0]
if ret.dtype != o.dtype:
# Join can upcast some inputs
return
# Copy over stacktrace from previous output (after join op)
# to new output, because an error in the new op must be caused
# by an error in the old join op.
copy_stack_trace(node.outputs, ret)
new_tensors = [tensor for tensor in tensors if tensor.type.shape[static_axis] != 0]
# If there are zero tensors, the join is useless but so is any other operation
# Another rewrite will (one day) handle all those cases
if 0 < len(new_tensors) < len(tensors):
# join eagerly returns a tensor when there is only one, no need for us to check
ret = join(axis, *new_tensors)
[old_output] = node.outputs
if ret.dtype != old_output.dtype:
ret = ret.astype(old_output.dtype)
copy_stack_trace(old_output, ret)
return [ret]
......@@ -1298,7 +1285,7 @@ def local_join_of_alloc(fgraph, node):
# Axis can never be lifted
# Non-axis allocated dimensions can be lifted if they are all broadcastable
[out] = node.outputs
axis = axis.data
static_axis = normalize_axis_index(axis.data, tensors[0].type.ndim)
broadcasted_dims = list(
zip(
......@@ -1320,7 +1307,7 @@ def local_join_of_alloc(fgraph, node):
lifteable_alloc_dims = {
dim
for dim in range(out.type.ndim)
if dim != axis and all(broadcasted_dims[dim])
if dim != static_axis and all(broadcasted_dims[dim])
}
if not lifteable_alloc_dims:
......@@ -1337,13 +1324,13 @@ def local_join_of_alloc(fgraph, node):
copy_stack_trace(tensor, new_tensor)
new_tensors.append(new_tensor)
new_join = node.op(axis, *new_tensors)
new_join = node.op(static_axis, *new_tensors)
copy_stack_trace(node.outputs[0], new_join)
# Reintroduce the lifted dims
post_join_shape = []
for i, alloc_dims in enumerate(zip(*alloc_shapes, strict=True)):
if i == axis:
if i == static_axis:
# The alloc dim along the axis is the sum of all the pre-join alloc dims
post_join_shape.append(add(*alloc_dims))
else:
......
......@@ -172,24 +172,6 @@ def test_Join(vals, axis):
)
def test_Join_view():
vals, vals_test = zip(
*(
(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
),
strict=True,
)
g = ptb.Join(view=1)(1, *vals)
with pytest.raises(NotImplementedError):
compare_numba_and_py(
vals,
g,
vals_test,
)
@pytest.mark.parametrize(
"n_splits, axis, values, sizes",
[
......
......@@ -1248,65 +1248,41 @@ def test_local_join_1():
def test_local_join_empty():
# test for vector, vector, empty to vector
# Vector case
empty_vec = np.asarray([], dtype=config.floatX)
a = vector("a")
s = pt.join(0, a, a, empty_vec)
f = function([a], s, mode=rewrite_mode)
val = f([1])
assert np.all(val == [1])
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert all(
not isinstance(n.op, Join) or len(n.inputs) == 3
for n in e
if isinstance(n.op, Join)
vec = vector("vec")
s = pt.join(0, vec, vec, empty_vec)
new_s = rewrite_graph(s)
assert equal_computations([new_s], [join(0, vec, vec)])
assert new_s.dtype == s.dtype
# Matrix case
empty_mat = np.zeros((2, 0), dtype=config.floatX)
empty_sym_mat = matrix("m", shape=(2, 0))
mat = matrix("mat", shape=(2, 10))
s = join(1, empty_mat, mat, empty_sym_mat, mat, mat)
new_s = rewrite_graph(s)
assert equal_computations([new_s], [join(1, mat, mat, mat)])
assert new_s.dtype == s.dtype
# Join can be completely removed, but casting and specify_shape are propagated
int_mat = matrix("int_mat", dtype=int)
s = join(-1, empty_mat, int_mat, empty_sym_mat)
new_s = rewrite_graph(s)
assert equal_computations(
[new_s], [specify_shape(int_mat, (2, None)).astype(s.dtype)]
)
assert f.maker.fgraph.outputs[0].dtype == config.floatX
# test for matrix join(1,a)
empty_mat = np.asarray([[]], dtype=config.floatX)
m = matrix("m")
s = join(1, empty_mat, m, m, m)
f = function([m], s, mode=rewrite_mode)
val = f([[1]])
assert np.all(val == [[1]])
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert all(
not isinstance(n.op, Join) or len(n.inputs) == 4
for n in e
if isinstance(n.op, Join)
)
assert f.maker.fgraph.outputs[0].dtype == config.floatX
# test for vector, vector, empty to matrix
# We can't rewrite this case.
s = pt.stack([a, a, empty_vec])
f = function([a], s, mode=rewrite_mode)
val = f([])
assert np.all(val == [1])
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert all(
not isinstance(n.op, Join) or len(n.inputs) == 4
for n in e
if isinstance(n.op, Join)
)
assert f.maker.fgraph.outputs[0].dtype == config.floatX
# test for matrix join(0,a)
# We can't rewrite this case.
s = join(0, m, np.asarray([[2.0]], dtype=config.floatX), m)
f = function([m], s, mode=rewrite_mode)
val = f([[1]])
assert np.all(val == [[1], [2], [1]])
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert all(
not isinstance(n.op, Join) or len(n.inputs) == 4
for n in e
if isinstance(n.op, Join)
)
assert f.maker.fgraph.outputs[0].dtype == config.floatX
# Dynamic axis, can't apply rewrite
axis = scalar("axis", dtype=int)
s = join(axis, empty_mat, int_mat, empty_sym_mat)
new_s = rewrite_graph(s)
assert equal_computations([new_s], [s])
# Stack introduces an expand_dims in the join, that's a nonzero dim!
s = pt.stack([vec, vec, empty_vec])
new_s = rewrite_graph(s)
assert equal_computations([new_s], [s])
def test_local_join_make_vector():
......
......@@ -2118,28 +2118,6 @@ class TestJoinAndSplit:
y = Split(2)(x, 0, [s, 5 - s])[0]
assert y.type.shape == (None,)
def test_join_inplace(self):
# Test join to work inplace.
#
# This function tests the case when several elements are passed to the
# join function but all except one of them are empty. In this case join
# should work inplace and the output should be the view of the non-empty
# element.
s = lscalar()
x = vector("x")
z = ptb.zeros((s,))
join = Join(view=0)
c = join(0, x, z, z)
f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True))
data = np.array([3, 4, 5], dtype=config.floatX)
if config.mode not in ["DebugMode", "DEBUG_MODE"]:
assert f(data, 0) is data
assert np.allclose(f(data, 0), [3, 4, 5])
def test_join_oneInput(self):
# Test join when only 1 input is given.
#
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论