提交 42a7adb9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove Unbroadcast Op

上级 a24f5345
...@@ -619,9 +619,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`. ...@@ -619,9 +619,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
.. function:: shape_padleft(x, n_ones=1) .. function:: shape_padleft(x, n_ones=1)
Reshape `x` by left padding the shape with `n_ones` 1s. Note that all Reshape `x` by left padding the shape with `n_ones` 1s.
this new dimension will be broadcastable. To make them non-broadcastable All new dimensions will be broadcastable.
see the :func:`unbroadcast`.
:param x: variable to be reshaped :param x: variable to be reshaped
:type x: any `TensorVariable` (or compatible) :type x: any `TensorVariable` (or compatible)
...@@ -633,9 +632,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`. ...@@ -633,9 +632,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
.. function:: shape_padright(x, n_ones=1) .. function:: shape_padright(x, n_ones=1)
Reshape `x` by right padding the shape with `n_ones` ones. Note that all Reshape `x` by right padding the shape with `n_ones` ones.
this new dimension will be broadcastable. To make them non-broadcastable All new dimensions will be broadcastable.
see the :func:`unbroadcast`.
:param x: variable to be reshaped :param x: variable to be reshaped
:type x: any TensorVariable (or compatible) :type x: any TensorVariable (or compatible)
...@@ -646,9 +644,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`. ...@@ -646,9 +644,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
.. function:: shape_padaxis(t, axis) .. function:: shape_padaxis(t, axis)
Reshape `t` by inserting ``1`` at the dimension `axis`. Note that this new Reshape `t` by inserting ``1`` at the dimension `axis`.
dimension will be broadcastable. To make it non-broadcastable All new dimensions will be broadcastable.
see the :func:`unbroadcast`.
:type x: any `TensorVariable` (or compatible) :type x: any `TensorVariable` (or compatible)
:param x: variable to be reshaped :param x: variable to be reshaped
......
...@@ -292,14 +292,8 @@ def rebuild_collect_shared( ...@@ -292,14 +292,8 @@ def rebuild_collect_shared(
f" shared_var.type={store_into.type}," f" shared_var.type={store_into.type},"
f" update_val={update_val}, update_val.type={getattr(update_val, 'type', None)})." f" update_val={update_val}, update_val.type={getattr(update_val, 'type', None)})."
) )
err_sug = (
"If the difference is related to the broadcast pattern,"
" you can call the"
" tensor.shape.unbroadcast(var, axis_to_unbroadcast[, ...])"
" function to mask broadcastable dimensions."
)
raise TypeError(err_msg, err_sug) raise TypeError(err_msg)
assert store_into.type.is_super(update_val.type) assert store_into.type.is_super(update_val.type)
update_d[store_into] = update_val update_d[store_into] = update_val
......
...@@ -26,7 +26,7 @@ from pytensor.graph.op import _NoPythonOp ...@@ -26,7 +26,7 @@ from pytensor.graph.op import _NoPythonOp
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from pytensor.graph.type import HasDataType, HasShape from pytensor.graph.type import HasDataType, HasShape
from pytensor.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast from pytensor.tensor.shape import Reshape, Shape, SpecifyShape
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -481,7 +481,6 @@ acceptable_ops = ( ...@@ -481,7 +481,6 @@ acceptable_ops = (
Shape, Shape,
SpecifyShape, SpecifyShape,
Reshape, Reshape,
Unbroadcast,
pt.math.Dot, pt.math.Dot,
pt.math.Max, pt.math.Max,
pt.math.Argmax, pt.math.Argmax,
......
...@@ -4,7 +4,7 @@ from pytensor.graph import Constant ...@@ -4,7 +4,7 @@ from pytensor.graph import Constant
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
...@@ -104,11 +104,3 @@ def jax_funcify_SpecifyShape(op, node, **kwargs): ...@@ -104,11 +104,3 @@ def jax_funcify_SpecifyShape(op, node, **kwargs):
return x return x
return specifyshape return specifyshape
@jax_funcify.register(Unbroadcast)
def jax_funcify_Unbroadcast(op, **kwargs):
def unbroadcast(x):
return x
return unbroadcast
...@@ -17,7 +17,6 @@ from pytensor.tensor.basic import ( ...@@ -17,7 +17,6 @@ from pytensor.tensor.basic import (
Split, Split,
TensorFromScalar, TensorFromScalar,
) )
from pytensor.tensor.shape import Unbroadcast
@numba_funcify.register(AllocEmpty) @numba_funcify.register(AllocEmpty)
...@@ -232,15 +231,6 @@ def makevector({", ".join(input_names)}): ...@@ -232,15 +231,6 @@ def makevector({", ".join(input_names)}):
return numba_basic.numba_njit(makevector_fn) return numba_basic.numba_njit(makevector_fn)
@numba_funcify.register(Unbroadcast)
def numba_funcify_Unbroadcast(op, **kwargs):
@numba_basic.numba_njit
def unbroadcast(x):
return x
return unbroadcast
@numba_funcify.register(TensorFromScalar) @numba_funcify.register(TensorFromScalar)
def numba_funcify_TensorFromScalar(op, **kwargs): def numba_funcify_TensorFromScalar(op, **kwargs):
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@pytorch_funcify.register(Reshape) @pytorch_funcify.register(Reshape)
...@@ -56,11 +56,3 @@ def pytorch_funcify_SpecifyShape(op, node, **kwargs): ...@@ -56,11 +56,3 @@ def pytorch_funcify_SpecifyShape(op, node, **kwargs):
return x return x
return specifyshape return specifyshape
@pytorch_funcify.register(Unbroadcast)
def pytorch_funcify_Unbroadcast(op, **kwargs):
def unbroadcast(x):
return x
return unbroadcast
...@@ -15,7 +15,7 @@ from pytensor.scan.utils import expand_empty, safe_new, until ...@@ -15,7 +15,7 @@ from pytensor.scan.utils import expand_empty, safe_new, until
from pytensor.tensor.basic import get_underlying_scalar_constant_value from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import minimum from pytensor.tensor.math import minimum
from pytensor.tensor.shape import shape_padleft, unbroadcast from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.type import TensorType, integer_dtypes from pytensor.tensor.type import TensorType, integer_dtypes
from pytensor.updates import OrderedUpdates from pytensor.updates import OrderedUpdates
...@@ -748,7 +748,7 @@ def scan( ...@@ -748,7 +748,7 @@ def scan(
# defined in scan utils # defined in scan utils
sit_sot_scan_inputs.append( sit_sot_scan_inputs.append(
expand_empty( expand_empty(
unbroadcast(shape_padleft(actual_arg), 0), shape_padleft(actual_arg),
actual_n_steps, actual_n_steps,
) )
) )
...@@ -865,13 +865,13 @@ def scan( ...@@ -865,13 +865,13 @@ def scan(
if n_fixed_steps in (1, -1): if n_fixed_steps in (1, -1):
for pos, inner_out in enumerate(outputs): for pos, inner_out in enumerate(outputs):
# we need to see if we need to pad our sequences with an # we need to see if we need to pad our sequences with an
# unbroadcastable dimension; case example : we return an # extra dimension; case example : we return an
# output for which we want all intermediate. If n_steps is 1 # output for which we want all intermediate. If n_steps is 1
# then, if we return the output as given by the innner function # then, if we return the output as given by the innner function
# this will represent only a slice and it will have one # this will represent only a slice and it will have one
# dimension less. # dimension less.
if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1: if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1:
outputs[pos] = unbroadcast(shape_padleft(inner_out), 0) outputs[pos] = shape_padleft(inner_out)
if not return_list and len(outputs) == 1: if not return_list and len(outputs) == 1:
outputs = outputs[0] outputs = outputs[0]
...@@ -1002,7 +1002,7 @@ def scan( ...@@ -1002,7 +1002,7 @@ def scan(
sit_sot_inner_inputs.append(new_var) sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append( sit_sot_scan_inputs.append(
expand_empty( expand_empty(
unbroadcast(shape_padleft(input.variable), 0), shape_padleft(input.variable),
actual_n_steps, actual_n_steps,
) )
) )
......
...@@ -166,8 +166,7 @@ def check_broadcast(v1, v2): ...@@ -166,8 +166,7 @@ def check_broadcast(v1, v2):
"axis %d in `output_info`. This can happen if one of the " "axis %d in `output_info`. This can happen if one of the "
"dimension is fixed to 1 in the input, while it is still " "dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make " "variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using pytensor.tensor." "them consistent, e.g. using pytensor.tensor.specify_broadcastable."
"{unbroadcast, specify_broadcastable}."
) )
size = min(v1.type.ndim, v2.type.ndim) size = min(v1.type.ndim, v2.type.ndim)
for n, (b1, b2) in enumerate( for n, (b1, b2) in enumerate(
......
...@@ -53,7 +53,6 @@ from pytensor.tensor.exceptions import NotScalarConstantError ...@@ -53,7 +53,6 @@ from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
Shape, Shape,
Shape_i, Shape_i,
Unbroadcast,
shape, shape,
shape_padaxis, shape_padaxis,
shape_padleft, shape_padleft,
...@@ -334,9 +333,7 @@ def _get_underlying_scalar_constant_value( ...@@ -334,9 +333,7 @@ def _get_underlying_scalar_constant_value(
if not only_process_constants and getattr(v, "owner", None) and max_recur > 0: if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
op = v.owner.op op = v.owner.op
max_recur -= 1 max_recur -= 1
if isinstance( if isinstance(op, Alloc | DimShuffle | OutputGuard | DeepCopyOp):
op, Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp
):
# OutputGuard is only used in debugmode but we # OutputGuard is only used in debugmode but we
# keep it here to avoid problems with old pickles # keep it here to avoid problems with old pickles
v = v.owner.inputs[0] v = v.owner.inputs[0]
...@@ -498,14 +495,6 @@ def _get_underlying_scalar_constant_value( ...@@ -498,14 +495,6 @@ def _get_underlying_scalar_constant_value(
grandparent = leftmost_parent.owner.inputs[0] grandparent = leftmost_parent.owner.inputs[0]
gp_shape = grandparent.type.shape gp_shape = grandparent.type.shape
ndim = grandparent.type.ndim ndim = grandparent.type.ndim
if grandparent.owner and isinstance(
grandparent.owner.op, Unbroadcast
):
ggp_shape = grandparent.owner.inputs[0].type.shape
l = [
_get_underlying_scalar_constant_value(s) for s in ggp_shape
]
gp_shape = tuple(l)
if not (idx < ndim): if not (idx < ndim):
msg = ( msg = (
......
...@@ -42,9 +42,7 @@ from pytensor.tensor.shape import ( ...@@ -42,9 +42,7 @@ from pytensor.tensor.shape import (
Shape, Shape,
Shape_i, Shape_i,
SpecifyShape, SpecifyShape,
Unbroadcast,
specify_shape, specify_shape,
unbroadcast,
) )
from pytensor.tensor.subtensor import Subtensor, get_idx_list from pytensor.tensor.subtensor import Subtensor, get_idx_list
from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes
...@@ -1296,78 +1294,3 @@ def local_track_shape_i(fgraph, node): ...@@ -1296,78 +1294,3 @@ def local_track_shape_i(fgraph, node):
# structure. # structure.
replacement = shape_feature.scheduled[node] replacement = shape_feature.scheduled[node]
return [shape_feature.shape_of[replacement][node.op.i]] return [shape_feature.shape_of[replacement][node.op.i]]
@register_useless
@register_canonicalize
@register_specialize
@node_rewriter([Unbroadcast])
def local_useless_unbroadcast(fgraph, node):
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern."""
if isinstance(node.op, Unbroadcast):
x = node.inputs[0]
if x.type.ndim == node.outputs[0].type.ndim and all(
s1 == s2
for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape, strict=True)
if s1 == 1 or s2 == 1
):
# No broadcastable flag was modified
# No need to copy over stack trace,
# because x should already have a stack trace.
return [x]
else:
# Keep the flags that modify something
new_axes = tuple(ax for ax in node.op.axes if x.type.shape[ax] == 1)
if new_axes == node.op.axes:
# All flags are useful
return None
else:
r = unbroadcast(x, *new_axes)
# Copy over stacktrace from previous output
copy_stack_trace(node.outputs, r)
return [r]
@register_canonicalize
@register_specialize
@node_rewriter([Unbroadcast])
def local_unbroadcast_lift(fgraph, node):
"""
Lifts `Unbroadcast` through unary Elemwise operations,
and merges consecutive `Unbroadcast`s.
Unbroadcast(Elemwise(x)) => Elemwise(Unbroadcast(x))
Unbroadcast(Unbroadcast(x)) => Unbroadcast(x)
TODO: Implement equivalent Elemwise lift for SpecifyShape
"""
op = node.op
if not isinstance(op, Unbroadcast):
return False
inp = node.inputs[0]
inode = inp.owner
if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1:
if len(fgraph.clients.get(inp, ())) == 1:
unbroadcasted = unbroadcast(inode.inputs[0], *op.axes)
copy_stack_trace(node.outputs, unbroadcasted)
rval = inode.op.make_node(unbroadcasted).outputs
# Copy over stacktrace from previous output (after unbroadcasting)
# and input (after elemwise operation) to new output, because an
# error in the new graph could have been caused by either of the
# two ops.
copy_stack_trace(node.outputs + node.inputs, rval)
return rval
if inode and isinstance(inode.op, Unbroadcast):
# Merge axis of each unbroadcast
axis = tuple(set(inode.op.axes).union(set(op.axes)))
iinput = inode.inputs[0]
rval = [unbroadcast(iinput, *axis)]
# Copy over stacktrace from previous output (after second unbroadcasting)
# and from previous input (after first unbroadcasting) because an error in
# the new graph could have been caused by either of the two Unbroadcast ops.
copy_stack_trace(node.outputs + node.inputs, rval)
return rval
...@@ -59,11 +59,9 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -59,11 +59,9 @@ from pytensor.tensor.rewriting.basic import (
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
Shape, Shape,
SpecifyShape, SpecifyShape,
Unbroadcast,
shape_padleft, shape_padleft,
shape_tuple, shape_tuple,
specify_shape, specify_shape,
unbroadcast,
) )
from pytensor.tensor.sharedvar import TensorSharedVariable from pytensor.tensor.sharedvar import TensorSharedVariable
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
...@@ -429,7 +427,6 @@ def local_subtensor_lift(fgraph, node): ...@@ -429,7 +427,6 @@ def local_subtensor_lift(fgraph, node):
Handles the following unary ops: Handles the following unary ops:
elemwise(x,...)[idx] -> elemwise(x[idx],...) elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all when x,... are broadcasted scalar or not broadcasted at all
Unbroadcast(x)[idx] => Unbroadcast(x[idx])
""" """
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
...@@ -488,40 +485,6 @@ def local_subtensor_lift(fgraph, node): ...@@ -488,40 +485,6 @@ def local_subtensor_lift(fgraph, node):
copy_stack_trace([node.outputs[0], node.inputs[0]], ret) copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
return [ret] return [ret]
if isinstance(u.owner.op, Unbroadcast):
# Subtensor might reduce dim., adapt broadcast pattern accordingly
old_axes = u.owner.op.axes
new_axes = []
# loop through indices being subtensor-ed
# i indexes broadcastable pattern before subtensor
# j indexes broadcastable pattern after subtensor
j = 0
for i, x in enumerate(node.op.idx_list):
# if it is not a slice, it will reduce the dimension, should
# not appear in the broascastable dimensions
if isinstance(x, slice):
if i in old_axes:
new_axes.append(j)
j += 1
# now keep the broadcastable pattern of all
# items not appearing in subtensor list
for i in range(len(node.op.idx_list), len(u.broadcastable)):
if i in old_axes:
new_axes.append(j)
j += 1
subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
# Copy over previous output stacktrace
copy_stack_trace(node.outputs[0], subt_x)
rbcast_subt_x = unbroadcast(subt_x, *new_axes)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x)
return [rbcast_subt_x]
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
......
...@@ -18,7 +18,6 @@ from pytensor.link.c.params_type import ParamsType ...@@ -18,7 +18,6 @@ from pytensor.link.c.params_type import ParamsType
from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
from pytensor.tensor.elemwise import get_normalized_batch_axes
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor
from pytensor.tensor.type_other import NoneConst, NoneTypeT from pytensor.tensor.type_other import NoneConst, NoneTypeT
...@@ -1008,118 +1007,3 @@ def specify_broadcastable(x, *axes): ...@@ -1008,118 +1007,3 @@ def specify_broadcastable(x, *axes):
axes = normalize_axis_tuple(axes, x.type.ndim) axes = normalize_axis_tuple(axes, x.type.ndim)
shape_info = [1 if i in axes else s for i, s in enumerate(x.type.shape)] shape_info = [1 if i in axes else s for i, s in enumerate(x.type.shape)]
return specify_shape(x, shape_info) return specify_shape(x, shape_info)
class Unbroadcast(COp):
"""
Mask static broadcastable dimensions of input as `None`
See Also
--------
unbroadcast <pytensor.tensor.shape.unbroadcast>
Examples
--------
``Unbroadcast((1,))(x)`` would make `x` second static dimension be `None`
"""
view_map = {0: [0]}
_f16_ok = True
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version: dict = {}
check_input = False
__props__ = ("axes",)
_f16_ok = True
def __init__(self, *axis):
# Sort them to make sure we merge all possible case.
items = tuple(sorted(axis))
self.axes = items
for axis in self.axes:
if not isinstance(axis, np.integer | int):
raise TypeError(f"Unbroadcast needs integer axes. Got {axis}")
def __str__(self):
return f"{self.__class__.__name__}{{{','.join(str(i) for i in self.axes)}}}"
def make_node(self, x):
x = as_tensor_variable(x)
if x.type.ndim <= max(self.axes):
raise ValueError("Trying to unbroadcast of non-existent dimension")
shape = [
None if (sh == 1 and i in self.axes) else sh
for i, sh in enumerate(x.type.shape)
]
return Apply(self, [x], [x.type.clone(shape=shape)()])
def perform(self, node, inp, out_):
(x,) = inp
(out,) = out_
out[0] = x
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
# restore the broadcasting pattern of the input
return [specify_shape(gz, x.type.shape)]
def infer_shape(self, fgraph, node, ishapes):
assert len(ishapes) == 1
return [tuple(ishapes[0])]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return self(*eval_points, return_list=True)
def c_code(self, node, nodename, inp, out, sub):
(iname,) = inp
(oname,) = out
return f"""
Py_XDECREF({oname});
{oname} = {iname};
Py_XINCREF({oname});
"""
def c_code_cache_version(self):
return (3,)
def unbroadcast(x, *axes):
"""
Mask static broadcastable dimensions of input as `None`
Parameters
----------
x : tensor_like
Input pytensor tensor.
axis : an int or an iterable object such as list or tuple of int values
The broadcastable dimensions of x that should be unbroadcasted.
Returns
-------
tensor
A pytensor tensor, with static broadcastable dimensions masked as `None`
"""
x = as_tensor_variable(x)
unbroadcasted_axes = [axis for axis in axes if x.type.shape[axis] == 1]
if not unbroadcasted_axes:
return x
return Unbroadcast(*unbroadcasted_axes)(x)
@_vectorize_node.register(Unbroadcast)
def _vectorize_unbroadcast(
op: Unbroadcast, node: Apply, batch_x: TensorVariable
) -> Apply:
core_ndim = node.inputs[0].type.ndim
batch_ndim = batch_x.type.ndim - core_ndim
batch_axes = get_normalized_batch_axes(op.axes, core_ndim, batch_ndim)
return cast(Apply, unbroadcast(batch_x, *batch_axes).owner)
...@@ -4,7 +4,7 @@ import pytest ...@@ -4,7 +4,7 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.compile.ops import DeepCopyOp, ViewOp from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape from pytensor.tensor.shape import Shape, Shape_i, reshape
from pytensor.tensor.type import iscalar, vector from pytensor.tensor.type import iscalar, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
...@@ -70,10 +70,6 @@ def test_jax_compile_ops(): ...@@ -70,10 +70,6 @@ def test_jax_compile_ops():
compare_jax_and_py([], [x], []) compare_jax_and_py([], [x], [])
x_np = np.zeros((20, 1, 1)) x_np = np.zeros((20, 1, 1))
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
compare_jax_and_py([], [x], [])
x = ViewOp()(pt.as_tensor_variable(x_np)) x = ViewOp()(pt.as_tensor_variable(x_np))
compare_jax_and_py([], [x], []) compare_jax_and_py([], [x], [])
...@@ -7,7 +7,6 @@ import pytensor.tensor.basic as ptb ...@@ -7,7 +7,6 @@ import pytensor.tensor.basic as ptb
from pytensor import config, function from pytensor import config, function
from pytensor.compile import get_mode from pytensor.compile import get_mode
from pytensor.scalar import Add from pytensor.scalar import Add
from pytensor.tensor.shape import Unbroadcast
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import (
compare_numba_and_py, compare_numba_and_py,
compare_shape_dtype, compare_shape_dtype,
...@@ -75,16 +74,6 @@ def test_ScalarFromTensor(): ...@@ -75,16 +74,6 @@ def test_ScalarFromTensor():
) )
def test_Unbroadcast():
v, v_test = pt.row(), np.array([[1.0, 2.0]], dtype=config.floatX)
g = Unbroadcast(0)(v)
compare_numba_and_py(
[v],
g,
[v_test],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"vals, dtype", "vals, dtype",
[ [
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape from pytensor.tensor.shape import Shape, Shape_i, reshape
from pytensor.tensor.type import iscalar, vector from pytensor.tensor.type import iscalar, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py from tests.link.pytorch.test_basic import compare_pytorch_and_py
...@@ -50,10 +50,3 @@ def test_pytorch_Reshape_dynamic(): ...@@ -50,10 +50,3 @@ def test_pytorch_Reshape_dynamic():
compare_pytorch_and_py( compare_pytorch_and_py(
[a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2] [a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]
) )
def test_pytorch_unbroadcast():
x_np = np.zeros((20, 1, 1))
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
compare_pytorch_and_py([], [x], [])
...@@ -36,32 +36,31 @@ def test_debugprint_sitsot(): ...@@ -36,32 +36,31 @@ def test_debugprint_sitsot():
│ │ │ │ │ ├─ k [id D] │ │ │ │ │ ├─ k [id D]
│ │ │ │ │ └─ Subtensor{i} [id H] │ │ │ │ │ └─ Subtensor{i} [id H]
│ │ │ │ │ ├─ Shape [id I] │ │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ │ └─ Unbroadcast{0} [id J] │ │ │ │ │ │ └─ ExpandDims{axis=0} [id J]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id K] │ │ │ │ │ │ └─ Second [id K]
│ │ │ │ │ │ └─ Second [id L] │ │ │ │ │ │ ├─ A [id L]
│ │ │ │ │ │ ├─ A [id M] │ │ │ │ │ │ └─ ExpandDims{axis=0} [id M]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id N] │ │ │ │ │ │ └─ 1.0 [id N]
│ │ │ │ │ │ └─ 1.0 [id O] │ │ │ │ │ └─ 0 [id O]
│ │ │ │ │ └─ 0 [id P] │ │ │ │ └─ Subtensor{i} [id P]
│ │ │ │ └─ Subtensor{i} [id Q]
│ │ │ │ ├─ Shape [id I] │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ 1 [id R] │ │ │ │ └─ 1 [id Q]
│ │ │ ├─ Unbroadcast{0} [id J] │ │ │ ├─ ExpandDims{axis=0} [id J]
│ │ │ │ └─ ··· │ │ │ │ └─ ···
│ │ │ └─ ScalarFromTensor [id S] │ │ │ └─ ScalarFromTensor [id R]
│ │ │ └─ Subtensor{i} [id H] │ │ │ └─ Subtensor{i} [id H]
│ │ │ └─ ··· │ │ │ └─ ···
│ │ └─ A [id M] (outer_in_non_seqs-0) │ │ └─ A [id L] (outer_in_non_seqs-0)
│ └─ 1 [id T] │ └─ 1 [id S]
└─ -1 [id U] └─ -1 [id T]
Inner graphs: Inner graphs:
Scan{scan_fn, while_loop=False, inplace=none} [id C] Scan{scan_fn, while_loop=False, inplace=none} [id C]
← Mul [id V] (inner_out_sit_sot-0) ← Mul [id U] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id W] -> [id E] (inner_in_sit_sot-0) ├─ *0-<Vector(float64, shape=(?,))> [id V] -> [id E] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id X] -> [id M] (inner_in_non_seqs-0) └─ *1-<Vector(float64, shape=(?,))> [id W] -> [id L] (inner_in_non_seqs-0)
""" """
for truth, out in zip(expected_output.split("\n"), lines, strict=True): for truth, out in zip(expected_output.split("\n"), lines, strict=True):
...@@ -94,32 +93,31 @@ def test_debugprint_sitsot_no_extra_info(): ...@@ -94,32 +93,31 @@ def test_debugprint_sitsot_no_extra_info():
│ │ │ │ │ ├─ k [id D] │ │ │ │ │ ├─ k [id D]
│ │ │ │ │ └─ Subtensor{i} [id H] │ │ │ │ │ └─ Subtensor{i} [id H]
│ │ │ │ │ ├─ Shape [id I] │ │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ │ └─ Unbroadcast{0} [id J] │ │ │ │ │ │ └─ ExpandDims{axis=0} [id J]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id K] │ │ │ │ │ │ └─ Second [id K]
│ │ │ │ │ │ └─ Second [id L] │ │ │ │ │ │ ├─ A [id L]
│ │ │ │ │ │ ├─ A [id M] │ │ │ │ │ │ └─ ExpandDims{axis=0} [id M]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id N] │ │ │ │ │ │ └─ 1.0 [id N]
│ │ │ │ │ │ └─ 1.0 [id O] │ │ │ │ │ └─ 0 [id O]
│ │ │ │ │ └─ 0 [id P] │ │ │ │ └─ Subtensor{i} [id P]
│ │ │ │ └─ Subtensor{i} [id Q]
│ │ │ │ ├─ Shape [id I] │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ 1 [id R] │ │ │ │ └─ 1 [id Q]
│ │ │ ├─ Unbroadcast{0} [id J] │ │ │ ├─ ExpandDims{axis=0} [id J]
│ │ │ │ └─ ··· │ │ │ │ └─ ···
│ │ │ └─ ScalarFromTensor [id S] │ │ │ └─ ScalarFromTensor [id R]
│ │ │ └─ Subtensor{i} [id H] │ │ │ └─ Subtensor{i} [id H]
│ │ │ └─ ··· │ │ │ └─ ···
│ │ └─ A [id M] │ │ └─ A [id L]
│ └─ 1 [id T] │ └─ 1 [id S]
└─ -1 [id U] └─ -1 [id T]
Inner graphs: Inner graphs:
Scan{scan_fn, while_loop=False, inplace=none} [id C] Scan{scan_fn, while_loop=False, inplace=none} [id C]
← Mul [id V] ← Mul [id U]
├─ *0-<Vector(float64, shape=(?,))> [id W] -> [id E] ├─ *0-<Vector(float64, shape=(?,))> [id V] -> [id E]
└─ *1-<Vector(float64, shape=(?,))> [id X] -> [id M] └─ *1-<Vector(float64, shape=(?,))> [id W] -> [id L]
""" """
for truth, out in zip(expected_output.split("\n"), lines, strict=True): for truth, out in zip(expected_output.split("\n"), lines, strict=True):
...@@ -278,32 +276,31 @@ def test_debugprint_nested_scans(): ...@@ -278,32 +276,31 @@ def test_debugprint_nested_scans():
│ │ │ │ │ │ ├─ *3-<Scalar(int32, shape=())> [id BF] -> [id X] (inner_in_non_seqs-1) │ │ │ │ │ │ ├─ *3-<Scalar(int32, shape=())> [id BF] -> [id X] (inner_in_non_seqs-1)
│ │ │ │ │ │ └─ Subtensor{i} [id BJ] │ │ │ │ │ │ └─ Subtensor{i} [id BJ]
│ │ │ │ │ │ ├─ Shape [id BK] │ │ │ │ │ │ ├─ Shape [id BK]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BL] │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BL]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BM] │ │ │ │ │ │ │ └─ Second [id BM]
│ │ │ │ │ │ │ └─ Second [id BN] │ │ │ │ │ │ │ ├─ *2-<Vector(float64, shape=(?,))> [id BN] -> [id W] (inner_in_non_seqs-0)
│ │ │ │ │ │ │ ├─ *2-<Vector(float64, shape=(?,))> [id BO] -> [id W] (inner_in_non_seqs-0) │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP] │ │ │ │ │ │ │ └─ 1.0 [id BP]
│ │ │ │ │ │ │ └─ 1.0 [id BQ] │ │ │ │ │ │ └─ 0 [id BQ]
│ │ │ │ │ │ └─ 0 [id BR] │ │ │ │ │ └─ Subtensor{i} [id BR]
│ │ │ │ │ └─ Subtensor{i} [id BS]
│ │ │ │ │ ├─ Shape [id BK] │ │ │ │ │ ├─ Shape [id BK]
│ │ │ │ │ │ └─ ··· │ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ 1 [id BT] │ │ │ │ │ └─ 1 [id BS]
│ │ │ │ ├─ Unbroadcast{0} [id BL] │ │ │ │ ├─ ExpandDims{axis=0} [id BL]
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarFromTensor [id BU] │ │ │ │ └─ ScalarFromTensor [id BT]
│ │ │ │ └─ Subtensor{i} [id BJ] │ │ │ │ └─ Subtensor{i} [id BJ]
│ │ │ │ └─ ··· │ │ │ │ └─ ···
│ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0) │ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BN] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ └─ 1 [id BV] │ │ └─ 1 [id BU]
│ └─ -1 [id BW] │ └─ -1 [id BV]
└─ ExpandDims{axis=0} [id BX] └─ ExpandDims{axis=0} [id BW]
└─ *1-<Scalar(int64, shape=())> [id BY] -> [id U] (inner_in_seqs-1) └─ *1-<Scalar(int64, shape=())> [id BX] -> [id U] (inner_in_seqs-1)
Scan{scan_fn, while_loop=False, inplace=none} [id BE] Scan{scan_fn, while_loop=False, inplace=none} [id BE]
← Mul [id BZ] (inner_out_sit_sot-0) ← Mul [id BY] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id CA] -> [id BG] (inner_in_sit_sot-0) ├─ *0-<Vector(float64, shape=(?,))> [id BZ] -> [id BG] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CB] -> [id BO] (inner_in_non_seqs-0) └─ *1-<Vector(float64, shape=(?,))> [id CA] -> [id BN] (inner_in_non_seqs-0)
""" """
for truth, out in zip(expected_output.split("\n"), lines, strict=True): for truth, out in zip(expected_output.split("\n"), lines, strict=True):
...@@ -375,34 +372,33 @@ def test_debugprint_nested_scans(): ...@@ -375,34 +372,33 @@ def test_debugprint_nested_scans():
│ │ │ │ │ │ ├─ *3-<Scalar(int32, shape=())> [id BB] (inner_in_non_seqs-1) │ │ │ │ │ │ ├─ *3-<Scalar(int32, shape=())> [id BB] (inner_in_non_seqs-1)
│ │ │ │ │ │ └─ Subtensor{i} [id BL] │ │ │ │ │ │ └─ Subtensor{i} [id BL]
│ │ │ │ │ │ ├─ Shape [id BM] │ │ │ │ │ │ ├─ Shape [id BM]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BN] │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BN]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO] │ │ │ │ │ │ │ └─ Second [id BO]
│ │ │ │ │ │ │ └─ Second [id BP] │ │ │ │ │ │ │ ├─ *2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0)
│ │ │ │ │ │ │ ├─ *2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0) │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BQ] │ │ │ │ │ │ │ └─ 1.0 [id BQ]
│ │ │ │ │ │ │ └─ 1.0 [id BR] │ │ │ │ │ │ └─ 0 [id BR]
│ │ │ │ │ │ └─ 0 [id BS] │ │ │ │ │ └─ Subtensor{i} [id BS]
│ │ │ │ │ └─ Subtensor{i} [id BT]
│ │ │ │ │ ├─ Shape [id BM] │ │ │ │ │ ├─ Shape [id BM]
│ │ │ │ │ │ └─ ··· │ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ 1 [id BU] │ │ │ │ │ └─ 1 [id BT]
│ │ │ │ ├─ Unbroadcast{0} [id BN] │ │ │ │ ├─ ExpandDims{axis=0} [id BN]
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ ScalarFromTensor [id BV] │ │ │ │ └─ ScalarFromTensor [id BU]
│ │ │ │ └─ Subtensor{i} [id BL] │ │ │ │ └─ Subtensor{i} [id BL]
│ │ │ │ └─ ··· │ │ │ │ └─ ···
│ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0) │ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ └─ 1 [id BW] │ │ └─ 1 [id BV]
│ └─ -1 [id BX] │ └─ -1 [id BW]
└─ ExpandDims{axis=0} [id BY] └─ ExpandDims{axis=0} [id BX]
└─ *1-<Scalar(int64, shape=())> [id Z] (inner_in_seqs-1) └─ *1-<Scalar(int64, shape=())> [id Z] (inner_in_seqs-1)
Scan{scan_fn, while_loop=False, inplace=none} [id BH] Scan{scan_fn, while_loop=False, inplace=none} [id BH]
→ *0-<Vector(float64, shape=(?,))> [id BZ] -> [id BI] (inner_in_sit_sot-0) → *0-<Vector(float64, shape=(?,))> [id BY] -> [id BI] (inner_in_sit_sot-0)
→ *1-<Vector(float64, shape=(?,))> [id CA] -> [id BA] (inner_in_non_seqs-0) → *1-<Vector(float64, shape=(?,))> [id BZ] -> [id BA] (inner_in_non_seqs-0)
← Mul [id CB] (inner_out_sit_sot-0) ← Mul [id CA] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id BZ] (inner_in_sit_sot-0) ├─ *0-<Vector(float64, shape=(?,))> [id BY] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CA] (inner_in_non_seqs-0) └─ *1-<Vector(float64, shape=(?,))> [id BZ] (inner_in_non_seqs-0)
""" """
for truth, out in zip(expected_output.split("\n"), lines, strict=True): for truth, out in zip(expected_output.split("\n"), lines, strict=True):
...@@ -516,105 +512,104 @@ def test_debugprint_mitmot(): ...@@ -516,105 +512,104 @@ def test_debugprint_mitmot():
│ │ │ │ │ │ │ ├─ k [id G] │ │ │ │ │ │ │ ├─ k [id G]
│ │ │ │ │ │ │ └─ Subtensor{i} [id K] │ │ │ │ │ │ │ └─ Subtensor{i} [id K]
│ │ │ │ │ │ │ ├─ Shape [id L] │ │ │ │ │ │ │ ├─ Shape [id L]
│ │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M] │ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id M]
│ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N] │ │ │ │ │ │ │ │ └─ Second [id N]
│ │ │ │ │ │ │ │ └─ Second [id O] │ │ │ │ │ │ │ │ ├─ A [id O]
│ │ │ │ │ │ │ │ ├─ A [id P] │ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id P]
│ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id Q] │ │ │ │ │ │ │ │ └─ 1.0 [id Q]
│ │ │ │ │ │ │ │ └─ 1.0 [id R] │ │ │ │ │ │ │ └─ 0 [id R]
│ │ │ │ │ │ │ └─ 0 [id S] │ │ │ │ │ │ └─ Subtensor{i} [id S]
│ │ │ │ │ │ └─ Subtensor{i} [id T]
│ │ │ │ │ │ ├─ Shape [id L] │ │ │ │ │ │ ├─ Shape [id L]
│ │ │ │ │ │ │ └─ ··· │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ 1 [id U] │ │ │ │ │ │ └─ 1 [id T]
│ │ │ │ │ ├─ Unbroadcast{0} [id M] │ │ │ │ │ ├─ ExpandDims{axis=0} [id M]
│ │ │ │ │ │ └─ ··· │ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ ScalarFromTensor [id V] │ │ │ │ │ └─ ScalarFromTensor [id U]
│ │ │ │ │ └─ Subtensor{i} [id K] │ │ │ │ │ └─ Subtensor{i} [id K]
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ A [id P] (outer_in_non_seqs-0) │ │ │ │ └─ A [id O] (outer_in_non_seqs-0)
│ │ │ └─ 0 [id W] │ │ │ └─ 0 [id V]
│ │ └─ 1 [id X] │ │ └─ 1 [id W]
│ ├─ Subtensor{:stop} [id Y] (outer_in_seqs-0) │ ├─ Subtensor{:stop} [id X] (outer_in_seqs-0)
│ │ ├─ Subtensor{::step} [id Z] │ │ ├─ Subtensor{::step} [id Y]
│ │ │ ├─ Subtensor{:stop} [id BA] │ │ │ ├─ Subtensor{:stop} [id Z]
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ -1 [id BB] │ │ │ │ └─ -1 [id BA]
│ │ │ └─ -1 [id BC] │ │ │ └─ -1 [id BB]
│ │ └─ ScalarFromTensor [id BD] │ │ └─ ScalarFromTensor [id BC]
│ │ └─ Sub [id C] │ │ └─ Sub [id C]
│ │ └─ ··· │ │ └─ ···
│ ├─ Subtensor{:stop} [id BE] (outer_in_seqs-1) │ ├─ Subtensor{:stop} [id BD] (outer_in_seqs-1)
│ │ ├─ Subtensor{:stop} [id BF] │ │ ├─ Subtensor{:stop} [id BE]
│ │ │ ├─ Subtensor{::step} [id BG] │ │ │ ├─ Subtensor{::step} [id BF]
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ -1 [id BH] │ │ │ │ └─ -1 [id BG]
│ │ │ └─ -1 [id BI] │ │ │ └─ -1 [id BH]
│ │ └─ ScalarFromTensor [id BJ] │ │ └─ ScalarFromTensor [id BI]
│ │ └─ Sub [id C] │ │ └─ Sub [id C]
│ │ └─ ··· │ │ └─ ···
│ ├─ Subtensor{::step} [id BK] (outer_in_mit_mot-0) │ ├─ Subtensor{::step} [id BJ] (outer_in_mit_mot-0)
│ │ ├─ IncSubtensor{start:} [id BL] │ │ ├─ IncSubtensor{start:} [id BK]
│ │ │ ├─ Second [id BM] │ │ │ ├─ Second [id BL]
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BN] │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BM]
│ │ │ │ └─ 0.0 [id BO] │ │ │ │ └─ 0.0 [id BN]
│ │ │ ├─ IncSubtensor{i} [id BP] │ │ │ ├─ IncSubtensor{i} [id BO]
│ │ │ │ ├─ Second [id BQ] │ │ │ │ ├─ Second [id BP]
│ │ │ │ │ ├─ Subtensor{start:} [id BR] │ │ │ │ │ ├─ Subtensor{start:} [id BQ]
│ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) │ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ │ │ └─ ··· │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ 1 [id BS] │ │ │ │ │ │ └─ 1 [id BR]
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BT] │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BS]
│ │ │ │ │ └─ 0.0 [id BU] │ │ │ │ │ └─ 0.0 [id BT]
│ │ │ │ ├─ Second [id BV] │ │ │ │ ├─ Second [id BU]
│ │ │ │ │ ├─ Subtensor{i} [id BW] │ │ │ │ │ ├─ Subtensor{i} [id BV]
│ │ │ │ │ │ ├─ Subtensor{start:} [id BR] │ │ │ │ │ │ ├─ Subtensor{start:} [id BQ]
│ │ │ │ │ │ │ └─ ··· │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ -1 [id BX] │ │ │ │ │ │ └─ -1 [id BW]
│ │ │ │ │ └─ ExpandDims{axis=0} [id BY] │ │ │ │ │ └─ ExpandDims{axis=0} [id BX]
│ │ │ │ │ └─ Second [id BZ] │ │ │ │ │ └─ Second [id BY]
│ │ │ │ │ ├─ Sum{axes=None} [id CA] │ │ │ │ │ ├─ Sum{axes=None} [id BZ]
│ │ │ │ │ │ └─ Subtensor{i} [id BW] │ │ │ │ │ │ └─ Subtensor{i} [id BV]
│ │ │ │ │ │ └─ ··· │ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ 1.0 [id CB] │ │ │ │ │ └─ 1.0 [id CA]
│ │ │ │ └─ -1 [id BX] │ │ │ │ └─ -1 [id BW]
│ │ │ └─ 1 [id BS] │ │ │ └─ 1 [id BR]
│ │ └─ -1 [id CC] │ │ └─ -1 [id CB]
│ ├─ Alloc [id CD] (outer_in_sit_sot-0) │ ├─ Alloc [id CC] (outer_in_sit_sot-0)
│ │ ├─ 0.0 [id CE] │ │ ├─ 0.0 [id CD]
│ │ ├─ Add [id CF] │ │ ├─ Add [id CE]
│ │ │ ├─ Sub [id C] │ │ │ ├─ Sub [id C]
│ │ │ │ └─ ··· │ │ │ │ └─ ···
│ │ │ └─ 1 [id CG] │ │ │ └─ 1 [id CF]
│ │ └─ Subtensor{i} [id CH] │ │ └─ Subtensor{i} [id CG]
│ │ ├─ Shape [id CI] │ │ ├─ Shape [id CH]
│ │ │ └─ A [id P] │ │ │ └─ A [id O]
│ │ └─ 0 [id CJ] │ │ └─ 0 [id CI]
│ └─ A [id P] (outer_in_non_seqs-0) │ └─ A [id O] (outer_in_non_seqs-0)
└─ -1 [id CK] └─ -1 [id CJ]
Inner graphs: Inner graphs:
Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B] Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B]
← Add [id CL] (inner_out_mit_mot-0-0) ← Add [id CK] (inner_out_mit_mot-0-0)
├─ Mul [id CM] ├─ Mul [id CL]
│ ├─ *2-<Vector(float64, shape=(?,))> [id CN] -> [id BK] (inner_in_mit_mot-0-0) │ ├─ *2-<Vector(float64, shape=(?,))> [id CM] -> [id BJ] (inner_in_mit_mot-0-0)
│ └─ *5-<Vector(float64, shape=(?,))> [id CO] -> [id P] (inner_in_non_seqs-0) │ └─ *5-<Vector(float64, shape=(?,))> [id CN] -> [id O] (inner_in_non_seqs-0)
└─ *3-<Vector(float64, shape=(?,))> [id CP] -> [id BK] (inner_in_mit_mot-0-1) └─ *3-<Vector(float64, shape=(?,))> [id CO] -> [id BJ] (inner_in_mit_mot-0-1)
← Add [id CQ] (inner_out_sit_sot-0) ← Add [id CP] (inner_out_sit_sot-0)
├─ Mul [id CR] ├─ Mul [id CQ]
│ ├─ *2-<Vector(float64, shape=(?,))> [id CN] -> [id BK] (inner_in_mit_mot-0-0) │ ├─ *2-<Vector(float64, shape=(?,))> [id CM] -> [id BJ] (inner_in_mit_mot-0-0)
│ └─ *0-<Vector(float64, shape=(?,))> [id CS] -> [id Y] (inner_in_seqs-0) │ └─ *0-<Vector(float64, shape=(?,))> [id CR] -> [id X] (inner_in_seqs-0)
└─ *4-<Vector(float64, shape=(?,))> [id CT] -> [id CD] (inner_in_sit_sot-0) └─ *4-<Vector(float64, shape=(?,))> [id CS] -> [id CC] (inner_in_sit_sot-0)
Scan{scan_fn, while_loop=False, inplace=none} [id F] Scan{scan_fn, while_loop=False, inplace=none} [id F]
← Mul [id CU] (inner_out_sit_sot-0) ← Mul [id CT] (inner_out_sit_sot-0)
├─ *0-<Vector(float64, shape=(?,))> [id CS] -> [id H] (inner_in_sit_sot-0) ├─ *0-<Vector(float64, shape=(?,))> [id CR] -> [id H] (inner_in_sit_sot-0)
└─ *1-<Vector(float64, shape=(?,))> [id CV] -> [id P] (inner_in_non_seqs-0) └─ *1-<Vector(float64, shape=(?,))> [id CU] -> [id O] (inner_in_non_seqs-0)
""" """
for truth, out in zip(expected_output.split("\n"), lines, strict=True): for truth, out in zip(expected_output.split("\n"), lines, strict=True):
......
...@@ -1621,7 +1621,7 @@ class TestSaveMem: ...@@ -1621,7 +1621,7 @@ class TestSaveMem:
np.testing.assert_allclose(f(x0=0, seq=test_seq, n_steps=200), 100) np.testing.assert_allclose(f(x0=0, seq=test_seq, n_steps=200), 100)
np.testing.assert_allclose(f(x0=1, seq=test_seq, n_steps=20), 21) np.testing.assert_allclose(f(x0=1, seq=test_seq, n_steps=20), 21)
np.testing.assert_allclose(f(x0=np.e, seq=test_seq, n_steps=1), np.e + 1) np.testing.assert_allclose(f(x0=np.e, seq=test_seq, n_steps=1), np.e + 1)
with pytest.raises(AssertionError, match="n_steps > 0"): with pytest.raises((AssertionError, IndexError)):
f(x0=0, seq=test_seq, n_steps=0) f(x0=0, seq=test_seq, n_steps=0)
# Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly. # Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly.
......
...@@ -77,9 +77,7 @@ from pytensor.tensor.shape import ( ...@@ -77,9 +77,7 @@ from pytensor.tensor.shape import (
Reshape, Reshape,
Shape_i, Shape_i,
SpecifyShape, SpecifyShape,
Unbroadcast,
specify_shape, specify_shape,
unbroadcast,
) )
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
...@@ -558,48 +556,6 @@ class TestTile: ...@@ -558,48 +556,6 @@ class TestTile:
f(data) f(data)
class TestUnbroadcast:
def setup_method(self):
self.mode = get_default_mode().including("canonicalize")
def test_local_useless_unbroadcast(self):
x1 = tensor(dtype="float64", shape=(1, 2))
x2 = tensor(dtype="float64", shape=(2, 1))
unbroadcast_op = Unbroadcast(0)
f = function([x1], unbroadcast_op(x1), mode=self.mode)
assert (
sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort())
== 1
)
f = function([x2], unbroadcast_op(x2), mode=self.mode)
assert (
sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort())
== 0
)
def test_local_unbroadcast_lift(self):
x = tensor(dtype="float64", shape=(1, 1))
y = unbroadcast(pt.exp(unbroadcast(x, 0)), 1)
assert (
sum(
isinstance(node.op, Unbroadcast)
for node in FunctionGraph([x], [y], copy_inputs=False).toposort()
)
== 2
)
f = function([x], y, mode=self.mode)
assert (
sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort())
== 1
)
np.testing.assert_almost_equal(f([[1]]), np.exp([[1]]))
class TestUselessElemwise: class TestUselessElemwise:
def setup_method(self): def setup_method(self):
self.mode = get_default_mode().including("canonicalize", "local_fill_to_alloc") self.mode = get_default_mode().including("canonicalize", "local_fill_to_alloc")
......
...@@ -28,7 +28,6 @@ from pytensor.tensor.rewriting.subtensor import ( ...@@ -28,7 +28,6 @@ from pytensor.tensor.rewriting.subtensor import (
) )
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
SpecifyShape, SpecifyShape,
Unbroadcast,
_shape, _shape,
shape, shape,
specify_shape, specify_shape,
...@@ -55,7 +54,6 @@ from pytensor.tensor.type import ( ...@@ -55,7 +54,6 @@ from pytensor.tensor.type import (
lscalar, lscalar,
lscalars, lscalars,
matrix, matrix,
row,
scalar, scalar,
tensor, tensor,
tensor3, tensor3,
...@@ -921,64 +919,6 @@ class TestLocalSubtensorLift: ...@@ -921,64 +919,6 @@ class TestLocalSubtensorLift:
assert len(prog) == 2 assert len(prog) == 2
f([1, 2, 3], 4) # let debugmode test something f([1, 2, 3], 4) # let debugmode test something
def test_basic_8(self):
# Test that Subtensor(Unbroadcast(x)) gets optimized into
# Unbroadcast(Subtensor(x)).
# test basic case
x = row("x")
xval = np.random.random((1, 10)).astype(config.floatX)
assert x.broadcastable == (True, False)
newx = Unbroadcast(0)(x)
assert newx.broadcastable == (False, False)
f1 = function([x], newx[:2, :5], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f1, ops_to_check=[Subtensor, Unbroadcast])
prog = f1.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Unbroadcast)
assert (f1(xval) == xval[:2, :5]).all()
# corner case 1: Unbroadcast changes dims which are dropped through subtensor
y = tensor(dtype="float64", shape=(1, 10, 1, 3), name="x")
yval = np.random.random((1, 10, 1, 3)).astype(config.floatX)
assert y.broadcastable == (True, False, True, False)
newy = Unbroadcast(0, 2)(y)
assert newy.broadcastable == (False, False, False, False)
f2 = function([y], newy[:, 3, 0, :], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f2, ops_to_check=[Subtensor, Unbroadcast])
prog = f2.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Unbroadcast)
assert (f2(yval) == yval[:, 3, 0, :]).all()
# corner case 2: subtensor idx_list is shorter than resulting broadcast pattern
f3 = function([y], newy[:, 3, 0], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f3, ops_to_check=[Subtensor, Unbroadcast])
prog = f3.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Unbroadcast)
assert (f3(yval) == yval[:, 3, 0]).all()
# corner case 3: subtensor idx_list is shorter than Unbroadcast.axis
z = tensor(dtype="float64", shape=(4, 10, 3, 1), name="x")
zval = np.random.random((4, 10, 3, 1)).astype(config.floatX)
assert z.broadcastable == (False, False, False, True)
newz = Unbroadcast(3)(z)
assert newz.broadcastable == (False, False, False, False)
f4 = function([z], newz[:, 3, 0], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f4, ops_to_check=[Subtensor, Unbroadcast])
prog = f4.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Unbroadcast)
assert (f4(zval) == zval[:, 3, 0]).all()
class TestLocalSubtensorMerge: class TestLocalSubtensorMerge:
def setup_method(self): def setup_method(self):
......
...@@ -287,7 +287,7 @@ TestAlloc13GradBroadcast = makeBroadcastTester( ...@@ -287,7 +287,7 @@ TestAlloc13GradBroadcast = makeBroadcastTester(
), ),
) )
# unbroadcast a row to a matrix # broadcast a row to a matrix
TestAllocb1GradBroadcast = makeBroadcastTester( TestAllocb1GradBroadcast = makeBroadcastTester(
name="Allocb1GradTester", name="Allocb1GradTester",
op=lambda x: alloc(x, s1, s2), op=lambda x: alloc(x, s1, s2),
...@@ -299,7 +299,7 @@ TestAllocb1GradBroadcast = makeBroadcastTester( ...@@ -299,7 +299,7 @@ TestAllocb1GradBroadcast = makeBroadcastTester(
), ),
) )
# unbroadcast a row to a tensor3 # broadcast a row to a tensor3
TestAllocb2GradBroadcast = makeBroadcastTester( TestAllocb2GradBroadcast = makeBroadcastTester(
name="Allocb2GradTester", name="Allocb2GradTester",
op=lambda x: alloc(x, s1, s2, s3), op=lambda x: alloc(x, s1, s2, s3),
...@@ -311,7 +311,7 @@ TestAllocb2GradBroadcast = makeBroadcastTester( ...@@ -311,7 +311,7 @@ TestAllocb2GradBroadcast = makeBroadcastTester(
), ),
) )
# unbroadcast a col to a matrix # broadcast a col to a matrix
TestAllocb3GradBroadcast = makeBroadcastTester( TestAllocb3GradBroadcast = makeBroadcastTester(
name="Allocb3GradTester", name="Allocb3GradTester",
op=lambda x: alloc(x, s1, s2), op=lambda x: alloc(x, s1, s2),
...@@ -323,7 +323,7 @@ TestAllocb3GradBroadcast = makeBroadcastTester( ...@@ -323,7 +323,7 @@ TestAllocb3GradBroadcast = makeBroadcastTester(
), ),
) )
# unbroadcast a col to a tensor3 # broadcast a col to a tensor3
TestAllocb4GradBroadcast = makeBroadcastTester( TestAllocb4GradBroadcast = makeBroadcastTester(
name="Allocb4GradTester", name="Allocb4GradTester",
op=lambda x: alloc(x, s1, s2, s3), op=lambda x: alloc(x, s1, s2, s3),
...@@ -336,7 +336,7 @@ TestAllocb4GradBroadcast = makeBroadcastTester( ...@@ -336,7 +336,7 @@ TestAllocb4GradBroadcast = makeBroadcastTester(
) )
# Partial unbroadcast of a dimshuffled input # Partial broadcast of a dimshuffled input
TestAllocDimshuffleGradBroadcast = makeBroadcastTester( TestAllocDimshuffleGradBroadcast = makeBroadcastTester(
name="Allocb4GradTester", name="Allocb4GradTester",
op=lambda x: alloc(x.dimshuffle("x", "x", 0), 1, s2, s3), op=lambda x: alloc(x.dimshuffle("x", "x", 0), 1, s2, s3),
......
...@@ -19,14 +19,12 @@ from pytensor.tensor.shape import ( ...@@ -19,14 +19,12 @@ from pytensor.tensor.shape import (
Shape, Shape,
Shape_i, Shape_i,
SpecifyShape, SpecifyShape,
Unbroadcast,
_specify_shape, _specify_shape,
reshape, reshape,
shape, shape,
shape_tuple, shape_tuple,
specify_broadcastable, specify_broadcastable,
specify_shape, specify_shape,
unbroadcast,
) )
from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import ( from pytensor.tensor.type import (
...@@ -696,66 +694,6 @@ def test_get_vector_length(): ...@@ -696,66 +694,6 @@ def test_get_vector_length():
assert get_vector_length(x) == 10 assert get_vector_length(x) == 10
class TestUnbroadcast:
def test_basic(self):
x = matrix()
assert unbroadcast(x, 0) is x
assert unbroadcast(x, 1) is x
assert unbroadcast(x, 1, 0) is x
assert unbroadcast(x, 0, 1) is x
x = row()
assert unbroadcast(x, 0) is not x
assert unbroadcast(x, 1) is x
assert unbroadcast(x, 1, 0) is not x
assert unbroadcast(x, 0, 1) is not x
assert unbroadcast(unbroadcast(x, 0), 0).owner.inputs[0] is x
def test_infer_shape(self):
x = matrix()
y = unbroadcast(x, 0)
f = pytensor.function([x], y.shape)
assert (f(np.zeros((2, 5), dtype=config.floatX)) == [2, 5]).all()
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert len(topo) == 3
assert isinstance(topo[0].op, Shape_i)
assert isinstance(topo[1].op, Shape_i)
assert isinstance(topo[2].op, MakeVector)
x = row()
y = unbroadcast(x, 0)
f = pytensor.function([x], y.shape)
assert (f(np.zeros((1, 5), dtype=config.floatX)) == [1, 5]).all()
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert len(topo) == 2
assert isinstance(topo[0].op, Shape_i)
assert isinstance(topo[1].op, MakeVector)
def test_error_checks(self):
with pytest.raises(TypeError, match="needs integer axes"):
Unbroadcast(0.0)
with pytest.raises(ValueError, match="^Trying to unbroadcast"):
Unbroadcast(1)(vector())
class TestUnbroadcastInferShape(utt.InferShapeTester):
def test_basic(self):
rng = np.random.default_rng(3453)
adtens4 = tensor(dtype="float64", shape=(1, 1, 1, None))
adtens4_val = rng.random((1, 1, 1, 3)).astype(config.floatX)
self._compile_and_check(
[adtens4],
[Unbroadcast(0, 2)(adtens4)],
[adtens4_val],
Unbroadcast,
warn=False,
)
def test_shape_tuple(): def test_shape_tuple():
x = Variable(MyType2(), None, None) x = Variable(MyType2(), None, None)
assert shape_tuple(x) == () assert shape_tuple(x) == ()
...@@ -882,16 +820,3 @@ class TestVectorize: ...@@ -882,16 +820,3 @@ class TestVectorize:
match="Invalid number of shape arguments passed into vectorize node of SpecifyShape", match="Invalid number of shape arguments passed into vectorize node of SpecifyShape",
): ):
vectorize_node(node, tns, *(5, 3, 2, x)) vectorize_node(node, tns, *(5, 3, 2, x))
def test_unbroadcast(self):
mat = tensor(
shape=(
1,
1,
)
)
tns = tensor(shape=(4, 1, 1, 1))
node = unbroadcast(mat, 0).owner
vect_node = vectorize_node(node, tns)
assert equal_computations(vect_node.outputs, [unbroadcast(tns, 2)])
...@@ -28,7 +28,6 @@ from pytensor.graph.basic import Apply ...@@ -28,7 +28,6 @@ from pytensor.graph.basic import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.tensor.math import argmax, dot from pytensor.tensor.math import argmax, dot
from pytensor.tensor.math import max as pt_max from pytensor.tensor.math import max as pt_max
from pytensor.tensor.shape import unbroadcast
from pytensor.tensor.type import matrix, vector from pytensor.tensor.type import matrix, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -252,13 +251,6 @@ class TestRopLop(RopLopChecker): ...@@ -252,13 +251,6 @@ class TestRopLop(RopLopChecker):
# vector # vector
self.check_rop_lop(self.x[:4].dimshuffle("x", 0).sum(axis=0), (4,)) self.check_rop_lop(self.x[:4].dimshuffle("x", 0).sum(axis=0), (4,))
def test_unbroadcast(self):
# I need the sum, because the setup expects the output to be a
# vector
self.check_rop_lop(
unbroadcast(self.x[:4].dimshuffle("x", 0), 0).sum(axis=1), (1,)
)
def test_join(self): def test_join(self):
tv = np.asarray(self.rng.uniform(size=(10,)), pytensor.config.floatX) tv = np.asarray(self.rng.uniform(size=(10,)), pytensor.config.floatX)
t = pytensor.shared(tv) t = pytensor.shared(tv)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论