提交 3de303d2 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Remove Join view flag

Do not normalize constant axis in make_node and fix rewrite that assumed this would always be positive
上级 ff092688
...@@ -87,13 +87,6 @@ def jax_funcify_Join(op, **kwargs): ...@@ -87,13 +87,6 @@ def jax_funcify_Join(op, **kwargs):
def join(axis, *tensors): def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim # tensors could also be tuples, and in this case they don't have a ndim
tensors = [jnp.asarray(tensor) for tensor in tensors] tensors = [jnp.asarray(tensor) for tensor in tensors]
view = op.view
if (view != -1) and all(
tensor.shape[axis] == 0 for tensor in tensors[0:view] + tensors[view + 1 :]
):
return tensors[view]
else:
return jnp.concatenate(tensors, axis=axis) return jnp.concatenate(tensors, axis=axis)
return join return join
......
...@@ -117,17 +117,9 @@ def numba_funcify_ARange(op, **kwargs): ...@@ -117,17 +117,9 @@ def numba_funcify_ARange(op, **kwargs):
@numba_funcify.register(Join) @numba_funcify.register(Join)
def numba_funcify_Join(op, **kwargs): def numba_funcify_Join(op, **kwargs):
view = op.view
if view != -1:
# TODO: Where (and why) is this `Join.view` even being used? From a
# quick search, the answer appears to be "nowhere", so we should
# probably just remove it.
raise NotImplementedError("The `view` parameter to `Join` is not supported")
@numba_basic.numba_njit @numba_basic.numba_njit
def join(axis, *tensors): def join(axis, *tensors):
return np.concatenate(tensors, numba_basic.to_scalar(axis)) return np.concatenate(tensors, axis.item())
return join return join
......
import pytensor.tensor.basic as ptb import pytensor.tensor.basic as ptb
from pytensor.scan.basic import scan from pytensor.scan.basic import scan
from pytensor.tensor.basic import Join
from pytensor.tensor.math import ceil, eq, neq from pytensor.tensor.math import ceil, eq, neq
from pytensor.tensor.subtensor import set_subtensor from pytensor.tensor.subtensor import set_subtensor
...@@ -127,14 +126,12 @@ def scan_checkpoints( ...@@ -127,14 +126,12 @@ def scan_checkpoints(
# Pad the sequences if needed # Pad the sequences if needed
if padding: if padding:
# Since padding could be an empty tensor, Join returns a view of s.
join = Join(view=0)
for i, s in enumerate(sequences): for i, s in enumerate(sequences):
overshoots_by = s.shape[0] % save_every_N overshoots_by = s.shape[0] % save_every_N
overshoots = neq(overshoots_by, 0) overshoots = neq(overshoots_by, 0)
n = (save_every_N - overshoots_by) * overshoots n = (save_every_N - overshoots_by) * overshoots
z = ptb.zeros((n, *s.shape[1:]), dtype=s.dtype) z = ptb.zeros((n, *s.shape[1:]), dtype=s.dtype)
sequences[i] = join(0, s, z) sequences[i] = ptb.join(0, s, z)
# Establish the input variables of the outer scan # Establish the input variables of the outer scan
o_sequences = [ o_sequences = [
......
...@@ -2434,37 +2434,17 @@ class Join(COp): ...@@ -2434,37 +2434,17 @@ class Join(COp):
The axis has to be an index into the shape The axis has to be an index into the shape
>>> pt.join(2, x, y, z) >>> pt.join(2, x, y, z)
Traceback (most recent call last): Traceback (most recent call last):
ValueError: Axis value 2 is out of range for the given input dimensions numpy.exceptions.AxisError: axis 2 is out of bounds for array of dimension 2
Joined tensors must have the same rank Joined tensors must have the same rank
>>> pt.join(0, x, u) >>> pt.join(0, x, u)
Traceback (most recent call last): Traceback (most recent call last):
TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [2, 1]. TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [2, 1]
""" """
check_input = False check_input = False
__props__ = ("view",) __props__ = ()
def __init__(self, view=-1):
self.view = view
if view != -1:
# since the first input is always the axis, the tensors
# start from index 1.
self.view_map = {0: [1 + view]}
def __str__(self):
if self.view == -1:
return self.__class__.__name__
else:
classname = self.__class__.__name__
args = ", ".join(f"{p}={getattr(self, p)!r}" for p in self.__props__)
return f"{classname}{{{args}}}"
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, "view"):
self.view = -1
def make_node(self, axis, *tensors): def make_node(self, axis, *tensors):
""" """
...@@ -2481,74 +2461,61 @@ class Join(COp): ...@@ -2481,74 +2461,61 @@ class Join(COp):
if not tensors: if not tensors:
raise ValueError("Cannot join an empty list of tensors") raise ValueError("Cannot join an empty list of tensors")
axis = as_tensor_variable(axis)
if axis.type.dtype not in int_dtypes:
raise TypeError(f"Axis {axis} must be an integer type.")
if axis.type.ndim > 0:
raise TypeError(f"Axis {axis} must be 0-d.")
tensors = [as_tensor_variable(x) for x in tensors] tensors = [as_tensor_variable(x) for x in tensors]
out_dtype = ps.upcast(*[x.type.dtype for x in tensors])
if not builtins.all(targs.type.ndim for targs in tensors): if not builtins.all(targs.type.ndim > 0 for targs in tensors):
raise TypeError( raise TypeError(
"Join cannot handle arguments of dimension 0." "Join cannot handle scalar arguments of dimension 0."
" Use `stack` to join scalar values." " Use `stack` to join scalar values or promote the scalars to vectors."
) )
if len(tensors) == 1: if len(tensors) == 1:
out_shape = tensors[0].type.shape out_shape = tensors[0].type.shape
else: else:
# When the axis is fixed, a dimension should be
# broadcastable if at least one of the inputs is
# broadcastable on that dimension (see justification below),
# except for the axis dimension.
# Initialize bcastable all false, and then fill in some trues with
# the loops.
if not isinstance(axis, int):
try:
axis = int(get_scalar_constant_value(axis))
except NotScalarConstantError:
pass
ndim = tensors[0].type.ndim ndim = tensors[0].type.ndim
if isinstance(axis, int):
# Basically, broadcastable -> length 1, but the
# converse does not hold. So we permit e.g. T/F/T
# joins, and if they fail at runtime they fail, but if
# they don't then it means that the argument where
# that broadcastable flag was False had length 1 along
# this dimension, and therefore this dimension should
# be broadcastable for the output.
if axis < -ndim:
raise IndexError(
f"Axis value {axis} is out of range for the given input dimensions"
)
if axis < 0:
axis += ndim
if axis > ndim - 1:
raise ValueError(
f"Axis value {axis} is out of range for the given input dimensions"
)
# NOTE: Constant negative axis can no longer be negative at this point.
in_shapes = [x.type.shape for x in tensors] if not builtins.all(x.ndim == ndim for x in tensors):
in_ndims = [len(s) for s in in_shapes]
if set(in_ndims) != {ndim}:
raise TypeError( raise TypeError(
"Only tensors with the same number of dimensions can be joined." "Only tensors with the same number of dimensions can be joined. "
f" Input ndims were: {in_ndims}." f"Input ndims were: {[x.ndim for x in tensors]}"
) )
try:
static_axis = int(get_scalar_constant_value(axis))
except NotScalarConstantError:
static_axis = None
if static_axis is None:
# When axis isn't static, we can't conclude anything about output dimension
# (unless we had some degenerate zero arrays) that can be removed during rewrites.
# We could also raise errors if any dimensions are pairwise inconsistent across all the axes
# As no matter the join it would be invalid.
# However, dynamic axis is so rare that is not worth the trouble
out_shape = [None] * ndim
else: # We know the axis statically
static_axis = normalize_axis_index(static_axis, ndim)
static_shapes = [x.type.shape for x in tensors]
# Determine output shapes from a matrix of input shapes # Determine output shapes from a matrix of input shapes
in_shapes = np.array(in_shapes) static_shapes = np.array(static_shapes)
out_shape = [None] * ndim out_shape = [None] * ndim
for d in range(ndim): for d in range(ndim):
ins = in_shapes[:, d] ins = static_shapes[:, d]
if d == axis: if d == static_axis:
# Any unknown size along the axis means we can't sum # Any unknown size along the axis means we can't infer it
if None in ins: if None in ins:
out_shape[d] = None out_shape[d] = None
else: else:
out_shape[d] = sum(ins) out_shape[d] = sum(ins)
else: else:
inset = set(in_shapes[:, d]) inset = set(static_shapes[:, d])
# Other dims must match exactly, # Other dims must match exactly,
# or if a mix of None and ? the output will be ? # or if a mix of None and ? the output will be ?
# otherwise the input shapes are incompatible. # otherwise the input shapes are incompatible.
...@@ -2558,101 +2525,72 @@ class Join(COp): ...@@ -2558,101 +2525,72 @@ class Join(COp):
(out_shape[d],) = inset - {None} (out_shape[d],) = inset - {None}
else: else:
raise ValueError( raise ValueError(
f"all input array dimensions other than the specified `axis` ({axis})" f"all input array dimensions other than the specified `axis` ({static_axis})"
" must match exactly, or be unknown (None)," " must match exactly, or be unknown (None),"
f" but along dimension {d}, the inputs shapes are incompatible: {ins}" f" but along dimension {d}, the inputs shapes are incompatible: {ins}"
) )
else:
# When the axis may vary, no dimension can be guaranteed to be
# broadcastable.
out_shape = [None] * tensors[0].type.ndim
if not builtins.all(x.ndim == len(out_shape) for x in tensors):
raise TypeError(
"Only tensors with the same number of dimensions can be joined"
)
inputs = [as_tensor_variable(axis), *tensors]
if inputs[0].type.dtype not in int_dtypes:
raise TypeError(f"Axis value {inputs[0]} must be an integer type")
inputs = [axis, *tensors]
out_dtype = ps.upcast(*[x.type.dtype for x in tensors])
return Apply(self, inputs, [tensor(dtype=out_dtype, shape=out_shape)]) return Apply(self, inputs, [tensor(dtype=out_dtype, shape=out_shape)])
def perform(self, node, axis_and_tensors, out_): def perform(self, node, inputs, output_storage):
(out,) = out_ axis, *arrays = inputs
view = self.view output_storage[0][0] = np.concatenate(
axis, tens = axis_and_tensors[0], axis_and_tensors[1:] arrays, axis=axis, dtype=node.outputs[0].type.dtype
# we check these tensors for being empty.
if (view != -1) and all(
tensor.shape[axis] == 0 for tensor in tens[0:view] + tens[view + 1 :]
):
out[0] = tens[view]
else:
ndim = tens[0].ndim
if axis < -ndim:
raise IndexError(
f"Join axis {int(axis)} out of bounds [0, {int(ndim)})"
)
out[0] = np.asarray(
np.concatenate(tens, axis=axis), dtype=node.outputs[0].type.dtype
) )
def c_code_cache_version(self): def c_code_cache_version(self):
return (5,) return (6,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
axis, tens = inputs[0], inputs[1:] axis, *arrays = inputs
view = self.view [out] = outputs
non_empty_tensor = tens[view] n = len(arrays)
input_1 = tens[0] ndim = node.outputs[0].type.ndim
l = len(tens)
(out,) = outputs
fail = sub["fail"] fail = sub["fail"]
adtype = node.inputs[0].type.dtype_specs()[1]
copy_to_list = (
f"""Py_INCREF({inp}); PyList_SetItem(list, {i}, (PyObject*){inp});"""
for i, inp in enumerate(tens)
)
copy_inputs_to_list = "\n".join(copy_to_list)
n = len(tens)
code = f""" # Most times axis is constant, inline it
int axis = (({adtype} *)PyArray_DATA({axis}))[0]; # This is safe to do because the hash of the c_code includes the constant signature
PyObject* list = PyList_New({l}); if isinstance(node.inputs[0], Constant):
{copy_inputs_to_list} static_axis = int(node.inputs[0].data)
int tensors_lens_sum; static_axis = normalize_axis_index(static_axis, ndim)
if({view} != -1) {{ axis_def = f"{static_axis};"
tensors_lens_sum = 0; axis_check = ""
else:
for(int i=0; i < {n}; i++){{ axis_ctype = node.inputs[0].type.dtype_specs()[1]
tensors_lens_sum += PyArray_DIM((PyArrayObject *)(PyList_GetItem(list, i)), axis); axis_def = f"(({axis_ctype} *)PyArray_DATA({axis}))[0];"
}} axis_check = f"""
tensors_lens_sum -= PyArray_DIM({non_empty_tensor}, axis); if (axis < 0){{
axis = {ndim} + axis;
}} }}
if({view} != -1 && tensors_lens_sum == 0) {{ if (axis >= {ndim} || axis < 0) {{
Py_XDECREF({out}); PyErr_SetString(PyExc_ValueError, "Join axis is out of bounds");
Py_INCREF({non_empty_tensor});
{out} = {non_empty_tensor};
}}else{{
//PyObject* PyArray_Concatenate(PyObject* obj, int axis)
int ndim = PyArray_NDIM({input_1});
if( axis < -ndim ){{
PyErr_Format(PyExc_IndexError,
"Join axis %d out of bounds [0, %d)", axis, ndim);
{fail} {fail}
}} }}
"""
copy_arrays_to_tuple = "\n".join(
(
f"""Py_INCREF({array}); PyTuple_SetItem(arrays_tuple, {i}, (PyObject*){array});"""
for i, array in enumerate(arrays)
)
)
code = f"""
int axis = {axis_def}
PyArrayObject* arrays[{n}] = {{{','.join(arrays)}}};
PyObject* arrays_tuple = PyTuple_New({n});
{axis_check}
Py_XDECREF({out}); Py_XDECREF({out});
{out} = (PyArrayObject *)PyArray_Concatenate(list, axis); {copy_arrays_to_tuple}
Py_DECREF(list); {out} = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
Py_DECREF(arrays_tuple);
if(!{out}){{ if(!{out}){{
{fail} {fail}
}} }}
}}
""" """
return code return code
...@@ -2661,22 +2599,21 @@ class Join(COp): ...@@ -2661,22 +2599,21 @@ class Join(COp):
return [None] return [None]
return self.make_node(inputs[0], *eval_points[1:]).outputs return self.make_node(inputs[0], *eval_points[1:]).outputs
def grad(self, axis_and_tensors, grads): def L_op(self, inputs, outputs, grads):
"""The gradient wrt a join op is a `Split`, used to partition """The gradient wrt a join op is a `Split`, used to partition
the gradient along the `axis` which was used for joining. the gradient along the `axis` which was used for joining.
""" """
(gz,) = grads [gz] = grads
axis, tens = axis_and_tensors[0], axis_and_tensors[1:] [out] = outputs
axis, *tensors = inputs
rval = [grad_undefined(self, 0, axis)] rval = [grad_undefined(self, 0, axis)]
out_dtype = out.type.dtype
dtypes = [as_tensor_variable(x).type.dtype for x in tens]
out_dtype = ps.upcast(*dtypes)
if "float" in out_dtype or "complex" in out_dtype: if "float" in out_dtype or "complex" in out_dtype:
# assume that this is differentiable # assume that this is differentiable
split = Split(len(tens)) split_sizes = stack([shape(x)[axis] for x in tensors])
split_gz = split(gz, axis, stack([shape(x)[axis] for x in tens])) split_gz = split(gz, split_sizes, n_splits=len(tensors), axis=axis)
# If there is only one split, it might not be in a list. # If there is only one split, it might not be in a list.
if not isinstance(split_gz, list): if not isinstance(split_gz, list):
split_gz = [split_gz] split_gz = [split_gz]
...@@ -2689,13 +2626,12 @@ class Join(COp): ...@@ -2689,13 +2626,12 @@ class Join(COp):
else specify_broadcastable( else specify_broadcastable(
g, *(ax for (ax, s) in enumerate(t.type.shape) if s == 1) g, *(ax for (ax, s) in enumerate(t.type.shape) if s == 1)
) )
for t, g in zip(tens, split_gz, strict=True) for t, g in zip(tensors, split_gz, strict=True)
] ]
rval = rval + split_gz rval = rval + split_gz
else: else:
# the output has integer type, so the gradient through it # the output has integer type, so the gradient through it is 0
# is 0 rval = rval + [t.zeros_like(dtype=config.floatX) for t in tensors]
rval = rval + [t.zeros_like(dtype=config.floatX) for t in tens]
return rval return rval
...@@ -2715,7 +2651,8 @@ class Join(COp): ...@@ -2715,7 +2651,8 @@ class Join(COp):
# An axis < -n_dim or >= ndim would be invalid, but this is # An axis < -n_dim or >= ndim would be invalid, but this is
# not checked here. A `CheckAndRaise` `Op` would be a way of # not checked here. A `CheckAndRaise` `Op` would be a way of
# addressing that, but it may disrupt optimizations. # addressing that, but it may disrupt optimizations.
join_dim = switch(ge(node.inputs[0], 0), node.inputs[0], node.inputs[0] + n_dim) axis = node.inputs[0]
join_dim = switch(ge(axis, 0), axis, axis + n_dim)
out_shapes = [] out_shapes = []
for dim in range(n_dim): for dim in range(n_dim):
# we have to deal with 2 possible cases in here : # we have to deal with 2 possible cases in here :
...@@ -2738,7 +2675,7 @@ class Join(COp): ...@@ -2738,7 +2675,7 @@ class Join(COp):
return [tuple(out_shapes)] return [tuple(out_shapes)]
join_ = Join() _join = Join()
pprint.assign(Join, printing.FunctionPrinter(["join"])) pprint.assign(Join, printing.FunctionPrinter(["join"]))
...@@ -2781,7 +2718,7 @@ def join(axis, *tensors_list): ...@@ -2781,7 +2718,7 @@ def join(axis, *tensors_list):
if len(tensors_list) == 1: if len(tensors_list) == 1:
return tensors_list[0] return tensors_list[0]
else: else:
return join_(axis, *tensors_list) return _join(axis, *tensors_list)
@_vectorize_node.register(Join) @_vectorize_node.register(Join)
......
...@@ -41,6 +41,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -41,6 +41,7 @@ from pytensor.graph.rewriting.basic import (
node_rewriter, node_rewriter,
) )
from pytensor.graph.rewriting.db import RewriteDatabase from pytensor.graph.rewriting.db import RewriteDatabase
from pytensor.npy_2_compat import normalize_axis_index
from pytensor.raise_op import Assert, CheckAndRaise, assert_op from pytensor.raise_op import Assert, CheckAndRaise, assert_op
from pytensor.scalar.basic import Second from pytensor.scalar.basic import Second
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
...@@ -817,52 +818,38 @@ def local_join_1(fgraph, node): ...@@ -817,52 +818,38 @@ def local_join_1(fgraph, node):
return [tensors[0]] return [tensors[0]]
# TODO: merge in local_useless_join
@register_infer_shape
@register_useless @register_useless
@register_specialize
@register_canonicalize @register_canonicalize
@register_specialize
@node_rewriter([Join]) @node_rewriter([Join])
def local_join_empty(fgraph, node): def local_join_empty(fgraph, node):
"""Join(i, x, y, empty) => Join(i, x, y) """Join(i, x, y, empty) => Join(i, x, y)
Remove empty inputs to joins. The empty inputs can be anywhere. Remove empty inputs to joins. The empty inputs can be anywhere.
""" """
if not isinstance(node.op, Join): axis, *tensors = node.inputs
return
new_inputs = []
try: try:
join_idx = get_scalar_constant_value( static_axis = get_scalar_constant_value(
node.inputs[0], only_process_constants=True node.inputs[0], only_process_constants=True
) )
except NotScalarConstantError: except NotScalarConstantError:
return return
for idx in range(1, len(node.inputs)):
inp = node.inputs[idx]
# We can not use size == 0,, as this can change shape from 3,0
# to 2,0. This trigger DebugMode error. This happen with
# stack(...,[]) as this add a dimshuffle on [], that add a
# dimensions with shape 1.
if isinstance(inp, Constant) and inp.data.shape[join_idx] == 0:
continue
new_inputs.append(inp)
if len(new_inputs) < len(node.inputs) - 1:
if len(new_inputs) == 0:
# at.join do not work in that case.
# constant folding will take care of this case.
return
ret = join(node.inputs[0], *new_inputs)
o = node.outputs[0]
if ret.dtype != o.dtype:
# Join can upcast some inputs
return
# Copy over stacktrace from previous output (after join op) new_tensors = [tensor for tensor in tensors if tensor.type.shape[static_axis] != 0]
# to new output, because an error in the new op must be caused
# by an error in the old join op. # If there are zero tensors, the join is useless but so is any other operation
copy_stack_trace(node.outputs, ret) # Another rewrite will (one day) handle all those cases
if 0 < len(new_tensors) < len(tensors):
# join eagerly returns a tensor when there is only one, no need for us to check
ret = join(axis, *new_tensors)
[old_output] = node.outputs
if ret.dtype != old_output.dtype:
ret = ret.astype(old_output.dtype)
copy_stack_trace(old_output, ret)
return [ret] return [ret]
...@@ -1298,7 +1285,7 @@ def local_join_of_alloc(fgraph, node): ...@@ -1298,7 +1285,7 @@ def local_join_of_alloc(fgraph, node):
# Axis can never be lifted # Axis can never be lifted
# Non-axis allocated dimensions can be lifted if they are all broadcastable # Non-axis allocated dimensions can be lifted if they are all broadcastable
[out] = node.outputs [out] = node.outputs
axis = axis.data static_axis = normalize_axis_index(axis.data, tensors[0].type.ndim)
broadcasted_dims = list( broadcasted_dims = list(
zip( zip(
...@@ -1320,7 +1307,7 @@ def local_join_of_alloc(fgraph, node): ...@@ -1320,7 +1307,7 @@ def local_join_of_alloc(fgraph, node):
lifteable_alloc_dims = { lifteable_alloc_dims = {
dim dim
for dim in range(out.type.ndim) for dim in range(out.type.ndim)
if dim != axis and all(broadcasted_dims[dim]) if dim != static_axis and all(broadcasted_dims[dim])
} }
if not lifteable_alloc_dims: if not lifteable_alloc_dims:
...@@ -1337,13 +1324,13 @@ def local_join_of_alloc(fgraph, node): ...@@ -1337,13 +1324,13 @@ def local_join_of_alloc(fgraph, node):
copy_stack_trace(tensor, new_tensor) copy_stack_trace(tensor, new_tensor)
new_tensors.append(new_tensor) new_tensors.append(new_tensor)
new_join = node.op(axis, *new_tensors) new_join = node.op(static_axis, *new_tensors)
copy_stack_trace(node.outputs[0], new_join) copy_stack_trace(node.outputs[0], new_join)
# Reintroduce the lifted dims # Reintroduce the lifted dims
post_join_shape = [] post_join_shape = []
for i, alloc_dims in enumerate(zip(*alloc_shapes, strict=True)): for i, alloc_dims in enumerate(zip(*alloc_shapes, strict=True)):
if i == axis: if i == static_axis:
# The alloc dim along the axis is the sum of all the pre-join alloc dims # The alloc dim along the axis is the sum of all the pre-join alloc dims
post_join_shape.append(add(*alloc_dims)) post_join_shape.append(add(*alloc_dims))
else: else:
......
...@@ -172,24 +172,6 @@ def test_Join(vals, axis): ...@@ -172,24 +172,6 @@ def test_Join(vals, axis):
) )
def test_Join_view():
vals, vals_test = zip(
*(
(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
(pt.matrix(), rng.normal(size=(2, 2)).astype(config.floatX)),
),
strict=True,
)
g = ptb.Join(view=1)(1, *vals)
with pytest.raises(NotImplementedError):
compare_numba_and_py(
vals,
g,
vals_test,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"n_splits, axis, values, sizes", "n_splits, axis, values, sizes",
[ [
......
...@@ -1248,65 +1248,41 @@ def test_local_join_1(): ...@@ -1248,65 +1248,41 @@ def test_local_join_1():
def test_local_join_empty(): def test_local_join_empty():
# test for vector, vector, empty to vector # Vector case
empty_vec = np.asarray([], dtype=config.floatX) empty_vec = np.asarray([], dtype=config.floatX)
a = vector("a") vec = vector("vec")
s = pt.join(0, a, a, empty_vec) s = pt.join(0, vec, vec, empty_vec)
f = function([a], s, mode=rewrite_mode) new_s = rewrite_graph(s)
val = f([1]) assert equal_computations([new_s], [join(0, vec, vec)])
assert np.all(val == [1]) assert new_s.dtype == s.dtype
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1 # Matrix case
assert all( empty_mat = np.zeros((2, 0), dtype=config.floatX)
not isinstance(n.op, Join) or len(n.inputs) == 3 empty_sym_mat = matrix("m", shape=(2, 0))
for n in e mat = matrix("mat", shape=(2, 10))
if isinstance(n.op, Join) s = join(1, empty_mat, mat, empty_sym_mat, mat, mat)
) new_s = rewrite_graph(s)
assert f.maker.fgraph.outputs[0].dtype == config.floatX assert equal_computations([new_s], [join(1, mat, mat, mat)])
assert new_s.dtype == s.dtype
# test for matrix join(1,a)
empty_mat = np.asarray([[]], dtype=config.floatX) # Join can be completely removed, but casting and specify_shape are propagated
m = matrix("m") int_mat = matrix("int_mat", dtype=int)
s = join(1, empty_mat, m, m, m) s = join(-1, empty_mat, int_mat, empty_sym_mat)
f = function([m], s, mode=rewrite_mode) new_s = rewrite_graph(s)
val = f([[1]]) assert equal_computations(
assert np.all(val == [[1]]) [new_s], [specify_shape(int_mat, (2, None)).astype(s.dtype)]
e = f.maker.fgraph.toposort() )
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert all( # Dynamic axis, can't apply rewrite
not isinstance(n.op, Join) or len(n.inputs) == 4 axis = scalar("axis", dtype=int)
for n in e s = join(axis, empty_mat, int_mat, empty_sym_mat)
if isinstance(n.op, Join) new_s = rewrite_graph(s)
) assert equal_computations([new_s], [s])
assert f.maker.fgraph.outputs[0].dtype == config.floatX
# test for vector, vector, empty to matrix # Stack introduces an expand_dims in the join, that's a nonzero dim!
# We can't rewrite this case. s = pt.stack([vec, vec, empty_vec])
s = pt.stack([a, a, empty_vec]) new_s = rewrite_graph(s)
f = function([a], s, mode=rewrite_mode) assert equal_computations([new_s], [s])
val = f([])
assert np.all(val == [1])
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert all(
not isinstance(n.op, Join) or len(n.inputs) == 4
for n in e
if isinstance(n.op, Join)
)
assert f.maker.fgraph.outputs[0].dtype == config.floatX
# test for matrix join(0,a)
# We can't rewrite this case.
s = join(0, m, np.asarray([[2.0]], dtype=config.floatX), m)
f = function([m], s, mode=rewrite_mode)
val = f([[1]])
assert np.all(val == [[1], [2], [1]])
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Join)]) == 1
assert all(
not isinstance(n.op, Join) or len(n.inputs) == 4
for n in e
if isinstance(n.op, Join)
)
assert f.maker.fgraph.outputs[0].dtype == config.floatX
def test_local_join_make_vector(): def test_local_join_make_vector():
......
...@@ -2118,28 +2118,6 @@ class TestJoinAndSplit: ...@@ -2118,28 +2118,6 @@ class TestJoinAndSplit:
y = Split(2)(x, 0, [s, 5 - s])[0] y = Split(2)(x, 0, [s, 5 - s])[0]
assert y.type.shape == (None,) assert y.type.shape == (None,)
def test_join_inplace(self):
# Test join to work inplace.
#
# This function tests the case when several elements are passed to the
# join function but all except one of them are empty. In this case join
# should work inplace and the output should be the view of the non-empty
# element.
s = lscalar()
x = vector("x")
z = ptb.zeros((s,))
join = Join(view=0)
c = join(0, x, z, z)
f = pytensor.function([In(x, borrow=True), s], Out(c, borrow=True))
data = np.array([3, 4, 5], dtype=config.floatX)
if config.mode not in ["DebugMode", "DEBUG_MODE"]:
assert f(data, 0) is data
assert np.allclose(f(data, 0), [3, 4, 5])
def test_join_oneInput(self): def test_join_oneInput(self):
# Test join when only 1 input is given. # Test join when only 1 input is given.
# #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论