提交 7f8af9bc authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Deprecate remaining uses of Rebroadcast in favor of Unbroadcast

上级 ac52d689
...@@ -147,7 +147,7 @@ from aesara.updates import OrderedUpdates ...@@ -147,7 +147,7 @@ from aesara.updates import OrderedUpdates
def get_scalar_constant_value(v): def get_scalar_constant_value(v):
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`. """Return the constant scalar (i.e. 0-D) value underlying variable `v`.
If `v` is the output of dim-shuffles, fills, allocs, rebroadcasts, cast If `v` is the output of dim-shuffles, fills, allocs, cast, etc.
this function digs through them. this function digs through them.
If ``aesara.sparse`` is also there, we will look over CSM `Op`. If ``aesara.sparse`` is also there, we will look over CSM `Op`.
......
...@@ -204,8 +204,8 @@ def rebuild_collect_shared( ...@@ -204,8 +204,8 @@ def rebuild_collect_shared(
err_sug = ( err_sug = (
"If the difference is related to the broadcast pattern," "If the difference is related to the broadcast pattern,"
" you can call the" " you can call the"
" tensor.unbroadcast(var, axis_to_unbroadcast[, ...])" " tensor.shape.unbroadcast(var, axis_to_unbroadcast[, ...])"
" function to remove broadcastable dimensions." " function to mask broadcastable dimensions."
) )
raise TypeError(err_msg, err_sug) raise TypeError(err_msg, err_sug)
......
...@@ -23,8 +23,7 @@ from aesara.configdefaults import config ...@@ -23,8 +23,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
from aesara.graph.op import _NoPythonOp from aesara.graph.op import _NoPythonOp
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.tensor import basic from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast
from aesara.tensor.shape import Reshape, Shape, SpecifyShape
__docformat__ = "restructedtext en" __docformat__ = "restructedtext en"
...@@ -451,7 +450,7 @@ acceptable_ops = ( ...@@ -451,7 +450,7 @@ acceptable_ops = (
Shape, Shape,
SpecifyShape, SpecifyShape,
Reshape, Reshape,
basic.Rebroadcast, Unbroadcast,
at.math.Dot, at.math.Dot,
at.math.MaxAndArgmax, at.math.MaxAndArgmax,
at.subtensor.Subtensor, at.subtensor.Subtensor,
......
...@@ -29,7 +29,6 @@ from aesara.tensor.basic import ( ...@@ -29,7 +29,6 @@ from aesara.tensor.basic import (
Eye, Eye,
Join, Join,
MakeVector, MakeVector,
Rebroadcast,
ScalarFromTensor, ScalarFromTensor,
TensorFromScalar, TensorFromScalar,
) )
...@@ -50,7 +49,7 @@ from aesara.tensor.math import Dot, MaxAndArgmax ...@@ -50,7 +49,7 @@ from aesara.tensor.math import Dot, MaxAndArgmax
from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
...@@ -347,20 +346,12 @@ def jax_funcify_SpecifyShape(op, **kwargs): ...@@ -347,20 +346,12 @@ def jax_funcify_SpecifyShape(op, **kwargs):
return specifyshape return specifyshape
@jax_funcify.register(Rebroadcast) @jax_funcify.register(Unbroadcast)
def jax_funcify_Rebroadcast(op, **kwargs): def jax_funcify_Unbroadcast(op, **kwargs):
op_axis = op.axis def unbroadcast(x):
def rebroadcast(x):
for axis, value in op_axis.items():
if value and x.shape[axis] != 1:
raise ValueError(
"Dimension %s in Rebroadcast's input was"
" supposed to be 1 (got %s instead)" % (axis, x.shape[axis])
)
return x return x
return rebroadcast return unbroadcast
@jax_funcify.register(ViewOp) @jax_funcify.register(ViewOp)
......
...@@ -14,10 +14,10 @@ from aesara.tensor.basic import ( ...@@ -14,10 +14,10 @@ from aesara.tensor.basic import (
Eye, Eye,
Join, Join,
MakeVector, MakeVector,
Rebroadcast,
ScalarFromTensor, ScalarFromTensor,
TensorFromScalar, TensorFromScalar,
) )
from aesara.tensor.shape import Unbroadcast
@numba_funcify.register(AllocEmpty) @numba_funcify.register(AllocEmpty)
...@@ -195,22 +195,13 @@ def makevector({", ".join(input_names)}): ...@@ -195,22 +195,13 @@ def makevector({", ".join(input_names)}):
return numba_basic.numba_njit(makevector_fn) return numba_basic.numba_njit(makevector_fn)
@numba_funcify.register(Rebroadcast) @numba_funcify.register(Unbroadcast)
def numba_funcify_Rebroadcast(op, **kwargs): def numba_funcify_Unbroadcast(op, **kwargs):
# Make sure op_axis only has ints. This way we can avoid literal_unroll
# which causes a segfault, see GH issue https://github.com/numba/numba/issues/8215
op_axis = tuple((axis, int(value)) for axis, value in op.axis.items())
@numba_basic.numba_njit @numba_basic.numba_njit
def rebroadcast(x): def unbroadcast(x):
for axis, value in op_axis:
if value and x.shape[axis] != 1:
raise ValueError(
("Dimension in Rebroadcast's input was supposed to be 1")
)
return x return x
return rebroadcast return unbroadcast
@numba_funcify.register(TensorFromScalar) @numba_funcify.register(TensorFromScalar)
......
...@@ -14,7 +14,7 @@ from aesara.scan.utils import expand_empty, safe_new, until ...@@ -14,7 +14,7 @@ from aesara.scan.utils import expand_empty, safe_new, until
from aesara.tensor.basic import get_scalar_constant_value from aesara.tensor.basic import get_scalar_constant_value
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import minimum from aesara.tensor.math import minimum
from aesara.tensor.shape import shape_padleft from aesara.tensor.shape import shape_padleft, unbroadcast
from aesara.tensor.type import TensorType, integer_dtypes from aesara.tensor.type import TensorType, integer_dtypes
from aesara.updates import OrderedUpdates from aesara.updates import OrderedUpdates
...@@ -751,7 +751,7 @@ def scan( ...@@ -751,7 +751,7 @@ def scan(
# defined in scan utils # defined in scan utils
sit_sot_scan_inputs.append( sit_sot_scan_inputs.append(
expand_empty( expand_empty(
at.unbroadcast(shape_padleft(actual_arg), 0), unbroadcast(shape_padleft(actual_arg), 0),
actual_n_steps, actual_n_steps,
) )
) )
...@@ -881,7 +881,7 @@ def scan( ...@@ -881,7 +881,7 @@ def scan(
# this will represent only a slice and it will have one # this will represent only a slice and it will have one
# dimension less. # dimension less.
if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1: if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1:
outputs[pos] = at.unbroadcast(shape_padleft(inner_out), 0) outputs[pos] = unbroadcast(shape_padleft(inner_out), 0)
if not return_list and len(outputs) == 1: if not return_list and len(outputs) == 1:
outputs = outputs[0] outputs = outputs[0]
...@@ -1010,7 +1010,7 @@ def scan( ...@@ -1010,7 +1010,7 @@ def scan(
sit_sot_inner_inputs.append(new_var) sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append( sit_sot_scan_inputs.append(
expand_empty( expand_empty(
at.unbroadcast(shape_padleft(input.variable), 0), unbroadcast(shape_padleft(input.variable), 0),
actual_n_steps, actual_n_steps,
) )
) )
......
...@@ -10,7 +10,7 @@ import warnings ...@@ -10,7 +10,7 @@ import warnings
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial from functools import partial
from numbers import Number from numbers import Number
from typing import Dict, Optional, Tuple, Union from typing import Optional, Tuple, Union
from typing import cast as type_cast from typing import cast as type_cast
import numpy as np import numpy as np
...@@ -44,6 +44,7 @@ from aesara.tensor.exceptions import NotScalarConstantError ...@@ -44,6 +44,7 @@ from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.shape import ( from aesara.tensor.shape import (
Shape, Shape,
Shape_i, Shape_i,
Unbroadcast,
shape, shape,
shape_padaxis, shape_padaxis,
shape_padleft, shape_padleft,
...@@ -254,7 +255,7 @@ def get_scalar_constant_value( ...@@ -254,7 +255,7 @@ def get_scalar_constant_value(
): ):
"""Return the constant scalar(0-D) value underlying variable `v`. """Return the constant scalar(0-D) value underlying variable `v`.
If `v` is the output of dimshuffles, fills, allocs, rebroadcasts, If `v` is the output of dimshuffles, fills, allocs, etc,
cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise
and some pattern with Subtensor, this function digs through them. and some pattern with Subtensor, this function digs through them.
...@@ -323,7 +324,7 @@ def get_scalar_constant_value( ...@@ -323,7 +324,7 @@ def get_scalar_constant_value(
( (
Alloc, Alloc,
DimShuffle, DimShuffle,
Rebroadcast, Unbroadcast,
# outputguard is only used in debugmode but we # outputguard is only used in debugmode but we
# keep it here to avoid problems with old pickels. # keep it here to avoid problems with old pickels.
compile.ops.OutputGuard, compile.ops.OutputGuard,
...@@ -495,7 +496,7 @@ def get_scalar_constant_value( ...@@ -495,7 +496,7 @@ def get_scalar_constant_value(
gp_broadcastable = grandparent.type.broadcastable gp_broadcastable = grandparent.type.broadcastable
ndim = grandparent.type.ndim ndim = grandparent.type.ndim
if grandparent.owner and isinstance( if grandparent.owner and isinstance(
grandparent.owner.op, Rebroadcast grandparent.owner.op, Unbroadcast
): ):
ggp_broadcastable = grandparent.owner.inputs[0].broadcastable ggp_broadcastable = grandparent.owner.inputs[0].broadcastable
l = [ l = [
...@@ -616,185 +617,6 @@ class ScalarFromTensor(COp): ...@@ -616,185 +617,6 @@ class ScalarFromTensor(COp):
scalar_from_tensor = ScalarFromTensor() scalar_from_tensor = ScalarFromTensor()
class Rebroadcast(COp):
"""
Change the input's broadcastable fields in some predetermined way.
See Also
--------
unbroadcast <aesara.tensor.unbroadcast>
Notes
-----
Works inplace and works for CudaNdarrayType.
Examples
--------
``Rebroadcast((0, True), (1, False))(x)`` would make `x` broadcastable in
axis 0 and not broadcastable in axis 1.
"""
view_map = {0: [0]}
_f16_ok = True
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version: Dict = {}
check_input = False
__props__ = ("axis",)
_f16_ok = True
def __init__(self, *axis):
# Sort them to make sure we merge all possible case.
items = sorted(axis)
self.axis = dict(items)
for axis, broad in self.axis.items():
if not isinstance(axis, (np.integer, int)):
raise TypeError(f"Rebroadcast needs integer axes. Got {axis}")
if not isinstance(broad, (np.bool_, bool)):
raise TypeError(
f"Rebroadcast needs bool for new broadcast pattern. Got {broad}"
)
def __hash__(self):
# Need special __hash__ as dict aren't hashable.
# no ambiguity because each item key is unique
items = sorted(self.axis.items())
return hash((type(self), tuple(items)))
def __str__(self):
return f"{self.__class__.__name__}{{{','.join(str(i) for i in self.axis.items())}}}"
def make_node(self, x):
if self.axis.keys() and (x.ndim <= max(self.axis.keys())):
raise ValueError("Trying to rebroadcast non-existent dimension")
t = x.type.clone(
shape=[self.axis.get(i, b) for i, b in enumerate(x.type.broadcastable)]
)
return Apply(self, [x], [t()])
def perform(self, node, inp, out_):
(x,) = inp
(out,) = out_
for axis, value in self.axis.items():
if value and x.shape[axis] != 1:
raise ValueError(
f"Dimension {axis} in Rebroadcast's input was"
f" supposed to be 1 (got {x.shape[axis]} instead)"
)
out[0] = x
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
# restore the broadcasting pattern of the input
return (
Rebroadcast(
*[
(axis, x.type.broadcastable[axis])
for axis, value in self.axis.items()
]
)(gz),
)
def infer_shape(self, fgraph, node, ishapes):
assert len(ishapes) == 1
l = []
one = aesara.tensor.basic.constant(1)
for ax in range(len(ishapes[0])):
if self.axis.get(ax, False):
l.append(one)
else:
l.append(ishapes[0][ax])
return [tuple(l)]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return self(*eval_points, return_list=True)
def c_code(self, node, nodename, inp, out, sub):
(iname,) = inp
(oname,) = out
fail = sub["fail"]
itype = node.inputs[0].type.__class__
if itype in self.c_code_and_version:
code, version = self.c_code_and_version[itype]
final_code = ""
for axis, value in self.axis.items():
if value:
final_code += code % locals()
return (
final_code
+ f"""
Py_XDECREF({oname});
{oname} = {iname};
Py_XINCREF({oname});
"""
)
raise NotImplementedError()
def c_code_cache_version(self):
version = []
# If any of the c code is unversioned, we have to return ()
# Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(
self.c_code_and_version.items(), key=lambda pair: str(pair[0])
):
if not v:
warnings.warn(
f"Type {t} has C code for Rebroadcast, but it "
"has no version. You should add a 'version' "
"keyword arg when calling "
"register_rebroadcast_c_code.",
stacklevel=2,
)
return ()
version.append((str(t), v))
if version:
version.append(1)
return tuple(version)
def register_rebroadcast_c_code(typ, code, version=()):
"""
Tell Rebroadcast how to generate C code for an Aesara Type.
typ : Aesara type
It must be the Aesara class itself and not an instance of the class.
code : C code
That checks if the dimension %(axis)s is of shape 1 for the Aesara type
'typ'. Use %(iname)s and %(oname)s for the input and output C variable
names respectively, and %(axis)s for the axis that we need to check.
This code is put in a loop for all axes.
version
A number indicating the version of the code, for cache.
"""
Rebroadcast.c_code_and_version[typ] = (code, version)
register_rebroadcast_c_code(
TensorType,
"""
if(PyArray_DIMS(%(iname)s)[%(axis)s] != 1){
PyErr_Format(PyExc_ValueError,
"Dimension %(axis)s in Rebroadcast's input was"
" supposed to be 1 (got %%d instead)",
PyArray_DIMS(%(iname)s)[%(axis)s]);
%(fail)s
}
""",
version=1,
)
# to be removed as we get the epydoc routine-documenting thing going # to be removed as we get the epydoc routine-documenting thing going
# -JB 20080924 # -JB 20080924
def _conversion(real_value: Op, name: str) -> Op: def _conversion(real_value: Op, name: str) -> Op:
...@@ -2254,36 +2076,6 @@ class Split(COp): ...@@ -2254,36 +2076,6 @@ class Split(COp):
) )
def unbroadcast(x, *axes):
"""
Make the input impossible to broadcast in the specified axes.
For example, unbroadcast(x, 0) will make the first dimension
of x not broadcastable. When performing the function, if the length
of x along that dimension is not 1, a ValueError will be raised.
We apply the opt here not to pollute the graph
Parameters
----------
x : tensor_like
Input aesara tensor.
axis : an int or an iterable object such as list or tuple of int values
The dimension along which the tensor x should be unbroadcastable.
If the length of x along these dimensions is not 1, a ValueError will
be raised.
Returns
-------
tensor
A aesara tensor, which is unbroadcastable along the specified dimensions.
"""
x = as_tensor_variable(x)
rval = Rebroadcast(*[(axis, False) for axis in axes])(x)
return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval)
class Join(COp): class Join(COp):
r""" r"""
Concatenate multiple `TensorVariable`\s along some axis. Concatenate multiple `TensorVariable`\s along some axis.
...@@ -4195,7 +3987,6 @@ __all__ = [ ...@@ -4195,7 +3987,6 @@ __all__ = [
"stack", "stack",
"roll", "roll",
"join", "join",
"unbroadcast",
"split", "split",
"transpose", "transpose",
"extract_constant", "extract_constant",
......
...@@ -48,7 +48,6 @@ from aesara.tensor.basic import ( ...@@ -48,7 +48,6 @@ from aesara.tensor.basic import (
AllocEmpty, AllocEmpty,
Join, Join,
MakeVector, MakeVector,
Rebroadcast,
ScalarFromTensor, ScalarFromTensor,
Split, Split,
TensorFromScalar, TensorFromScalar,
...@@ -77,9 +76,11 @@ from aesara.tensor.shape import ( ...@@ -77,9 +76,11 @@ from aesara.tensor.shape import (
Shape, Shape,
Shape_i, Shape_i,
SpecifyShape, SpecifyShape,
Unbroadcast,
shape_i, shape_i,
shape_padleft, shape_padleft,
specify_shape, specify_shape,
unbroadcast,
) )
from aesara.tensor.sort import TopKOp from aesara.tensor.sort import TopKOp
from aesara.tensor.subtensor import Subtensor, get_idx_list from aesara.tensor.subtensor import Subtensor, get_idx_list
...@@ -2226,10 +2227,13 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): ...@@ -2226,10 +2227,13 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Rebroadcast]) @local_optimizer([Unbroadcast])
def local_useless_rebroadcast(fgraph, node): def local_useless_unbroadcast(fgraph, node):
"""Remove `Rebroadcast` if it does not actually change the broadcasting pattern.""" """Remove `Unbroadcast` if it does not actually change the broadcasting pattern.
if isinstance(node.op, Rebroadcast):
TODO: Implement equivalent rewrite for SpecifyShape
"""
if isinstance(node.op, Unbroadcast):
x = node.inputs[0] x = node.inputs[0]
if x.broadcastable == node.outputs[0].broadcastable: if x.broadcastable == node.outputs[0].broadcastable:
# No broadcastable flag was modified # No broadcastable flag was modified
...@@ -2238,15 +2242,12 @@ def local_useless_rebroadcast(fgraph, node): ...@@ -2238,15 +2242,12 @@ def local_useless_rebroadcast(fgraph, node):
return [x] return [x]
else: else:
# Keep the flags that modify something # Keep the flags that modify something
new_axis = {} new_axes = tuple(ax for ax in node.op.axes if x.type.shape[ax] == 1)
for dim, bc in node.op.axis.items(): if new_axes == node.op.axes:
if x.broadcastable[dim] != bc:
new_axis[dim] = bc
if new_axis == node.op.axis:
# All flags are useful # All flags are useful
return return None
else: else:
r = Rebroadcast(*new_axis.items())(x) r = unbroadcast(x, *new_axes)
# Copy over stacktrace from previous output # Copy over stacktrace from previous output
copy_stack_trace(node.outputs, r) copy_stack_trace(node.outputs, r)
return [r] return [r]
...@@ -2254,93 +2255,49 @@ def local_useless_rebroadcast(fgraph, node): ...@@ -2254,93 +2255,49 @@ def local_useless_rebroadcast(fgraph, node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@local_optimizer([Rebroadcast]) @local_optimizer([Unbroadcast])
def local_rebroadcast_lift(fgraph, node): def local_unbroadcast_lift(fgraph, node):
""" """
Lifts Rebroadcast through unary Elemwise operations, Lifts `Unbroadcast` through unary Elemwise operations,
and merges consecutive Rebroadcasts. and merges consecutive `Unbroadcast`s.
Rebroadcast(Elemwise(x)) => Elemwise(Rebroadcast(x)) Unbroadcast(Elemwise(x)) => Elemwise(Unbroadcast(x))
Rebroadcast(Rebroadcast(x)) => Rebroadcast(x) Unbroadcast(Unbroadcast(x)) => Unbroadcast(x)
TODO: Implement equivalent Elemwise lift for SpecifyShape
""" """
op = node.op op = node.op
if not isinstance(op, Rebroadcast): if not isinstance(op, Unbroadcast):
return False return False
inp = node.inputs[0] inp = node.inputs[0]
inode = inp.owner inode = inp.owner
if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1: if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1:
# It may happen that `input` has no client because this optimization
# is called from `apply_rebroadcast_opt`, which in particular is used
# by the `unbroadcast` function before we are in the actual function
# compilation phase.
if len(fgraph.clients.get(inp, ())) == 1: if len(fgraph.clients.get(inp, ())) == 1:
rebroadcasted = Rebroadcast(*list(op.axis.items()))(inode.inputs[0]) unbroadcasted = unbroadcast(inode.inputs[0], *op.axes)
# Copy over stacktrace from previous output (after rebroadcasting) copy_stack_trace(node.outputs, unbroadcasted)
# to new output, because an error in the new graph right after
# rebroadcasting must have been caused by the previous rebroadcasting.
copy_stack_trace(node.outputs, rebroadcasted)
rval = inode.op.make_node(rebroadcasted).outputs rval = inode.op.make_node(unbroadcasted).outputs
# Copy over stacktrace from previous output (after rebroadcasting) # Copy over stacktrace from previous output (after unbroadcasting)
# and input (after elemwise operation) to new output, because an # and input (after elemwise operation) to new output, because an
# error in the new graph could have been caused by either of the # error in the new graph could have been caused by either of the
# two ops. # two ops.
copy_stack_trace(node.outputs + node.inputs, rval) copy_stack_trace(node.outputs + node.inputs, rval)
return rval return rval
if inode and isinstance(inode.op, Rebroadcast):
# the "axis" specification in the outer Rebroadcast overrides
# the axis of the inner one
axis = inode.op.axis.copy()
axis.update(op.axis)
iinput = inode.inputs[0]
rval = [Rebroadcast(*list(axis.items()))(iinput)]
# Copy over stacktrace from previous output (after second rebroadcast) if inode and isinstance(inode.op, Unbroadcast):
# and from previous input (after first rebroadcast op) because an error in # Merge axis of each unbroadcast
# the new graph could have been caused by either of the two axis = tuple(set(inode.op.axes).union(set(op.axes)))
# rebroadcast ops. iinput = inode.inputs[0]
rval = [unbroadcast(iinput, *axis)]
# Copy over stacktrace from previous output (after second unbroadcasting)
# and from previous input (after first unbroadcasting) because an error in
# the new graph could have been caused by either of the two Unbroadcast ops.
copy_stack_trace(node.outputs + node.inputs, rval) copy_stack_trace(node.outputs + node.inputs, rval)
return rval return rval
def apply_rebroadcast_opt(rval):
"""
Apply as many times as required the optimization local_useless_rebroadcast
and local_rebroadcast_lift.
Parameters
----------
rval: a Variable
Returns
-------
A Variable (the same if no optimization can be applied)
"""
fg = FunctionGraph([], [])
changed = True
while changed and rval.owner:
changed = False
rval2 = local_useless_rebroadcast.transform(fg, rval.owner)
if rval2:
assert len(rval2) == 1
rval = rval2[0]
changed = True
if rval.owner:
rval2 = local_rebroadcast_lift.transform(fg, rval.owner)
if rval2:
assert len(rval2) == 1
rval = rval2[0]
changed = True
return rval
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@register_useless @register_useless
......
...@@ -926,3 +926,108 @@ def specify_broadcastable(x, *axes): ...@@ -926,3 +926,108 @@ def specify_broadcastable(x, *axes):
shape_info = [1 if i in axes else None for i in range(len(x.type.shape))] shape_info = [1 if i in axes else None for i in range(len(x.type.shape))]
return specify_shape(x, shape_info) return specify_shape(x, shape_info)
class Unbroadcast(COp):
"""
Mask static broadcastable dimensions of input as `None`
See Also
--------
unbroadcast <aesara.tensor.shape.unbroadcast>
Examples
--------
``Unbroadcast((1,))(x)`` would make `x` second static dimension be `None`
"""
view_map = {0: [0]}
_f16_ok = True
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version: Dict = {}
check_input = False
__props__ = ("axes",)
_f16_ok = True
def __init__(self, *axis):
# Sort them to make sure we merge all possible case.
items = tuple(sorted(axis))
self.axes = items
for axis in self.axes:
if not isinstance(axis, (np.integer, int)):
raise TypeError(f"Unbroadcast needs integer axes. Got {axis}")
def __str__(self):
return f"{self.__class__.__name__}{{{','.join(str(i) for i in self.axes)}}}"
def make_node(self, x):
x = as_tensor_variable(x)
if x.type.ndim <= max(self.axes):
raise ValueError("Trying to unbroadcast of non-existent dimension")
shape = [
None if (sh == 1 and i in self.axes) else sh
for i, sh in enumerate(x.type.shape)
]
return Apply(self, [x], [x.type.clone(shape=shape)()])
def perform(self, node, inp, out_):
(x,) = inp
(out,) = out_
out[0] = x
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
# restore the broadcasting pattern of the input
return [specify_shape(gz, x.type.shape)]
def infer_shape(self, fgraph, node, ishapes):
assert len(ishapes) == 1
return [tuple(ishapes[0])]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return self(*eval_points, return_list=True)
def c_code(self, node, nodename, inp, out, sub):
(iname,) = inp
(oname,) = out
return f"""
Py_XDECREF({oname});
{oname} = {iname};
Py_XINCREF({oname});
"""
def c_code_cache_version(self):
return (3,)
def unbroadcast(x, *axes):
"""
Mask static broadcastable dimensions of input as `None`
Parameters
----------
x : tensor_like
Input aesara tensor.
axis : an int or an iterable object such as list or tuple of int values
The broadcastable dimensions of x that should be unbroadcasted.
Returns
-------
tensor
A aesara tensor, with static broadcastable dimensions masked as `None`
"""
x = as_tensor_variable(x)
unbroadcasted_axes = [axis for axis in axes if x.type.shape[axis] == 1]
if not unbroadcasted_axes:
return x
return Unbroadcast(*unbroadcasted_axes)(x)
...@@ -14,7 +14,6 @@ from aesara.tensor.basic import ( ...@@ -14,7 +14,6 @@ from aesara.tensor.basic import (
ARange, ARange,
Join, Join,
MakeVector, MakeVector,
Rebroadcast,
ScalarFromTensor, ScalarFromTensor,
TensorFromScalar, TensorFromScalar,
alloc, alloc,
...@@ -50,9 +49,11 @@ from aesara.tensor.math import ( ...@@ -50,9 +49,11 @@ from aesara.tensor.math import (
from aesara.tensor.shape import ( from aesara.tensor.shape import (
Shape, Shape,
SpecifyShape, SpecifyShape,
Unbroadcast,
shape_padleft, shape_padleft,
shape_tuple, shape_tuple,
specify_shape, specify_shape,
unbroadcast,
) )
from aesara.tensor.sharedvar import TensorSharedVariable from aesara.tensor.sharedvar import TensorSharedVariable
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
...@@ -370,7 +371,7 @@ def local_subtensor_lift(fgraph, node): ...@@ -370,7 +371,7 @@ def local_subtensor_lift(fgraph, node):
Handles the following unary ops: Handles the following unary ops:
elemwise(x,...)[idx] -> elemwise(x[idx],...) elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all when x,... are broadcasted scalar or not broadcasted at all
rebroadcast(x)[idx] => rebroadcast(x[idx]) Unbroadcast(x)[idx] => Unbroadcast(x[idx])
""" """
if isinstance(node.op, Subtensor): if isinstance(node.op, Subtensor):
...@@ -429,34 +430,34 @@ def local_subtensor_lift(fgraph, node): ...@@ -429,34 +430,34 @@ def local_subtensor_lift(fgraph, node):
copy_stack_trace([node.outputs[0], node.inputs[0]], ret) copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
return [ret] return [ret]
if isinstance(u.owner.op, Rebroadcast): if isinstance(u.owner.op, Unbroadcast):
# make sure that Rebroadcast has only 1 input
assert len(u.owner.inputs) == 1
# Subtensor might reduce dim., adapt broadcast pattern accordingly # Subtensor might reduce dim., adapt broadcast pattern accordingly
new_axis = [] old_axes = u.owner.op.axes
new_axes = []
# loop through indices being subtensor-ed # loop through indices being subtensor-ed
# i indexes broadcastable pattern before subtensor # i indexes broadcastable pattern before subtensor
# j indexes broadcastable pattern after subtensor # j indexes broadcastable pattern after subtensor
j = 0 j = 0
for (i, x) in enumerate(node.op.idx_list): for (i, x) in enumerate(node.op.idx_list):
# if its not a slice, it will reduce the dimension, should # if it is not a slice, it will reduce the dimension, should
# not appear in the broascastable dimensions # not appear in the broascastable dimensions
if isinstance(x, slice): if isinstance(x, slice):
new_axis += [(j, u.broadcastable[i])] if i in old_axes:
new_axes.append(j)
j += 1 j += 1
# now keep the broadcastable pattern of all # now keep the broadcastable pattern of all
# items not appearing in subtensor list # items not appearing in subtensor list
for i in range(len(node.op.idx_list), len(u.broadcastable)): for i in range(len(node.op.idx_list), len(u.broadcastable)):
new_axis += [(j, u.broadcastable[i])] if i in old_axes:
new_axes.append(j)
j += 1 j += 1
subt_x = node.op(u.owner.inputs[0], *node.inputs[1:]) subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
# Copy over previous output stacktrace # Copy over previous output stacktrace
copy_stack_trace(node.outputs[0], subt_x) copy_stack_trace(node.outputs[0], subt_x)
rbcast_subt_x = Rebroadcast(*new_axis)(subt_x) rbcast_subt_x = unbroadcast(subt_x, *new_axes)
# Copy over previous output stacktrace # Copy over previous output stacktrace
# and stacktrace from previous unary operation # and stacktrace from previous unary operation
copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x) copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x)
......
...@@ -39,7 +39,7 @@ from aesara.tensor.math import sum as at_sum ...@@ -39,7 +39,7 @@ from aesara.tensor.math import sum as at_sum
from aesara.tensor.nnet.basic import SoftmaxGrad from aesara.tensor.nnet.basic import SoftmaxGrad
from aesara.tensor.random.basic import RandomVariable, normal from aesara.tensor.random.basic import RandomVariable, normal
from aesara.tensor.random.utils import RandomStream from aesara.tensor.random.utils import RandomStream
from aesara.tensor.shape import Shape, Shape_i, SpecifyShape, reshape from aesara.tensor.shape import Shape, Shape_i, SpecifyShape, Unbroadcast, reshape
from aesara.tensor.type import ( from aesara.tensor.type import (
dscalar, dscalar,
dvector, dvector,
...@@ -201,20 +201,11 @@ def test_jax_compile_ops(): ...@@ -201,20 +201,11 @@ def test_jax_compile_ops():
compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
x_np = np.zeros((20, 1, 1)) x_np = np.zeros((20, 1, 1))
x = at.Rebroadcast((0, False), (1, True), (2, False))(at.as_tensor_variable(x_np)) x = Unbroadcast(0, 2)(at.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, []) compare_jax_and_py(x_fg, [])
with config.change_flags(compute_test_value="off"):
x = at.Rebroadcast((0, True), (1, False), (2, False))(
at.as_tensor_variable(x_np)
)
x_fg = FunctionGraph([], [x])
with pytest.raises(ValueError):
compare_jax_and_py(x_fg, [])
x = ViewOp()(at.as_tensor_variable(x_np)) x = ViewOp()(at.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([], [x])
......
...@@ -40,7 +40,7 @@ from aesara.tensor import extra_ops, nlinalg, slinalg ...@@ -40,7 +40,7 @@ from aesara.tensor import extra_ops, nlinalg, slinalg
from aesara.tensor import subtensor as at_subtensor from aesara.tensor import subtensor as at_subtensor
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
class MyType(Type): class MyType(Type):
...@@ -769,31 +769,10 @@ def test_ScalarFromTensor(v): ...@@ -769,31 +769,10 @@ def test_ScalarFromTensor(v):
) )
@pytest.mark.parametrize( def test_Unbroadcast():
"v, axis, fails", v = set_test_value(at.row(), np.array([[1.0, 2.0]], dtype=config.floatX))
[ g = Unbroadcast(0)(v)
(
set_test_value(at.matrix(), np.array([[1.0]], dtype=config.floatX)),
[(0, True), (1, True)],
False,
),
(
set_test_value(at.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
[(0, True), (1, False)],
False,
),
(
set_test_value(at.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
[(0, True), (1, True)],
True,
),
],
)
def test_Rebroadcast(v, axis, fails):
g = atb.Rebroadcast(*axis)(v)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if not fails else pytest.raises(ValueError)
with cm:
compare_numba_and_py( compare_numba_and_py(
g_fg, g_fg,
[ [
......
...@@ -36,7 +36,7 @@ def test_debugprint_sitsot(): ...@@ -36,7 +36,7 @@ def test_debugprint_sitsot():
| | | | | |k [id D] | | | | | |k [id D]
| | | | | |Subtensor{int64} [id H] | | | | | |Subtensor{int64} [id H]
| | | | | |Shape [id I] | | | | | |Shape [id I]
| | | | | | |Rebroadcast{(0, False)} [id J] | | | | | | |Unbroadcast{0} [id J]
| | | | | | |InplaceDimShuffle{x,0} [id K] | | | | | | |InplaceDimShuffle{x,0} [id K]
| | | | | | |Elemwise{second,no_inplace} [id L] | | | | | | |Elemwise{second,no_inplace} [id L]
| | | | | | |A [id M] | | | | | | |A [id M]
...@@ -45,9 +45,9 @@ def test_debugprint_sitsot(): ...@@ -45,9 +45,9 @@ def test_debugprint_sitsot():
| | | | | |ScalarConstant{0} [id P] | | | | | |ScalarConstant{0} [id P]
| | | | |Subtensor{int64} [id Q] | | | | |Subtensor{int64} [id Q]
| | | | |Shape [id R] | | | | |Shape [id R]
| | | | | |Rebroadcast{(0, False)} [id J] | | | | | |Unbroadcast{0} [id J]
| | | | |ScalarConstant{1} [id S] | | | | |ScalarConstant{1} [id S]
| | | |Rebroadcast{(0, False)} [id J] | | | |Unbroadcast{0} [id J]
| | | |ScalarFromTensor [id T] | | | |ScalarFromTensor [id T]
| | | |Subtensor{int64} [id H] | | | |Subtensor{int64} [id H]
| | |A [id M] (outer_in_non_seqs-0) | | |A [id M] (outer_in_non_seqs-0)
...@@ -91,7 +91,7 @@ def test_debugprint_sitsot_no_extra_info(): ...@@ -91,7 +91,7 @@ def test_debugprint_sitsot_no_extra_info():
| | | | | |k [id D] | | | | | |k [id D]
| | | | | |Subtensor{int64} [id H] | | | | | |Subtensor{int64} [id H]
| | | | | |Shape [id I] | | | | | |Shape [id I]
| | | | | | |Rebroadcast{(0, False)} [id J] | | | | | | |Unbroadcast{0} [id J]
| | | | | | |InplaceDimShuffle{x,0} [id K] | | | | | | |InplaceDimShuffle{x,0} [id K]
| | | | | | |Elemwise{second,no_inplace} [id L] | | | | | | |Elemwise{second,no_inplace} [id L]
| | | | | | |A [id M] | | | | | | |A [id M]
...@@ -100,9 +100,9 @@ def test_debugprint_sitsot_no_extra_info(): ...@@ -100,9 +100,9 @@ def test_debugprint_sitsot_no_extra_info():
| | | | | |ScalarConstant{0} [id P] | | | | | |ScalarConstant{0} [id P]
| | | | |Subtensor{int64} [id Q] | | | | |Subtensor{int64} [id Q]
| | | | |Shape [id R] | | | | |Shape [id R]
| | | | | |Rebroadcast{(0, False)} [id J] | | | | | |Unbroadcast{0} [id J]
| | | | |ScalarConstant{1} [id S] | | | | |ScalarConstant{1} [id S]
| | | |Rebroadcast{(0, False)} [id J] | | | |Unbroadcast{0} [id J]
| | | |ScalarFromTensor [id T] | | | |ScalarFromTensor [id T]
| | | |Subtensor{int64} [id H] | | | |Subtensor{int64} [id H]
| | |A [id M] | | |A [id M]
...@@ -261,7 +261,7 @@ def test_debugprint_nested_scans(): ...@@ -261,7 +261,7 @@ def test_debugprint_nested_scans():
> | | | | | | |*3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) > | | | | | | |*3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1)
> | | | | | | |Subtensor{int64} [id BJ] > | | | | | | |Subtensor{int64} [id BJ]
> | | | | | | |Shape [id BK] > | | | | | | |Shape [id BK]
> | | | | | | | |Rebroadcast{(0, False)} [id BL] > | | | | | | | |Unbroadcast{0} [id BL]
> | | | | | | | |InplaceDimShuffle{x,0} [id BM] > | | | | | | | |InplaceDimShuffle{x,0} [id BM]
> | | | | | | | |Elemwise{second,no_inplace} [id BN] > | | | | | | | |Elemwise{second,no_inplace} [id BN]
> | | | | | | | |*2-<TensorType(float64, (None,))> [id BO] -> [id W] (inner_in_non_seqs-0) > | | | | | | | |*2-<TensorType(float64, (None,))> [id BO] -> [id W] (inner_in_non_seqs-0)
...@@ -270,9 +270,9 @@ def test_debugprint_nested_scans(): ...@@ -270,9 +270,9 @@ def test_debugprint_nested_scans():
> | | | | | | |ScalarConstant{0} [id BR] > | | | | | | |ScalarConstant{0} [id BR]
> | | | | | |Subtensor{int64} [id BS] > | | | | | |Subtensor{int64} [id BS]
> | | | | | |Shape [id BT] > | | | | | |Shape [id BT]
> | | | | | | |Rebroadcast{(0, False)} [id BL] > | | | | | | |Unbroadcast{0} [id BL]
> | | | | | |ScalarConstant{1} [id BU] > | | | | | |ScalarConstant{1} [id BU]
> | | | | |Rebroadcast{(0, False)} [id BL] > | | | | |Unbroadcast{0} [id BL]
> | | | | |ScalarFromTensor [id BV] > | | | | |ScalarFromTensor [id BV]
> | | | | |Subtensor{int64} [id BJ] > | | | | |Subtensor{int64} [id BJ]
> | | | |*2-<TensorType(float64, (None,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0) > | | | |*2-<TensorType(float64, (None,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
...@@ -350,7 +350,7 @@ def test_debugprint_nested_scans(): ...@@ -350,7 +350,7 @@ def test_debugprint_nested_scans():
> | | | | | | |*3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1) > | | | | | | |*3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1)
> | | | | | | |Subtensor{int64} [id BL] > | | | | | | |Subtensor{int64} [id BL]
> | | | | | | |Shape [id BM] > | | | | | | |Shape [id BM]
> | | | | | | | |Rebroadcast{(0, False)} [id BN] > | | | | | | | |Unbroadcast{0} [id BN]
> | | | | | | | |InplaceDimShuffle{x,0} [id BO] > | | | | | | | |InplaceDimShuffle{x,0} [id BO]
> | | | | | | | |Elemwise{second,no_inplace} [id BP] > | | | | | | | |Elemwise{second,no_inplace} [id BP]
> | | | | | | | |*2-<TensorType(float64, (None,))> [id BA] (inner_in_non_seqs-0) > | | | | | | | |*2-<TensorType(float64, (None,))> [id BA] (inner_in_non_seqs-0)
...@@ -359,9 +359,9 @@ def test_debugprint_nested_scans(): ...@@ -359,9 +359,9 @@ def test_debugprint_nested_scans():
> | | | | | | |ScalarConstant{0} [id BS] > | | | | | | |ScalarConstant{0} [id BS]
> | | | | | |Subtensor{int64} [id BT] > | | | | | |Subtensor{int64} [id BT]
> | | | | | |Shape [id BU] > | | | | | |Shape [id BU]
> | | | | | | |Rebroadcast{(0, False)} [id BN] > | | | | | | |Unbroadcast{0} [id BN]
> | | | | | |ScalarConstant{1} [id BV] > | | | | | |ScalarConstant{1} [id BV]
> | | | | |Rebroadcast{(0, False)} [id BN] > | | | | |Unbroadcast{0} [id BN]
> | | | | |ScalarFromTensor [id BW] > | | | | |ScalarFromTensor [id BW]
> | | | | |Subtensor{int64} [id BL] > | | | | |Subtensor{int64} [id BL]
> | | | |*2-<TensorType(float64, (None,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0) > | | | |*2-<TensorType(float64, (None,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
...@@ -487,7 +487,7 @@ def test_debugprint_mitmot(): ...@@ -487,7 +487,7 @@ def test_debugprint_mitmot():
| | | | | | | |k [id G] | | | | | | | |k [id G]
| | | | | | | |Subtensor{int64} [id K] | | | | | | | |Subtensor{int64} [id K]
| | | | | | | |Shape [id L] | | | | | | | |Shape [id L]
| | | | | | | | |Rebroadcast{(0, False)} [id M] | | | | | | | | |Unbroadcast{0} [id M]
| | | | | | | | |InplaceDimShuffle{x,0} [id N] | | | | | | | | |InplaceDimShuffle{x,0} [id N]
| | | | | | | | |Elemwise{second,no_inplace} [id O] | | | | | | | | |Elemwise{second,no_inplace} [id O]
| | | | | | | | |A [id P] | | | | | | | | |A [id P]
...@@ -496,9 +496,9 @@ def test_debugprint_mitmot(): ...@@ -496,9 +496,9 @@ def test_debugprint_mitmot():
| | | | | | | |ScalarConstant{0} [id S] | | | | | | | |ScalarConstant{0} [id S]
| | | | | | |Subtensor{int64} [id T] | | | | | | |Subtensor{int64} [id T]
| | | | | | |Shape [id U] | | | | | | |Shape [id U]
| | | | | | | |Rebroadcast{(0, False)} [id M] | | | | | | | |Unbroadcast{0} [id M]
| | | | | | |ScalarConstant{1} [id V] | | | | | | |ScalarConstant{1} [id V]
| | | | | |Rebroadcast{(0, False)} [id M] | | | | | |Unbroadcast{0} [id M]
| | | | | |ScalarFromTensor [id W] | | | | | |ScalarFromTensor [id W]
| | | | | |Subtensor{int64} [id K] | | | | | |Subtensor{int64} [id K]
| | | | |A [id P] (outer_in_non_seqs-0) | | | | |A [id P] (outer_in_non_seqs-0)
......
...@@ -34,7 +34,6 @@ from aesara.tensor.basic import ( ...@@ -34,7 +34,6 @@ from aesara.tensor.basic import (
Join, Join,
MakeVector, MakeVector,
PermuteRowElements, PermuteRowElements,
Rebroadcast,
ScalarFromTensor, ScalarFromTensor,
Split, Split,
TensorFromScalar, TensorFromScalar,
...@@ -86,7 +85,6 @@ from aesara.tensor.basic import ( ...@@ -86,7 +85,6 @@ from aesara.tensor.basic import (
triu, triu,
triu_indices, triu_indices,
triu_indices_from, triu_indices_from,
unbroadcast,
vertical_stack, vertical_stack,
zeros_like, zeros_like,
) )
...@@ -104,7 +102,6 @@ from aesara.tensor.type import ( ...@@ -104,7 +102,6 @@ from aesara.tensor.type import (
dscalar, dscalar,
dscalars, dscalars,
dtensor3, dtensor3,
dtensor4,
dvector, dvector,
fmatrix, fmatrix,
fscalar, fscalar,
...@@ -337,7 +334,7 @@ TestAllocb4GradBroadcast = makeBroadcastTester( ...@@ -337,7 +334,7 @@ TestAllocb4GradBroadcast = makeBroadcastTester(
) )
# Partial un broadcast of a dimshuffled input # Partial unbroadcast of a dimshuffled input
TestAllocDimshuffleGradBroadcast = makeBroadcastTester( TestAllocDimshuffleGradBroadcast = makeBroadcastTester(
name="Allocb4GradTester", name="Allocb4GradTester",
op=lambda x: alloc(x.dimshuffle("x", "x", 0), 1, s2, s3), op=lambda x: alloc(x.dimshuffle("x", "x", 0), 1, s2, s3),
...@@ -3223,80 +3220,6 @@ class TestLongTensor: ...@@ -3223,80 +3220,6 @@ class TestLongTensor:
constant()[[val, val]] constant()[[val, val]]
class TestBroadcast:
def test_unbroadcast(self):
# test that the unbroadcast fct don't insert not needed broadcast
# and fuse consecutive Rebroadcast op
x = matrix()
assert unbroadcast(x, 0) is x
assert unbroadcast(x, 1) is x
assert unbroadcast(x, 1, 0) is x
assert unbroadcast(x, 0, 1) is x
x = row()
assert unbroadcast(x, 0) is not x
assert unbroadcast(x, 1) is x
assert unbroadcast(x, 1, 0) is not x
assert unbroadcast(x, 0, 1) is not x
# The first broadcast is remove the broadcast, so the second
# should not make one
assert unbroadcast(unbroadcast(x, 0), 0).owner.inputs[0] is x
# Test that consecutive Rebroadcast op are fused
x = TensorType(dtype="float64", shape=(True, True))()
assert unbroadcast(unbroadcast(x, 1), 0).owner.inputs[0] is x
def test_infer_shape(self):
x = matrix()
y = unbroadcast(x, 0)
f = aesara.function([x], y.shape)
assert (f(np.zeros((2, 5), dtype=config.floatX)) == [2, 5]).all()
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert len(topo) == 3
assert isinstance(topo[0].op, Shape_i)
assert isinstance(topo[1].op, Shape_i)
assert isinstance(topo[2].op, MakeVector)
x = row()
y = unbroadcast(x, 0)
f = aesara.function([x], y.shape)
assert (f(np.zeros((1, 5), dtype=config.floatX)) == [1, 5]).all()
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert len(topo) == 2
assert isinstance(topo[0].op, Shape_i)
assert isinstance(topo[1].op, MakeVector)
class TestRebroadcast(utt.InferShapeTester):
def test_rebroadcast(self):
rng = np.random.default_rng(3453)
# Rebroadcast
adtens4 = dtensor4()
adict = [(0, False), (1, True), (2, False), (3, True)]
adtens4_val = rng.random((2, 1, 3, 1)).astype(config.floatX)
self._compile_and_check(
[adtens4],
[Rebroadcast(*adict)(adtens4)],
[adtens4_val],
Rebroadcast,
warn=False,
)
adtens4_bro = TensorType("float64", (True, True, True, False))()
bdict = [(0, True), (1, False), (2, False), (3, False)]
adtens4_bro_val = rng.random((1, 1, 1, 3)).astype(config.floatX)
self._compile_and_check(
[adtens4_bro],
[Rebroadcast(*bdict)(adtens4_bro)],
[adtens4_bro_val],
Rebroadcast,
)
def test_len(): def test_len():
for shape_ in [(5,), (3, 4), (7, 4, 6)]: for shape_ in [(5,), (3, 4), (7, 4, 6)]:
x = tensor(dtype="floatX", shape=(False,) * len(shape_)) x = tensor(dtype="floatX", shape=(False,) * len(shape_))
......
...@@ -28,7 +28,6 @@ from aesara.tensor.basic import ( ...@@ -28,7 +28,6 @@ from aesara.tensor.basic import (
Alloc, Alloc,
Join, Join,
MakeVector, MakeVector,
Rebroadcast,
ScalarFromTensor, ScalarFromTensor,
Split, Split,
TensorFromScalar, TensorFromScalar,
...@@ -40,7 +39,6 @@ from aesara.tensor.basic import ( ...@@ -40,7 +39,6 @@ from aesara.tensor.basic import (
) )
from aesara.tensor.basic_opt import ( from aesara.tensor.basic_opt import (
ShapeFeature, ShapeFeature,
apply_rebroadcast_opt,
assert_op, assert_op,
local_alloc_sink_dimshuffle, local_alloc_sink_dimshuffle,
local_dimshuffle_lift, local_dimshuffle_lift,
...@@ -92,9 +90,11 @@ from aesara.tensor.shape import ( ...@@ -92,9 +90,11 @@ from aesara.tensor.shape import (
Reshape, Reshape,
Shape_i, Shape_i,
SpecifyShape, SpecifyShape,
Unbroadcast,
reshape, reshape,
shape, shape,
specify_shape, specify_shape,
unbroadcast,
) )
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
...@@ -1898,18 +1898,46 @@ class TestTile: ...@@ -1898,18 +1898,46 @@ class TestTile:
f(data) f(data)
class TestRebroadcast: class TestUnbroadcast:
def test_local_useless_rebroadcast(self): def setup_method(self):
mode = get_default_mode().including("canonicalize") self.mode = get_default_mode().including("canonicalize")
v1 = vector()
v2 = vector()
j = at.join(0, v1, v2)
f = function([v1, v2], j, mode=mode)
f([1, 2], [3, 4, 5])
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Rebroadcast)]) == 0
assert check_stack_trace(f, ops_to_check="all") def test_local_useless_unbroadcast(self):
x1 = tensor("float64", shape=(1, 2))
x2 = tensor("float64", shape=(2, 1))
unbroadcast_op = Unbroadcast(0)
f = function([x1], unbroadcast_op(x1), mode=self.mode)
assert (
sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort())
== 1
)
f = function([x2], unbroadcast_op(x2), mode=self.mode)
assert (
sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort())
== 0
)
def test_local_unbroadcast_lift(self):
x = tensor("float64", shape=(1, 1))
y = unbroadcast(at.exp(unbroadcast(x, 0)), 1)
assert (
sum(
isinstance(node.op, Unbroadcast)
for node in FunctionGraph([x], [y], copy_inputs=False).toposort()
)
== 2
)
f = function([x], y, mode=self.mode)
assert (
sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort())
== 1
)
np.testing.assert_almost_equal(f([[1]]), np.exp([[1]]))
class TestUselessElemwise: class TestUselessElemwise:
...@@ -3167,21 +3195,6 @@ def test_local_useless_alloc(): ...@@ -3167,21 +3195,6 @@ def test_local_useless_alloc():
assert isinstance(topo[-1].op, Alloc) assert isinstance(topo[-1].op, Alloc)
def test_apply_rebroadcast_opt():
# Test the `Elemwise` case in `local_rebroadcast_lift` with `fgraph=None`.
# This is called by in `apply_rebroadcast_opt`.
a = vector(dtype="float32")
b = tensor("float64", [True])
x = b.astype(a.dtype)
broadcastable = (False,)
axis = [(i, broadcastable[i]) for i in range(len(broadcastable))]
rval = Rebroadcast(*axis)(x)
res = apply_rebroadcast_opt(rval)
assert res is rval
@pytest.mark.parametrize("return_index", [False]) @pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False]) @pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False]) @pytest.mark.parametrize("return_inverse", [False])
......
...@@ -17,12 +17,14 @@ from aesara.tensor.shape import ( ...@@ -17,12 +17,14 @@ from aesara.tensor.shape import (
Reshape, Reshape,
Shape_i, Shape_i,
SpecifyShape, SpecifyShape,
Unbroadcast,
_specify_shape, _specify_shape,
reshape, reshape,
shape, shape,
shape_i, shape_i,
specify_broadcastable, specify_broadcastable,
specify_shape, specify_shape,
unbroadcast,
) )
from aesara.tensor.subtensor import Subtensor from aesara.tensor.subtensor import Subtensor
from aesara.tensor.type import ( from aesara.tensor.type import (
...@@ -36,6 +38,7 @@ from aesara.tensor.type import ( ...@@ -36,6 +38,7 @@ from aesara.tensor.type import (
lscalar, lscalar,
matrix, matrix,
scalar, scalar,
tensor,
tensor3, tensor3,
vector, vector,
) )
...@@ -594,3 +597,63 @@ def test_get_vector_length(): ...@@ -594,3 +597,63 @@ def test_get_vector_length():
# Test `SpecifyShape` # Test `SpecifyShape`
x = specify_shape(ivector(), (10,)) x = specify_shape(ivector(), (10,))
assert get_vector_length(x) == 10 assert get_vector_length(x) == 10
class TestUnbroadcast:
def test_basic(self):
x = matrix()
assert unbroadcast(x, 0) is x
assert unbroadcast(x, 1) is x
assert unbroadcast(x, 1, 0) is x
assert unbroadcast(x, 0, 1) is x
x = row()
assert unbroadcast(x, 0) is not x
assert unbroadcast(x, 1) is x
assert unbroadcast(x, 1, 0) is not x
assert unbroadcast(x, 0, 1) is not x
assert unbroadcast(unbroadcast(x, 0), 0).owner.inputs[0] is x
def test_infer_shape(self):
x = matrix()
y = unbroadcast(x, 0)
f = aesara.function([x], y.shape)
assert (f(np.zeros((2, 5), dtype=config.floatX)) == [2, 5]).all()
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert len(topo) == 3
assert isinstance(topo[0].op, Shape_i)
assert isinstance(topo[1].op, Shape_i)
assert isinstance(topo[2].op, MakeVector)
x = row()
y = unbroadcast(x, 0)
f = aesara.function([x], y.shape)
assert (f(np.zeros((1, 5), dtype=config.floatX)) == [1, 5]).all()
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert len(topo) == 2
assert isinstance(topo[0].op, Shape_i)
assert isinstance(topo[1].op, MakeVector)
def test_error_checks(self):
with pytest.raises(TypeError, match="needs integer axes"):
Unbroadcast(0.0)
with pytest.raises(ValueError, match="^Trying to unbroadcast"):
Unbroadcast(1)(vector())
class TestUnbroadcastInferShape(utt.InferShapeTester):
def test_basic(self):
rng = np.random.default_rng(3453)
adtens4 = tensor("float64", shape=(1, 1, 1, None))
adtens4_val = rng.random((1, 1, 1, 3)).astype(config.floatX)
self._compile_and_check(
[adtens4],
[Unbroadcast(0, 2)(adtens4)],
[adtens4_val],
Unbroadcast,
warn=False,
)
...@@ -16,16 +16,10 @@ from aesara.graph.optdb import OptimizationQuery ...@@ -16,16 +16,10 @@ from aesara.graph.optdb import OptimizationQuery
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.raise_op import Assert from aesara.raise_op import Assert
from aesara.tensor import inplace from aesara.tensor import inplace
from aesara.tensor.basic import ( from aesara.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector
Alloc,
MakeVector,
Rebroadcast,
_convert_to_int8,
make_vector,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.math import Dot, add, dot, exp, sqr from aesara.tensor.math import Dot, add, dot, exp, sqr
from aesara.tensor.shape import SpecifyShape, _shape, shape, specify_shape from aesara.tensor.shape import SpecifyShape, Unbroadcast, _shape, shape, specify_shape
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
...@@ -843,61 +837,61 @@ class TestLocalSubtensorLift: ...@@ -843,61 +837,61 @@ class TestLocalSubtensorLift:
f([1, 2, 3], 4) # let debugmode test something f([1, 2, 3], 4) # let debugmode test something
def test_basic_8(self): def test_basic_8(self):
# Test that Subtensor(Rebroadcast(x)) gets optimized into # Test that Subtensor(Unbroadcast(x)) gets optimized into
# Rebroadcast(Subtensor(x)). # Unbroadcast(Subtensor(x)).
# test basic case # test basic case
x = matrix("x") x = row("x")
xval = np.random.random((1, 10)).astype(config.floatX) xval = np.random.random((1, 10)).astype(config.floatX)
assert x.broadcastable == (False, False) assert x.broadcastable == (True, False)
newx = Rebroadcast((0, True), (1, False))(x) newx = Unbroadcast(0)(x)
assert newx.broadcastable == (True, False) assert newx.broadcastable == (False, False)
f1 = function([x], newx[:2, :5], mode=mode_opt) f1 = function([x], newx[:2, :5], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f1, ops_to_check=[Subtensor, Rebroadcast]) assert check_stack_trace(f1, ops_to_check=[Subtensor, Unbroadcast])
prog = f1.maker.fgraph.toposort() prog = f1.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor) assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Rebroadcast) assert isinstance(prog[1].op, Unbroadcast)
assert (f1(xval) == xval[:2, :5]).all() assert (f1(xval) == xval[:2, :5]).all()
# corner case 1: rebroadcast changes dims which are dropped through subtensor # corner case 1: Unbroadcast changes dims which are dropped through subtensor
y = tensor4("x") y = tensor("float64", shape=(1, 10, 1, 3), name="x")
yval = np.random.random((1, 10, 1, 3)).astype(config.floatX) yval = np.random.random((1, 10, 1, 3)).astype(config.floatX)
assert y.broadcastable == (False, False, False, False) assert y.broadcastable == (True, False, True, False)
newy = Rebroadcast((0, True), (2, True))(y) newy = Unbroadcast(0, 2)(y)
assert newy.broadcastable == (True, False, True, False) assert newy.broadcastable == (False, False, False, False)
f2 = function([y], newy[:, 3, 0, :], mode=mode_opt) f2 = function([y], newy[:, 3, 0, :], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f2, ops_to_check=[Subtensor, Rebroadcast]) assert check_stack_trace(f2, ops_to_check=[Subtensor, Unbroadcast])
prog = f2.maker.fgraph.toposort() prog = f2.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor) assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Rebroadcast) assert isinstance(prog[1].op, Unbroadcast)
assert (f2(yval) == yval[:, 3, 0, :]).all() assert (f2(yval) == yval[:, 3, 0, :]).all()
# corner case 2: subtensor idx_list is shorter than resulting broadcast pattern # corner case 2: subtensor idx_list is shorter than resulting broadcast pattern
f3 = function([y], newy[:, 3, 0], mode=mode_opt) f3 = function([y], newy[:, 3, 0], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f3, ops_to_check=[Subtensor, Rebroadcast]) assert check_stack_trace(f3, ops_to_check=[Subtensor, Unbroadcast])
prog = f3.maker.fgraph.toposort() prog = f3.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor) assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Rebroadcast) assert isinstance(prog[1].op, Unbroadcast)
assert (f3(yval) == yval[:, 3, 0]).all() assert (f3(yval) == yval[:, 3, 0]).all()
# corner case 3: subtensor idx_list is shorter than rebroadcast.axis # corner case 3: subtensor idx_list is shorter than Unbroadcast.axis
z = tensor4("x") z = tensor("float64", shape=(4, 10, 3, 1), name="x")
zval = np.random.random((4, 10, 3, 1)).astype(config.floatX) zval = np.random.random((4, 10, 3, 1)).astype(config.floatX)
assert z.broadcastable == (False, False, False, False) assert z.broadcastable == (False, False, False, True)
newz = Rebroadcast((3, True))(z) newz = Unbroadcast(3)(z)
assert newz.broadcastable == (False, False, False, True) assert newz.broadcastable == (False, False, False, False)
f4 = function([z], newz[:, 3, 0], mode=mode_opt) f4 = function([z], newz[:, 3, 0], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied # Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f4, ops_to_check=[Subtensor, Rebroadcast]) assert check_stack_trace(f4, ops_to_check=[Subtensor, Unbroadcast])
prog = f4.maker.fgraph.toposort() prog = f4.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor) assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Rebroadcast) assert isinstance(prog[1].op, Unbroadcast)
assert (f4(zval) == zval[:, 3, 0]).all() assert (f4(zval) == zval[:, 3, 0]).all()
......
...@@ -26,6 +26,7 @@ from aesara.graph.op import Op ...@@ -26,6 +26,7 @@ from aesara.graph.op import Op
from aesara.tensor.math import argmax, dot from aesara.tensor.math import argmax, dot
from aesara.tensor.math import max as at_max from aesara.tensor.math import max as at_max
from aesara.tensor.nnet import conv, conv2d from aesara.tensor.nnet import conv, conv2d
from aesara.tensor.shape import unbroadcast
from aesara.tensor.signal.pool import Pool from aesara.tensor.signal.pool import Pool
from aesara.tensor.type import TensorType, matrix, vector from aesara.tensor.type import TensorType, matrix, vector
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -237,11 +238,11 @@ class TestRopLop(RopLopChecker): ...@@ -237,11 +238,11 @@ class TestRopLop(RopLopChecker):
# vector # vector
self.check_rop_lop(self.x[:4].dimshuffle("x", 0).sum(axis=0), (4,)) self.check_rop_lop(self.x[:4].dimshuffle("x", 0).sum(axis=0), (4,))
def test_rebroadcast(self): def test_unbroadcast(self):
# I need the sum, because the setup expects the output to be a # I need the sum, because the setup expects the output to be a
# vector # vector
self.check_rop_lop( self.check_rop_lop(
at.unbroadcast(self.x[:4].dimshuffle("x", 0), 0).sum(axis=1), (1,) unbroadcast(self.x[:4].dimshuffle("x", 0), 0).sum(axis=1), (1,)
) )
@pytest.mark.slow @pytest.mark.slow
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论