提交 42a7adb9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove Unbroadcast Op

上级 a24f5345
......@@ -619,9 +619,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
.. function:: shape_padleft(x, n_ones=1)
Reshape `x` by left padding the shape with `n_ones` 1s. Note that all
this new dimension will be broadcastable. To make them non-broadcastable
see the :func:`unbroadcast`.
Reshape `x` by left padding the shape with `n_ones` 1s.
All new dimensions will be broadcastable.
:param x: variable to be reshaped
:type x: any `TensorVariable` (or compatible)
......@@ -633,9 +632,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
.. function:: shape_padright(x, n_ones=1)
Reshape `x` by right padding the shape with `n_ones` ones. Note that all
this new dimension will be broadcastable. To make them non-broadcastable
see the :func:`unbroadcast`.
Reshape `x` by right padding the shape with `n_ones` ones.
All new dimensions will be broadcastable.
:param x: variable to be reshaped
:type x: any TensorVariable (or compatible)
......@@ -646,9 +644,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
.. function:: shape_padaxis(t, axis)
Reshape `t` by inserting ``1`` at the dimension `axis`. Note that this new
dimension will be broadcastable. To make it non-broadcastable
see the :func:`unbroadcast`.
Reshape `t` by inserting ``1`` at the dimension `axis`.
All new dimensions will be broadcastable.
:type x: any `TensorVariable` (or compatible)
:param x: variable to be reshaped
......
......@@ -292,14 +292,8 @@ def rebuild_collect_shared(
f" shared_var.type={store_into.type},"
f" update_val={update_val}, update_val.type={getattr(update_val, 'type', None)})."
)
err_sug = (
"If the difference is related to the broadcast pattern,"
" you can call the"
" tensor.shape.unbroadcast(var, axis_to_unbroadcast[, ...])"
" function to mask broadcastable dimensions."
)
raise TypeError(err_msg, err_sug)
raise TypeError(err_msg)
assert store_into.type.is_super(update_val.type)
update_d[store_into] = update_val
......
......@@ -26,7 +26,7 @@ from pytensor.graph.op import _NoPythonOp
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from pytensor.graph.type import HasDataType, HasShape
from pytensor.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast
from pytensor.tensor.shape import Reshape, Shape, SpecifyShape
if TYPE_CHECKING:
......@@ -481,7 +481,6 @@ acceptable_ops = (
Shape,
SpecifyShape,
Reshape,
Unbroadcast,
pt.math.Dot,
pt.math.Max,
pt.math.Argmax,
......
......@@ -4,7 +4,7 @@ from pytensor.graph import Constant
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.type import TensorType
......@@ -104,11 +104,3 @@ def jax_funcify_SpecifyShape(op, node, **kwargs):
return x
return specifyshape
@jax_funcify.register(Unbroadcast)
def jax_funcify_Unbroadcast(op, **kwargs):
def unbroadcast(x):
return x
return unbroadcast
......@@ -17,7 +17,6 @@ from pytensor.tensor.basic import (
Split,
TensorFromScalar,
)
from pytensor.tensor.shape import Unbroadcast
@numba_funcify.register(AllocEmpty)
......@@ -232,15 +231,6 @@ def makevector({", ".join(input_names)}):
return numba_basic.numba_njit(makevector_fn)
@numba_funcify.register(Unbroadcast)
def numba_funcify_Unbroadcast(op, **kwargs):
@numba_basic.numba_njit
def unbroadcast(x):
return x
return unbroadcast
@numba_funcify.register(TensorFromScalar)
def numba_funcify_TensorFromScalar(op, **kwargs):
@numba_basic.numba_njit(inline="always")
......
......@@ -2,7 +2,7 @@ import torch
from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@pytorch_funcify.register(Reshape)
......@@ -56,11 +56,3 @@ def pytorch_funcify_SpecifyShape(op, node, **kwargs):
return x
return specifyshape
@pytorch_funcify.register(Unbroadcast)
def pytorch_funcify_Unbroadcast(op, **kwargs):
def unbroadcast(x):
return x
return unbroadcast
......@@ -15,7 +15,7 @@ from pytensor.scan.utils import expand_empty, safe_new, until
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import minimum
from pytensor.tensor.shape import shape_padleft, unbroadcast
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.type import TensorType, integer_dtypes
from pytensor.updates import OrderedUpdates
......@@ -748,7 +748,7 @@ def scan(
# defined in scan utils
sit_sot_scan_inputs.append(
expand_empty(
unbroadcast(shape_padleft(actual_arg), 0),
shape_padleft(actual_arg),
actual_n_steps,
)
)
......@@ -865,13 +865,13 @@ def scan(
if n_fixed_steps in (1, -1):
for pos, inner_out in enumerate(outputs):
# we need to see if we need to pad our sequences with an
# unbroadcastable dimension; case example : we return an
# extra dimension; case example : we return an
# output for which we want all intermediate. If n_steps is 1
# then, if we return the output as given by the innner function
# this will represent only a slice and it will have one
# dimension less.
if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1:
outputs[pos] = unbroadcast(shape_padleft(inner_out), 0)
outputs[pos] = shape_padleft(inner_out)
if not return_list and len(outputs) == 1:
outputs = outputs[0]
......@@ -1002,7 +1002,7 @@ def scan(
sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append(
expand_empty(
unbroadcast(shape_padleft(input.variable), 0),
shape_padleft(input.variable),
actual_n_steps,
)
)
......
......@@ -166,8 +166,7 @@ def check_broadcast(v1, v2):
"axis %d in `output_info`. This can happen if one of the "
"dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using pytensor.tensor."
"{unbroadcast, specify_broadcastable}."
"them consistent, e.g. using pytensor.tensor.specify_broadcastable."
)
size = min(v1.type.ndim, v2.type.ndim)
for n, (b1, b2) in enumerate(
......
......@@ -53,7 +53,6 @@ from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import (
Shape,
Shape_i,
Unbroadcast,
shape,
shape_padaxis,
shape_padleft,
......@@ -334,9 +333,7 @@ def _get_underlying_scalar_constant_value(
if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
op = v.owner.op
max_recur -= 1
if isinstance(
op, Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp
):
if isinstance(op, Alloc | DimShuffle | OutputGuard | DeepCopyOp):
# OutputGuard is only used in debugmode but we
# keep it here to avoid problems with old pickles
v = v.owner.inputs[0]
......@@ -498,14 +495,6 @@ def _get_underlying_scalar_constant_value(
grandparent = leftmost_parent.owner.inputs[0]
gp_shape = grandparent.type.shape
ndim = grandparent.type.ndim
if grandparent.owner and isinstance(
grandparent.owner.op, Unbroadcast
):
ggp_shape = grandparent.owner.inputs[0].type.shape
l = [
_get_underlying_scalar_constant_value(s) for s in ggp_shape
]
gp_shape = tuple(l)
if not (idx < ndim):
msg = (
......
......@@ -42,9 +42,7 @@ from pytensor.tensor.shape import (
Shape,
Shape_i,
SpecifyShape,
Unbroadcast,
specify_shape,
unbroadcast,
)
from pytensor.tensor.subtensor import Subtensor, get_idx_list
from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes
......@@ -1296,78 +1294,3 @@ def local_track_shape_i(fgraph, node):
# structure.
replacement = shape_feature.scheduled[node]
return [shape_feature.shape_of[replacement][node.op.i]]
@register_useless
@register_canonicalize
@register_specialize
@node_rewriter([Unbroadcast])
def local_useless_unbroadcast(fgraph, node):
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern."""
if isinstance(node.op, Unbroadcast):
x = node.inputs[0]
if x.type.ndim == node.outputs[0].type.ndim and all(
s1 == s2
for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape, strict=True)
if s1 == 1 or s2 == 1
):
# No broadcastable flag was modified
# No need to copy over stack trace,
# because x should already have a stack trace.
return [x]
else:
# Keep the flags that modify something
new_axes = tuple(ax for ax in node.op.axes if x.type.shape[ax] == 1)
if new_axes == node.op.axes:
# All flags are useful
return None
else:
r = unbroadcast(x, *new_axes)
# Copy over stacktrace from previous output
copy_stack_trace(node.outputs, r)
return [r]
@register_canonicalize
@register_specialize
@node_rewriter([Unbroadcast])
def local_unbroadcast_lift(fgraph, node):
"""
Lifts `Unbroadcast` through unary Elemwise operations,
and merges consecutive `Unbroadcast`s.
Unbroadcast(Elemwise(x)) => Elemwise(Unbroadcast(x))
Unbroadcast(Unbroadcast(x)) => Unbroadcast(x)
TODO: Implement equivalent Elemwise lift for SpecifyShape
"""
op = node.op
if not isinstance(op, Unbroadcast):
return False
inp = node.inputs[0]
inode = inp.owner
if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1:
if len(fgraph.clients.get(inp, ())) == 1:
unbroadcasted = unbroadcast(inode.inputs[0], *op.axes)
copy_stack_trace(node.outputs, unbroadcasted)
rval = inode.op.make_node(unbroadcasted).outputs
# Copy over stacktrace from previous output (after unbroadcasting)
# and input (after elemwise operation) to new output, because an
# error in the new graph could have been caused by either of the
# two ops.
copy_stack_trace(node.outputs + node.inputs, rval)
return rval
if inode and isinstance(inode.op, Unbroadcast):
# Merge axis of each unbroadcast
axis = tuple(set(inode.op.axes).union(set(op.axes)))
iinput = inode.inputs[0]
rval = [unbroadcast(iinput, *axis)]
# Copy over stacktrace from previous output (after second unbroadcasting)
# and from previous input (after first unbroadcasting) because an error in
# the new graph could have been caused by either of the two Unbroadcast ops.
copy_stack_trace(node.outputs + node.inputs, rval)
return rval
......@@ -59,11 +59,9 @@ from pytensor.tensor.rewriting.basic import (
from pytensor.tensor.shape import (
Shape,
SpecifyShape,
Unbroadcast,
shape_padleft,
shape_tuple,
specify_shape,
unbroadcast,
)
from pytensor.tensor.sharedvar import TensorSharedVariable
from pytensor.tensor.subtensor import (
......@@ -429,7 +427,6 @@ def local_subtensor_lift(fgraph, node):
Handles the following unary ops:
elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all
Unbroadcast(x)[idx] => Unbroadcast(x[idx])
"""
if isinstance(node.op, Subtensor):
......@@ -488,40 +485,6 @@ def local_subtensor_lift(fgraph, node):
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
return [ret]
if isinstance(u.owner.op, Unbroadcast):
# Subtensor might reduce dim., adapt broadcast pattern accordingly
old_axes = u.owner.op.axes
new_axes = []
# loop through indices being subtensor-ed
# i indexes broadcastable pattern before subtensor
# j indexes broadcastable pattern after subtensor
j = 0
for i, x in enumerate(node.op.idx_list):
# if it is not a slice, it will reduce the dimension, should
# not appear in the broascastable dimensions
if isinstance(x, slice):
if i in old_axes:
new_axes.append(j)
j += 1
# now keep the broadcastable pattern of all
# items not appearing in subtensor list
for i in range(len(node.op.idx_list), len(u.broadcastable)):
if i in old_axes:
new_axes.append(j)
j += 1
subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
# Copy over previous output stacktrace
copy_stack_trace(node.outputs[0], subt_x)
rbcast_subt_x = unbroadcast(subt_x, *new_axes)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x)
return [rbcast_subt_x]
@register_canonicalize
@register_specialize
......
......@@ -18,7 +18,6 @@ from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor import basic as ptb
from pytensor.tensor.elemwise import get_normalized_batch_axes
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor
from pytensor.tensor.type_other import NoneConst, NoneTypeT
......@@ -1008,118 +1007,3 @@ def specify_broadcastable(x, *axes):
axes = normalize_axis_tuple(axes, x.type.ndim)
shape_info = [1 if i in axes else s for i, s in enumerate(x.type.shape)]
return specify_shape(x, shape_info)
class Unbroadcast(COp):
"""
Mask static broadcastable dimensions of input as `None`
See Also
--------
unbroadcast <pytensor.tensor.shape.unbroadcast>
Examples
--------
``Unbroadcast((1,))(x)`` would make `x` second static dimension be `None`
"""
view_map = {0: [0]}
_f16_ok = True
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version: dict = {}
check_input = False
__props__ = ("axes",)
_f16_ok = True
def __init__(self, *axis):
# Sort them to make sure we merge all possible case.
items = tuple(sorted(axis))
self.axes = items
for axis in self.axes:
if not isinstance(axis, np.integer | int):
raise TypeError(f"Unbroadcast needs integer axes. Got {axis}")
def __str__(self):
return f"{self.__class__.__name__}{{{','.join(str(i) for i in self.axes)}}}"
def make_node(self, x):
x = as_tensor_variable(x)
if x.type.ndim <= max(self.axes):
raise ValueError("Trying to unbroadcast of non-existent dimension")
shape = [
None if (sh == 1 and i in self.axes) else sh
for i, sh in enumerate(x.type.shape)
]
return Apply(self, [x], [x.type.clone(shape=shape)()])
def perform(self, node, inp, out_):
(x,) = inp
(out,) = out_
out[0] = x
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
# restore the broadcasting pattern of the input
return [specify_shape(gz, x.type.shape)]
def infer_shape(self, fgraph, node, ishapes):
assert len(ishapes) == 1
return [tuple(ishapes[0])]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return self(*eval_points, return_list=True)
def c_code(self, node, nodename, inp, out, sub):
(iname,) = inp
(oname,) = out
return f"""
Py_XDECREF({oname});
{oname} = {iname};
Py_XINCREF({oname});
"""
def c_code_cache_version(self):
return (3,)
def unbroadcast(x, *axes):
"""
Mask static broadcastable dimensions of input as `None`
Parameters
----------
x : tensor_like
Input pytensor tensor.
axis : an int or an iterable object such as list or tuple of int values
The broadcastable dimensions of x that should be unbroadcasted.
Returns
-------
tensor
A pytensor tensor, with static broadcastable dimensions masked as `None`
"""
x = as_tensor_variable(x)
unbroadcasted_axes = [axis for axis in axes if x.type.shape[axis] == 1]
if not unbroadcasted_axes:
return x
return Unbroadcast(*unbroadcasted_axes)(x)
@_vectorize_node.register(Unbroadcast)
def _vectorize_unbroadcast(
op: Unbroadcast, node: Apply, batch_x: TensorVariable
) -> Apply:
core_ndim = node.inputs[0].type.ndim
batch_ndim = batch_x.type.ndim - core_ndim
batch_axes = get_normalized_batch_axes(op.axes, core_ndim, batch_ndim)
return cast(Apply, unbroadcast(batch_x, *batch_axes).owner)
......@@ -4,7 +4,7 @@ import pytest
import pytensor.tensor as pt
from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
from pytensor.tensor.shape import Shape, Shape_i, reshape
from pytensor.tensor.type import iscalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
......@@ -70,10 +70,6 @@ def test_jax_compile_ops():
compare_jax_and_py([], [x], [])
x_np = np.zeros((20, 1, 1))
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
compare_jax_and_py([], [x], [])
x = ViewOp()(pt.as_tensor_variable(x_np))
compare_jax_and_py([], [x], [])
......@@ -7,7 +7,6 @@ import pytensor.tensor.basic as ptb
from pytensor import config, function
from pytensor.compile import get_mode
from pytensor.scalar import Add
from pytensor.tensor.shape import Unbroadcast
from tests.link.numba.test_basic import (
compare_numba_and_py,
compare_shape_dtype,
......@@ -75,16 +74,6 @@ def test_ScalarFromTensor():
)
def test_Unbroadcast():
v, v_test = pt.row(), np.array([[1.0, 2.0]], dtype=config.floatX)
g = Unbroadcast(0)(v)
compare_numba_and_py(
[v],
g,
[v_test],
)
@pytest.mark.parametrize(
"vals, dtype",
[
......
......@@ -2,7 +2,7 @@ import numpy as np
import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
from pytensor.tensor.shape import Shape, Shape_i, reshape
from pytensor.tensor.type import iscalar, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py
......@@ -50,10 +50,3 @@ def test_pytorch_Reshape_dynamic():
compare_pytorch_and_py(
[a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]
)
def test_pytorch_unbroadcast():
x_np = np.zeros((20, 1, 1))
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
compare_pytorch_and_py([], [x], [])
......@@ -36,32 +36,31 @@ def test_debugprint_sitsot():
│ │ │ │ │ ├─ k [id D]
│ │ │ │ │ └─ Subtensor{i} [id H]
│ │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id K]
│ │ │ │ │ │ └─ Second [id L]
│ │ │ │ │ │ ├─ A [id M]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
│ │ │ │ │ │ └─ 1.0 [id O]
│ │ │ │ │ └─ 0 [id P]
│ │ │ │ └─ Subtensor{i} [id Q]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id J]
│ │ │ │ │ │ └─ Second [id K]
│ │ │ │ │ │ ├─ A [id L]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id M]
│ │ │ │ │ │ └─ 1.0 [id N]
│ │ │ │ │ └─ 0 [id O]
│ │ │ │ └─ Subtensor{i} [id P]
│ │ │ │ ├─ Shape [id I]
│ │ │ │ │ └─ ···
│ │ │ │ └─ 1 [id R]
│ │ │ ├─ Unbroadcast{0} [id J]
│ │ │ │ └─ 1 [id Q]
│ │ │ ├─ ExpandDims{axis=0} [id J]
│ │ │ │ └─ ···
│ │ │ └─ ScalarFromTensor [id S]
│ │ │ └─ ScalarFromTensor [id R]
│ │ │ └─ Subtensor{i} [id H]
│ │ │ └─ ···
│ │ └─ A [id M] (outer_in_non_seqs-0)
│ └─ 1 [id T]
└─ -1 [id U]
│ │ └─ A [id L] (outer_in_non_seqs-0)
│ └─ 1 [id S]
└─ -1 [id T]
Inner graphs:
Scan{scan_fn, while_loop=False, inplace=none} [id C]
← Mul [id V] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id W] -> [id E] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id X] -> [id M] (inner_in_non_seqs-0)
← Mul [id U] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id V] -> [id E] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id W] -> [id L] (inner_in_non_seqs-0)
"""
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
......@@ -94,32 +93,31 @@ def test_debugprint_sitsot_no_extra_info():
│ │ │ │ │ ├─ k [id D]
│ │ │ │ │ └─ Subtensor{i} [id H]
│ │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id K]
│ │ │ │ │ │ └─ Second [id L]
│ │ │ │ │ │ ├─ A [id M]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
│ │ │ │ │ │ └─ 1.0 [id O]
│ │ │ │ │ └─ 0 [id P]
│ │ │ │ └─ Subtensor{i} [id Q]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id J]
│ │ │ │ │ │ └─ Second [id K]
│ │ │ │ │ │ ├─ A [id L]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id M]
│ │ │ │ │ │ └─ 1.0 [id N]
│ │ │ │ │ └─ 0 [id O]
│ │ │ │ └─ Subtensor{i} [id P]
│ │ │ │ ├─ Shape [id I]
│ │ │ │ │ └─ ···
│ │ │ │ └─ 1 [id R]
│ │ │ ├─ Unbroadcast{0} [id J]
│ │ │ │ └─ 1 [id Q]
│ │ │ ├─ ExpandDims{axis=0} [id J]
│ │ │ │ └─ ···
│ │ │ └─ ScalarFromTensor [id S]
│ │ │ └─ ScalarFromTensor [id R]
│ │ │ └─ Subtensor{i} [id H]
│ │ │ └─ ···
│ │ └─ A [id M]
│ └─ 1 [id T]
└─ -1 [id U]
│ │ └─ A [id L]
│ └─ 1 [id S]
└─ -1 [id T]
Inner graphs:
Scan{scan_fn, while_loop=False, inplace=none} [id C]
← Mul [id V]
├─ *0-<Vector(float64, shape=(?,))> [id W] -> [id E]
└─ *1-<Vector(float64, shape=(?,))> [id X] -> [id M]
← Mul [id U]
├─ *0-<Vector(float64, shape=(?,))> [id V] -> [id E]
└─ *1-<Vector(float64, shape=(?,))> [id W] -> [id L]
"""
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
......@@ -278,32 +276,31 @@ def test_debugprint_nested_scans():
│ │ │ │ │ │ ├─ *3-<Scalar(int32, shape=())> [id BF] -> [id X] (inner_in_non_seqs-1)
│ │ │ │ │ │ └─ Subtensor{i} [id BJ]
│ │ │ │ │ │ ├─ Shape [id BK]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BM]
│ │ │ │ │ │ │ └─ Second [id BN]
│ │ │ │ │ │ │ ├─ *2-<Vector(float64, shape=(?,))> [id BO] -> [id W] (inner_in_non_seqs-0)
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP]
│ │ │ │ │ │ │ └─ 1.0 [id BQ]
│ │ │ │ │ │ └─ 0 [id BR]
│ │ │ │ │ └─ Subtensor{i} [id BS]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BL]
│ │ │ │ │ │ │ └─ Second [id BM]
│ │ │ │ │ │ │ ├─ *2-<Vector(float64, shape=(?,))> [id BN] -> [id W] (inner_in_non_seqs-0)
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO]
│ │ │ │ │ │ │ └─ 1.0 [id BP]
│ │ │ │ │ │ └─ 0 [id BQ]
│ │ │ │ │ └─ Subtensor{i} [id BR]
│ │ │ │ │ ├─ Shape [id BK]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ 1 [id BT]
│ │ │ │ ├─ Unbroadcast{0} [id BL]
│ │ │ │ │ └─ 1 [id BS]
│ │ │ │ ├─ ExpandDims{axis=0} [id BL]
│ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarFromTensor [id BU]
│ │ │ │ └─ ScalarFromTensor [id BT]
│ │ │ │ └─ Subtensor{i} [id BJ]
│ │ │ │ └─ ···
│ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ └─ 1 [id BV]
│ └─ -1 [id BW]
└─ ExpandDims{axis=0} [id BX]
└─ *1-<Scalar(int64, shape=())> [id BY] -> [id U] (inner_in_seqs-1)
│ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BN] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ └─ 1 [id BU]
│ └─ -1 [id BV]
└─ ExpandDims{axis=0} [id BW]
└─ *1-<Scalar(int64, shape=())> [id BX] -> [id U] (inner_in_seqs-1)
Scan{scan_fn, while_loop=False, inplace=none} [id BE]
← Mul [id BZ] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id CA] -> [id BG] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CB] -> [id BO] (inner_in_non_seqs-0)
← Mul [id BY] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id BZ] -> [id BG] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CA] -> [id BN] (inner_in_non_seqs-0)
"""
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
......@@ -375,34 +372,33 @@ def test_debugprint_nested_scans():
│ │ │ │ │ │ ├─ *3-<Scalar(int32, shape=())> [id BB] (inner_in_non_seqs-1)
│ │ │ │ │ │ └─ Subtensor{i} [id BL]
│ │ │ │ │ │ ├─ Shape [id BM]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO]
│ │ │ │ │ │ │ └─ Second [id BP]
│ │ │ │ │ │ │ ├─ *2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0)
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BQ]
│ │ │ │ │ │ │ └─ 1.0 [id BR]
│ │ │ │ │ │ └─ 0 [id BS]
│ │ │ │ │ └─ Subtensor{i} [id BT]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BN]
│ │ │ │ │ │ │ └─ Second [id BO]
│ │ │ │ │ │ │ ├─ *2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0)
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP]
│ │ │ │ │ │ │ └─ 1.0 [id BQ]
│ │ │ │ │ │ └─ 0 [id BR]
│ │ │ │ │ └─ Subtensor{i} [id BS]
│ │ │ │ │ ├─ Shape [id BM]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ 1 [id BU]
│ │ │ │ ├─ Unbroadcast{0} [id BN]
│ │ │ │ │ └─ 1 [id BT]
│ │ │ │ ├─ ExpandDims{axis=0} [id BN]
│ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarFromTensor [id BV]
│ │ │ │ └─ ScalarFromTensor [id BU]
│ │ │ │ └─ Subtensor{i} [id BL]
│ │ │ │ └─ ···
│ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ └─ 1 [id BW]
│ └─ -1 [id BX]
└─ ExpandDims{axis=0} [id BY]
│ │ └─ 1 [id BV]
│ └─ -1 [id BW]
└─ ExpandDims{axis=0} [id BX]
└─ *1-<Scalar(int64, shape=())> [id Z] (inner_in_seqs-1)
Scan{scan_fn, while_loop=False, inplace=none} [id BH]
→ *0-<Vector(float64, shape=(?,))> [id BZ] -> [id BI] (inner_in_sit_sot-0)
→ *1-<Vector(float64, shape=(?,))> [id CA] -> [id BA] (inner_in_non_seqs-0)
← Mul [id CB] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id BZ] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CA] (inner_in_non_seqs-0)
→ *0-<Vector(float64, shape=(?,))> [id BY] -> [id BI] (inner_in_sit_sot-0)
→ *1-<Vector(float64, shape=(?,))> [id BZ] -> [id BA] (inner_in_non_seqs-0)
← Mul [id CA] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id BY] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id BZ] (inner_in_non_seqs-0)
"""
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
......@@ -516,105 +512,104 @@ def test_debugprint_mitmot():
│ │ │ │ │ │ │ ├─ k [id G]
│ │ │ │ │ │ │ └─ Subtensor{i} [id K]
│ │ │ │ │ │ │ ├─ Shape [id L]
│ │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M]
│ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
│ │ │ │ │ │ │ │ └─ Second [id O]
│ │ │ │ │ │ │ │ ├─ A [id P]
│ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id Q]
│ │ │ │ │ │ │ │ └─ 1.0 [id R]
│ │ │ │ │ │ │ └─ 0 [id S]
│ │ │ │ │ │ └─ Subtensor{i} [id T]
│ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id M]
│ │ │ │ │ │ │ │ └─ Second [id N]
│ │ │ │ │ │ │ │ ├─ A [id O]
│ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id P]
│ │ │ │ │ │ │ │ └─ 1.0 [id Q]
│ │ │ │ │ │ │ └─ 0 [id R]
│ │ │ │ │ │ └─ Subtensor{i} [id S]
│ │ │ │ │ │ ├─ Shape [id L]
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ 1 [id U]
│ │ │ │ │ ├─ Unbroadcast{0} [id M]
│ │ │ │ │ │ └─ 1 [id T]
│ │ │ │ │ ├─ ExpandDims{axis=0} [id M]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ ScalarFromTensor [id V]
│ │ │ │ │ └─ ScalarFromTensor [id U]
│ │ │ │ │ └─ Subtensor{i} [id K]
│ │ │ │ │ └─ ···
│ │ │ │ └─ A [id P] (outer_in_non_seqs-0)
│ │ │ └─ 0 [id W]
│ │ └─ 1 [id X]
│ ├─ Subtensor{:stop} [id Y] (outer_in_seqs-0)
│ │ ├─ Subtensor{::step} [id Z]
│ │ │ ├─ Subtensor{:stop} [id BA]
│ │ │ │ └─ A [id O] (outer_in_non_seqs-0)
│ │ │ └─ 0 [id V]
│ │ └─ 1 [id W]
│ ├─ Subtensor{:stop} [id X] (outer_in_seqs-0)
│ │ ├─ Subtensor{::step} [id Y]
│ │ │ ├─ Subtensor{:stop} [id Z]
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ···
│ │ │ │ └─ -1 [id BB]
│ │ │ └─ -1 [id BC]
│ │ └─ ScalarFromTensor [id BD]
│ │ │ │ └─ -1 [id BA]
│ │ │ └─ -1 [id BB]
│ │ └─ ScalarFromTensor [id BC]
│ │ └─ Sub [id C]
│ │ └─ ···
│ ├─ Subtensor{:stop} [id BE] (outer_in_seqs-1)
│ │ ├─ Subtensor{:stop} [id BF]
│ │ │ ├─ Subtensor{::step} [id BG]
│ ├─ Subtensor{:stop} [id BD] (outer_in_seqs-1)
│ │ ├─ Subtensor{:stop} [id BE]
│ │ │ ├─ Subtensor{::step} [id BF]
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ···
│ │ │ │ └─ -1 [id BH]
│ │ │ └─ -1 [id BI]
│ │ └─ ScalarFromTensor [id BJ]
│ │ │ │ └─ -1 [id BG]
│ │ │ └─ -1 [id BH]
│ │ └─ ScalarFromTensor [id BI]
│ │ └─ Sub [id C]
│ │ └─ ···
│ ├─ Subtensor{::step} [id BK] (outer_in_mit_mot-0)
│ │ ├─ IncSubtensor{start:} [id BL]
│ │ │ ├─ Second [id BM]
│ ├─ Subtensor{::step} [id BJ] (outer_in_mit_mot-0)
│ │ ├─ IncSubtensor{start:} [id BK]
│ │ │ ├─ Second [id BL]
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ···
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BN]
│ │ │ │ └─ 0.0 [id BO]
│ │ │ ├─ IncSubtensor{i} [id BP]
│ │ │ │ ├─ Second [id BQ]
│ │ │ │ │ ├─ Subtensor{start:} [id BR]
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BM]
│ │ │ │ └─ 0.0 [id BN]
│ │ │ ├─ IncSubtensor{i} [id BO]
│ │ │ │ ├─ Second [id BP]
│ │ │ │ │ ├─ Subtensor{start:} [id BQ]
│ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ 1 [id BS]
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BT]
│ │ │ │ │ └─ 0.0 [id BU]
│ │ │ │ ├─ Second [id BV]
│ │ │ │ │ ├─ Subtensor{i} [id BW]
│ │ │ │ │ │ ├─ Subtensor{start:} [id BR]
│ │ │ │ │ │ └─ 1 [id BR]
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BS]
│ │ │ │ │ └─ 0.0 [id BT]
│ │ │ │ ├─ Second [id BU]
│ │ │ │ │ ├─ Subtensor{i} [id BV]
│ │ │ │ │ │ ├─ Subtensor{start:} [id BQ]
│ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ -1 [id BX]
│ │ │ │ │ └─ ExpandDims{axis=0} [id BY]
│ │ │ │ │ └─ Second [id BZ]
│ │ │ │ │ ├─ Sum{axes=None} [id CA]
│ │ │ │ │ │ └─ Subtensor{i} [id BW]
│ │ │ │ │ │ └─ -1 [id BW]
│ │ │ │ │ └─ ExpandDims{axis=0} [id BX]
│ │ │ │ │ └─ Second [id BY]
│ │ │ │ │ ├─ Sum{axes=None} [id BZ]
│ │ │ │ │ │ └─ Subtensor{i} [id BV]
│ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ 1.0 [id CB]
│ │ │ │ └─ -1 [id BX]
│ │ │ └─ 1 [id BS]
│ │ └─ -1 [id CC]
│ ├─ Alloc [id CD] (outer_in_sit_sot-0)
│ │ ├─ 0.0 [id CE]
│ │ ├─ Add [id CF]
│ │ │ │ │ └─ 1.0 [id CA]
│ │ │ │ └─ -1 [id BW]
│ │ │ └─ 1 [id BR]
│ │ └─ -1 [id CB]
│ ├─ Alloc [id CC] (outer_in_sit_sot-0)
│ │ ├─ 0.0 [id CD]
│ │ ├─ Add [id CE]
│ │ │ ├─ Sub [id C]
│ │ │ │ └─ ···
│ │ │ └─ 1 [id CG]
│ │ └─ Subtensor{i} [id CH]
│ │ ├─ Shape [id CI]
│ │ │ └─ A [id P]
│ │ └─ 0 [id CJ]
│ └─ A [id P] (outer_in_non_seqs-0)
└─ -1 [id CK]
│ │ │ └─ 1 [id CF]
│ │ └─ Subtensor{i} [id CG]
│ │ ├─ Shape [id CH]
│ │ │ └─ A [id O]
│ │ └─ 0 [id CI]
│ └─ A [id O] (outer_in_non_seqs-0)
└─ -1 [id CJ]
Inner graphs:
Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B]
← Add [id CL] (inner_out_mit_mot-0-0)
├─ Mul [id CM]
│ ├─ *2-<Vector(float64, shape=(?,))> [id CN] -> [id BK] (inner_in_mit_mot-0-0)
│ └─ *5-<Vector(float64, shape=(?,))> [id CO] -> [id P] (inner_in_non_seqs-0)
└─ *3-<Vector(float64, shape=(?,))> [id CP] -> [id BK] (inner_in_mit_mot-0-1)
← Add [id CQ] (inner_out_sit_sot-0)
├─ Mul [id CR]
│ ├─ *2-<Vector(float64, shape=(?,))> [id CN] -> [id BK] (inner_in_mit_mot-0-0)
│ └─ *0-<Vector(float64, shape=(?,))> [id CS] -> [id Y] (inner_in_seqs-0)
└─ *4-<Vector(float64, shape=(?,))> [id CT] -> [id CD] (inner_in_sit_sot-0)
← Add [id CK] (inner_out_mit_mot-0-0)
├─ Mul [id CL]
│ ├─ *2-<Vector(float64, shape=(?,))> [id CM] -> [id BJ] (inner_in_mit_mot-0-0)
│ └─ *5-<Vector(float64, shape=(?,))> [id CN] -> [id O] (inner_in_non_seqs-0)
└─ *3-<Vector(float64, shape=(?,))> [id CO] -> [id BJ] (inner_in_mit_mot-0-1)
← Add [id CP] (inner_out_sit_sot-0)
├─ Mul [id CQ]
│ ├─ *2-<Vector(float64, shape=(?,))> [id CM] -> [id BJ] (inner_in_mit_mot-0-0)
│ └─ *0-<Vector(float64, shape=(?,))> [id CR] -> [id X] (inner_in_seqs-0)
└─ *4-<Vector(float64, shape=(?,))> [id CS] -> [id CC] (inner_in_sit_sot-0)
Scan{scan_fn, while_loop=False, inplace=none} [id F]
← Mul [id CU] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id CS] -> [id H] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CV] -> [id P] (inner_in_non_seqs-0)
← Mul [id CT] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id CR] -> [id H] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CU] -> [id O] (inner_in_non_seqs-0)
"""
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
......
......@@ -1621,7 +1621,7 @@ class TestSaveMem:
np.testing.assert_allclose(f(x0=0, seq=test_seq, n_steps=200), 100)
np.testing.assert_allclose(f(x0=1, seq=test_seq, n_steps=20), 21)
np.testing.assert_allclose(f(x0=np.e, seq=test_seq, n_steps=1), np.e + 1)
with pytest.raises(AssertionError, match="n_steps > 0"):
with pytest.raises((AssertionError, IndexError)):
f(x0=0, seq=test_seq, n_steps=0)
# Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly.
......
......@@ -77,9 +77,7 @@ from pytensor.tensor.shape import (
Reshape,
Shape_i,
SpecifyShape,
Unbroadcast,
specify_shape,
unbroadcast,
)
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor1,
......@@ -558,48 +556,6 @@ class TestTile:
f(data)
class TestUnbroadcast:
def setup_method(self):
self.mode = get_default_mode().including("canonicalize")
def test_local_useless_unbroadcast(self):
x1 = tensor(dtype="float64", shape=(1, 2))
x2 = tensor(dtype="float64", shape=(2, 1))
unbroadcast_op = Unbroadcast(0)
f = function([x1], unbroadcast_op(x1), mode=self.mode)
assert (
sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort())
== 1
)
f = function([x2], unbroadcast_op(x2), mode=self.mode)
assert (
sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort())
== 0
)
def test_local_unbroadcast_lift(self):
x = tensor(dtype="float64", shape=(1, 1))
y = unbroadcast(pt.exp(unbroadcast(x, 0)), 1)
assert (
sum(
isinstance(node.op, Unbroadcast)
for node in FunctionGraph([x], [y], copy_inputs=False).toposort()
)
== 2
)
f = function([x], y, mode=self.mode)
assert (
sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort())
== 1
)
np.testing.assert_almost_equal(f([[1]]), np.exp([[1]]))
class TestUselessElemwise:
def setup_method(self):
self.mode = get_default_mode().including("canonicalize", "local_fill_to_alloc")
......
......@@ -28,7 +28,6 @@ from pytensor.tensor.rewriting.subtensor import (
)
from pytensor.tensor.shape import (
SpecifyShape,
Unbroadcast,
_shape,
shape,
specify_shape,
......@@ -55,7 +54,6 @@ from pytensor.tensor.type import (
lscalar,
lscalars,
matrix,
row,
scalar,
tensor,
tensor3,
......@@ -921,64 +919,6 @@ class TestLocalSubtensorLift:
assert len(prog) == 2
f([1, 2, 3], 4) # let debugmode test something
def test_basic_8(self):
# Test that Subtensor(Unbroadcast(x)) gets optimized into
# Unbroadcast(Subtensor(x)).
# test basic case
x = row("x")
xval = np.random.random((1, 10)).astype(config.floatX)
assert x.broadcastable == (True, False)
newx = Unbroadcast(0)(x)
assert newx.broadcastable == (False, False)
f1 = function([x], newx[:2, :5], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f1, ops_to_check=[Subtensor, Unbroadcast])
prog = f1.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Unbroadcast)
assert (f1(xval) == xval[:2, :5]).all()
# corner case 1: Unbroadcast changes dims which are dropped through subtensor
y = tensor(dtype="float64", shape=(1, 10, 1, 3), name="x")
yval = np.random.random((1, 10, 1, 3)).astype(config.floatX)
assert y.broadcastable == (True, False, True, False)
newy = Unbroadcast(0, 2)(y)
assert newy.broadcastable == (False, False, False, False)
f2 = function([y], newy[:, 3, 0, :], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f2, ops_to_check=[Subtensor, Unbroadcast])
prog = f2.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Unbroadcast)
assert (f2(yval) == yval[:, 3, 0, :]).all()
# corner case 2: subtensor idx_list is shorter than resulting broadcast pattern
f3 = function([y], newy[:, 3, 0], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f3, ops_to_check=[Subtensor, Unbroadcast])
prog = f3.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Unbroadcast)
assert (f3(yval) == yval[:, 3, 0]).all()
# corner case 3: subtensor idx_list is shorter than Unbroadcast.axis
z = tensor(dtype="float64", shape=(4, 10, 3, 1), name="x")
zval = np.random.random((4, 10, 3, 1)).astype(config.floatX)
assert z.broadcastable == (False, False, False, True)
newz = Unbroadcast(3)(z)
assert newz.broadcastable == (False, False, False, False)
f4 = function([z], newz[:, 3, 0], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f4, ops_to_check=[Subtensor, Unbroadcast])
prog = f4.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Unbroadcast)
assert (f4(zval) == zval[:, 3, 0]).all()
class TestLocalSubtensorMerge:
def setup_method(self):
......
......@@ -287,7 +287,7 @@ TestAlloc13GradBroadcast = makeBroadcastTester(
),
)
# unbroadcast a row to a matrix
# broadcast a row to a matrix
TestAllocb1GradBroadcast = makeBroadcastTester(
name="Allocb1GradTester",
op=lambda x: alloc(x, s1, s2),
......@@ -299,7 +299,7 @@ TestAllocb1GradBroadcast = makeBroadcastTester(
),
)
# unbroadcast a row to a tensor3
# broadcast a row to a tensor3
TestAllocb2GradBroadcast = makeBroadcastTester(
name="Allocb2GradTester",
op=lambda x: alloc(x, s1, s2, s3),
......@@ -311,7 +311,7 @@ TestAllocb2GradBroadcast = makeBroadcastTester(
),
)
# unbroadcast a col to a matrix
# broadcast a col to a matrix
TestAllocb3GradBroadcast = makeBroadcastTester(
name="Allocb3GradTester",
op=lambda x: alloc(x, s1, s2),
......@@ -323,7 +323,7 @@ TestAllocb3GradBroadcast = makeBroadcastTester(
),
)
# unbroadcast a col to a tensor3
# broadcast a col to a tensor3
TestAllocb4GradBroadcast = makeBroadcastTester(
name="Allocb4GradTester",
op=lambda x: alloc(x, s1, s2, s3),
......@@ -336,7 +336,7 @@ TestAllocb4GradBroadcast = makeBroadcastTester(
)
# Partial unbroadcast of a dimshuffled input
# Partial broadcast of a dimshuffled input
TestAllocDimshuffleGradBroadcast = makeBroadcastTester(
name="Allocb4GradTester",
op=lambda x: alloc(x.dimshuffle("x", "x", 0), 1, s2, s3),
......
......@@ -19,14 +19,12 @@ from pytensor.tensor.shape import (
Shape,
Shape_i,
SpecifyShape,
Unbroadcast,
_specify_shape,
reshape,
shape,
shape_tuple,
specify_broadcastable,
specify_shape,
unbroadcast,
)
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import (
......@@ -696,66 +694,6 @@ def test_get_vector_length():
assert get_vector_length(x) == 10
class TestUnbroadcast:
def test_basic(self):
x = matrix()
assert unbroadcast(x, 0) is x
assert unbroadcast(x, 1) is x
assert unbroadcast(x, 1, 0) is x
assert unbroadcast(x, 0, 1) is x
x = row()
assert unbroadcast(x, 0) is not x
assert unbroadcast(x, 1) is x
assert unbroadcast(x, 1, 0) is not x
assert unbroadcast(x, 0, 1) is not x
assert unbroadcast(unbroadcast(x, 0), 0).owner.inputs[0] is x
def test_infer_shape(self):
x = matrix()
y = unbroadcast(x, 0)
f = pytensor.function([x], y.shape)
assert (f(np.zeros((2, 5), dtype=config.floatX)) == [2, 5]).all()
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert len(topo) == 3
assert isinstance(topo[0].op, Shape_i)
assert isinstance(topo[1].op, Shape_i)
assert isinstance(topo[2].op, MakeVector)
x = row()
y = unbroadcast(x, 0)
f = pytensor.function([x], y.shape)
assert (f(np.zeros((1, 5), dtype=config.floatX)) == [1, 5]).all()
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert len(topo) == 2
assert isinstance(topo[0].op, Shape_i)
assert isinstance(topo[1].op, MakeVector)
def test_error_checks(self):
with pytest.raises(TypeError, match="needs integer axes"):
Unbroadcast(0.0)
with pytest.raises(ValueError, match="^Trying to unbroadcast"):
Unbroadcast(1)(vector())
class TestUnbroadcastInferShape(utt.InferShapeTester):
def test_basic(self):
rng = np.random.default_rng(3453)
adtens4 = tensor(dtype="float64", shape=(1, 1, 1, None))
adtens4_val = rng.random((1, 1, 1, 3)).astype(config.floatX)
self._compile_and_check(
[adtens4],
[Unbroadcast(0, 2)(adtens4)],
[adtens4_val],
Unbroadcast,
warn=False,
)
def test_shape_tuple():
x = Variable(MyType2(), None, None)
assert shape_tuple(x) == ()
......@@ -882,16 +820,3 @@ class TestVectorize:
match="Invalid number of shape arguments passed into vectorize node of SpecifyShape",
):
vectorize_node(node, tns, *(5, 3, 2, x))
def test_unbroadcast(self):
mat = tensor(
shape=(
1,
1,
)
)
tns = tensor(shape=(4, 1, 1, 1))
node = unbroadcast(mat, 0).owner
vect_node = vectorize_node(node, tns)
assert equal_computations(vect_node.outputs, [unbroadcast(tns, 2)])
......@@ -28,7 +28,6 @@ from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor.math import argmax, dot
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.shape import unbroadcast
from pytensor.tensor.type import matrix, vector
from tests import unittest_tools as utt
......@@ -252,13 +251,6 @@ class TestRopLop(RopLopChecker):
# vector
self.check_rop_lop(self.x[:4].dimshuffle("x", 0).sum(axis=0), (4,))
def test_unbroadcast(self):
# I need the sum, because the setup expects the output to be a
# vector
self.check_rop_lop(
unbroadcast(self.x[:4].dimshuffle("x", 0), 0).sum(axis=1), (1,)
)
def test_join(self):
tv = np.asarray(self.rng.uniform(size=(10,)), pytensor.config.floatX)
t = pytensor.shared(tv)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论