提交 7f8af9bc authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Deprecate remaining uses of Rebroadcast in favor of Unbroadcast

上级 ac52d689
......@@ -147,7 +147,7 @@ from aesara.updates import OrderedUpdates
def get_scalar_constant_value(v):
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`.
If `v` is the output of dim-shuffles, fills, allocs, rebroadcasts, cast
If `v` is the output of dim-shuffles, fills, allocs, cast, etc.
this function digs through them.
If ``aesara.sparse`` is also there, we will look over CSM `Op`.
......
......@@ -204,8 +204,8 @@ def rebuild_collect_shared(
err_sug = (
"If the difference is related to the broadcast pattern,"
" you can call the"
" tensor.unbroadcast(var, axis_to_unbroadcast[, ...])"
" function to remove broadcastable dimensions."
" tensor.shape.unbroadcast(var, axis_to_unbroadcast[, ...])"
" function to mask broadcastable dimensions."
)
raise TypeError(err_msg, err_sug)
......
......@@ -23,8 +23,7 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
from aesara.graph.op import _NoPythonOp
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.tensor import basic
from aesara.tensor.shape import Reshape, Shape, SpecifyShape
from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast
__docformat__ = "restructedtext en"
......@@ -451,7 +450,7 @@ acceptable_ops = (
Shape,
SpecifyShape,
Reshape,
basic.Rebroadcast,
Unbroadcast,
at.math.Dot,
at.math.MaxAndArgmax,
at.subtensor.Subtensor,
......
......@@ -29,7 +29,6 @@ from aesara.tensor.basic import (
Eye,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
)
......@@ -50,7 +49,7 @@ from aesara.tensor.math import Dot, MaxAndArgmax
from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
......@@ -347,20 +346,12 @@ def jax_funcify_SpecifyShape(op, **kwargs):
return specifyshape
@jax_funcify.register(Rebroadcast)
def jax_funcify_Rebroadcast(op, **kwargs):
op_axis = op.axis
def rebroadcast(x):
for axis, value in op_axis.items():
if value and x.shape[axis] != 1:
raise ValueError(
"Dimension %s in Rebroadcast's input was"
" supposed to be 1 (got %s instead)" % (axis, x.shape[axis])
)
@jax_funcify.register(Unbroadcast)
def jax_funcify_Unbroadcast(op, **kwargs):
def unbroadcast(x):
return x
return rebroadcast
return unbroadcast
@jax_funcify.register(ViewOp)
......
......@@ -14,10 +14,10 @@ from aesara.tensor.basic import (
Eye,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
)
from aesara.tensor.shape import Unbroadcast
@numba_funcify.register(AllocEmpty)
......@@ -195,22 +195,13 @@ def makevector({", ".join(input_names)}):
return numba_basic.numba_njit(makevector_fn)
@numba_funcify.register(Rebroadcast)
def numba_funcify_Rebroadcast(op, **kwargs):
# Make sure op_axis only has ints. This way we can avoid literal_unroll
# which causes a segfault, see GH issue https://github.com/numba/numba/issues/8215
op_axis = tuple((axis, int(value)) for axis, value in op.axis.items())
@numba_funcify.register(Unbroadcast)
def numba_funcify_Unbroadcast(op, **kwargs):
@numba_basic.numba_njit
def rebroadcast(x):
for axis, value in op_axis:
if value and x.shape[axis] != 1:
raise ValueError(
("Dimension in Rebroadcast's input was supposed to be 1")
)
def unbroadcast(x):
return x
return rebroadcast
return unbroadcast
@numba_funcify.register(TensorFromScalar)
......
......@@ -14,7 +14,7 @@ from aesara.scan.utils import expand_empty, safe_new, until
from aesara.tensor.basic import get_scalar_constant_value
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import minimum
from aesara.tensor.shape import shape_padleft
from aesara.tensor.shape import shape_padleft, unbroadcast
from aesara.tensor.type import TensorType, integer_dtypes
from aesara.updates import OrderedUpdates
......@@ -751,7 +751,7 @@ def scan(
# defined in scan utils
sit_sot_scan_inputs.append(
expand_empty(
at.unbroadcast(shape_padleft(actual_arg), 0),
unbroadcast(shape_padleft(actual_arg), 0),
actual_n_steps,
)
)
......@@ -881,7 +881,7 @@ def scan(
# this will represent only a slice and it will have one
# dimension less.
if isinstance(inner_out.type, TensorType) and return_steps.get(pos, 0) != 1:
outputs[pos] = at.unbroadcast(shape_padleft(inner_out), 0)
outputs[pos] = unbroadcast(shape_padleft(inner_out), 0)
if not return_list and len(outputs) == 1:
outputs = outputs[0]
......@@ -1010,7 +1010,7 @@ def scan(
sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append(
expand_empty(
at.unbroadcast(shape_padleft(input.variable), 0),
unbroadcast(shape_padleft(input.variable), 0),
actual_n_steps,
)
)
......
......@@ -10,7 +10,7 @@ import warnings
from collections.abc import Sequence
from functools import partial
from numbers import Number
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union
from typing import cast as type_cast
import numpy as np
......@@ -44,6 +44,7 @@ from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.shape import (
Shape,
Shape_i,
Unbroadcast,
shape,
shape_padaxis,
shape_padleft,
......@@ -254,7 +255,7 @@ def get_scalar_constant_value(
):
"""Return the constant scalar(0-D) value underlying variable `v`.
If `v` is the output of dimshuffles, fills, allocs, rebroadcasts,
If `v` is the output of dimshuffles, fills, allocs, etc,
cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise
and some pattern with Subtensor, this function digs through them.
......@@ -323,7 +324,7 @@ def get_scalar_constant_value(
(
Alloc,
DimShuffle,
Rebroadcast,
Unbroadcast,
# outputguard is only used in debugmode but we
# keep it here to avoid problems with old pickels.
compile.ops.OutputGuard,
......@@ -495,7 +496,7 @@ def get_scalar_constant_value(
gp_broadcastable = grandparent.type.broadcastable
ndim = grandparent.type.ndim
if grandparent.owner and isinstance(
grandparent.owner.op, Rebroadcast
grandparent.owner.op, Unbroadcast
):
ggp_broadcastable = grandparent.owner.inputs[0].broadcastable
l = [
......@@ -616,185 +617,6 @@ class ScalarFromTensor(COp):
scalar_from_tensor = ScalarFromTensor()
class Rebroadcast(COp):
"""
Change the input's broadcastable fields in some predetermined way.
See Also
--------
unbroadcast <aesara.tensor.unbroadcast>
Notes
-----
Works inplace and works for CudaNdarrayType.
Examples
--------
``Rebroadcast((0, True), (1, False))(x)`` would make `x` broadcastable in
axis 0 and not broadcastable in axis 1.
"""
view_map = {0: [0]}
_f16_ok = True
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version: Dict = {}
check_input = False
__props__ = ("axis",)
_f16_ok = True
def __init__(self, *axis):
# Sort them to make sure we merge all possible case.
items = sorted(axis)
self.axis = dict(items)
for axis, broad in self.axis.items():
if not isinstance(axis, (np.integer, int)):
raise TypeError(f"Rebroadcast needs integer axes. Got {axis}")
if not isinstance(broad, (np.bool_, bool)):
raise TypeError(
f"Rebroadcast needs bool for new broadcast pattern. Got {broad}"
)
def __hash__(self):
# Need special __hash__ as dict aren't hashable.
# no ambiguity because each item key is unique
items = sorted(self.axis.items())
return hash((type(self), tuple(items)))
def __str__(self):
return f"{self.__class__.__name__}{{{','.join(str(i) for i in self.axis.items())}}}"
def make_node(self, x):
if self.axis.keys() and (x.ndim <= max(self.axis.keys())):
raise ValueError("Trying to rebroadcast non-existent dimension")
t = x.type.clone(
shape=[self.axis.get(i, b) for i, b in enumerate(x.type.broadcastable)]
)
return Apply(self, [x], [t()])
def perform(self, node, inp, out_):
(x,) = inp
(out,) = out_
for axis, value in self.axis.items():
if value and x.shape[axis] != 1:
raise ValueError(
f"Dimension {axis} in Rebroadcast's input was"
f" supposed to be 1 (got {x.shape[axis]} instead)"
)
out[0] = x
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
# restore the broadcasting pattern of the input
return (
Rebroadcast(
*[
(axis, x.type.broadcastable[axis])
for axis, value in self.axis.items()
]
)(gz),
)
def infer_shape(self, fgraph, node, ishapes):
assert len(ishapes) == 1
l = []
one = aesara.tensor.basic.constant(1)
for ax in range(len(ishapes[0])):
if self.axis.get(ax, False):
l.append(one)
else:
l.append(ishapes[0][ax])
return [tuple(l)]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return self(*eval_points, return_list=True)
def c_code(self, node, nodename, inp, out, sub):
(iname,) = inp
(oname,) = out
fail = sub["fail"]
itype = node.inputs[0].type.__class__
if itype in self.c_code_and_version:
code, version = self.c_code_and_version[itype]
final_code = ""
for axis, value in self.axis.items():
if value:
final_code += code % locals()
return (
final_code
+ f"""
Py_XDECREF({oname});
{oname} = {iname};
Py_XINCREF({oname});
"""
)
raise NotImplementedError()
def c_code_cache_version(self):
version = []
# If any of the c code is unversioned, we have to return ()
# Else, we will return a list of (type name, version) pairs.
for t, (c, v) in sorted(
self.c_code_and_version.items(), key=lambda pair: str(pair[0])
):
if not v:
warnings.warn(
f"Type {t} has C code for Rebroadcast, but it "
"has no version. You should add a 'version' "
"keyword arg when calling "
"register_rebroadcast_c_code.",
stacklevel=2,
)
return ()
version.append((str(t), v))
if version:
version.append(1)
return tuple(version)
def register_rebroadcast_c_code(typ, code, version=()):
"""
Tell Rebroadcast how to generate C code for an Aesara Type.
typ : Aesara type
It must be the Aesara class itself and not an instance of the class.
code : C code
That checks if the dimension %(axis)s is of shape 1 for the Aesara type
'typ'. Use %(iname)s and %(oname)s for the input and output C variable
names respectively, and %(axis)s for the axis that we need to check.
This code is put in a loop for all axes.
version
A number indicating the version of the code, for cache.
"""
Rebroadcast.c_code_and_version[typ] = (code, version)
register_rebroadcast_c_code(
TensorType,
"""
if(PyArray_DIMS(%(iname)s)[%(axis)s] != 1){
PyErr_Format(PyExc_ValueError,
"Dimension %(axis)s in Rebroadcast's input was"
" supposed to be 1 (got %%d instead)",
PyArray_DIMS(%(iname)s)[%(axis)s]);
%(fail)s
}
""",
version=1,
)
# to be removed as we get the epydoc routine-documenting thing going
# -JB 20080924
def _conversion(real_value: Op, name: str) -> Op:
......@@ -2254,36 +2076,6 @@ class Split(COp):
)
def unbroadcast(x, *axes):
"""
Make the input impossible to broadcast in the specified axes.
For example, unbroadcast(x, 0) will make the first dimension
of x not broadcastable. When performing the function, if the length
of x along that dimension is not 1, a ValueError will be raised.
We apply the opt here not to pollute the graph
Parameters
----------
x : tensor_like
Input aesara tensor.
axis : an int or an iterable object such as list or tuple of int values
The dimension along which the tensor x should be unbroadcastable.
If the length of x along these dimensions is not 1, a ValueError will
be raised.
Returns
-------
tensor
A aesara tensor, which is unbroadcastable along the specified dimensions.
"""
x = as_tensor_variable(x)
rval = Rebroadcast(*[(axis, False) for axis in axes])(x)
return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval)
class Join(COp):
r"""
Concatenate multiple `TensorVariable`\s along some axis.
......@@ -4195,7 +3987,6 @@ __all__ = [
"stack",
"roll",
"join",
"unbroadcast",
"split",
"transpose",
"extract_constant",
......
......@@ -48,7 +48,6 @@ from aesara.tensor.basic import (
AllocEmpty,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
Split,
TensorFromScalar,
......@@ -77,9 +76,11 @@ from aesara.tensor.shape import (
Shape,
Shape_i,
SpecifyShape,
Unbroadcast,
shape_i,
shape_padleft,
specify_shape,
unbroadcast,
)
from aesara.tensor.sort import TopKOp
from aesara.tensor.subtensor import Subtensor, get_idx_list
......@@ -2226,10 +2227,13 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
@register_useless
@register_canonicalize
@register_specialize
@local_optimizer([Rebroadcast])
def local_useless_rebroadcast(fgraph, node):
"""Remove `Rebroadcast` if it does not actually change the broadcasting pattern."""
if isinstance(node.op, Rebroadcast):
@local_optimizer([Unbroadcast])
def local_useless_unbroadcast(fgraph, node):
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern.
TODO: Implement equivalent rewrite for SpecifyShape
"""
if isinstance(node.op, Unbroadcast):
x = node.inputs[0]
if x.broadcastable == node.outputs[0].broadcastable:
# No broadcastable flag was modified
......@@ -2238,15 +2242,12 @@ def local_useless_rebroadcast(fgraph, node):
return [x]
else:
# Keep the flags that modify something
new_axis = {}
for dim, bc in node.op.axis.items():
if x.broadcastable[dim] != bc:
new_axis[dim] = bc
if new_axis == node.op.axis:
new_axes = tuple(ax for ax in node.op.axes if x.type.shape[ax] == 1)
if new_axes == node.op.axes:
# All flags are useful
return
return None
else:
r = Rebroadcast(*new_axis.items())(x)
r = unbroadcast(x, *new_axes)
# Copy over stacktrace from previous output
copy_stack_trace(node.outputs, r)
return [r]
......@@ -2254,93 +2255,49 @@ def local_useless_rebroadcast(fgraph, node):
@register_canonicalize
@register_specialize
@local_optimizer([Rebroadcast])
def local_rebroadcast_lift(fgraph, node):
@local_optimizer([Unbroadcast])
def local_unbroadcast_lift(fgraph, node):
"""
Lifts Rebroadcast through unary Elemwise operations,
and merges consecutive Rebroadcasts.
Lifts `Unbroadcast` through unary Elemwise operations,
and merges consecutive `Unbroadcast`s.
Rebroadcast(Elemwise(x)) => Elemwise(Rebroadcast(x))
Rebroadcast(Rebroadcast(x)) => Rebroadcast(x)
Unbroadcast(Elemwise(x)) => Elemwise(Unbroadcast(x))
Unbroadcast(Unbroadcast(x)) => Unbroadcast(x)
TODO: Implement equivalent Elemwise lift for SpecifyShape
"""
op = node.op
if not isinstance(op, Rebroadcast):
if not isinstance(op, Unbroadcast):
return False
inp = node.inputs[0]
inode = inp.owner
if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1:
# It may happen that `input` has no client because this optimization
# is called from `apply_rebroadcast_opt`, which in particular is used
# by the `unbroadcast` function before we are in the actual function
# compilation phase.
if len(fgraph.clients.get(inp, ())) == 1:
rebroadcasted = Rebroadcast(*list(op.axis.items()))(inode.inputs[0])
# Copy over stacktrace from previous output (after rebroadcasting)
# to new output, because an error in the new graph right after
# rebroadcasting must have been caused by the previous rebroadcasting.
copy_stack_trace(node.outputs, rebroadcasted)
unbroadcasted = unbroadcast(inode.inputs[0], *op.axes)
copy_stack_trace(node.outputs, unbroadcasted)
rval = inode.op.make_node(rebroadcasted).outputs
rval = inode.op.make_node(unbroadcasted).outputs
# Copy over stacktrace from previous output (after rebroadcasting)
# Copy over stacktrace from previous output (after unbroadcasting)
# and input (after elemwise operation) to new output, because an
# error in the new graph could have been caused by either of the
# two ops.
copy_stack_trace(node.outputs + node.inputs, rval)
return rval
if inode and isinstance(inode.op, Rebroadcast):
# the "axis" specification in the outer Rebroadcast overrides
# the axis of the inner one
axis = inode.op.axis.copy()
axis.update(op.axis)
iinput = inode.inputs[0]
rval = [Rebroadcast(*list(axis.items()))(iinput)]
# Copy over stacktrace from previous output (after second rebroadcast)
# and from previous input (after first rebroadcast op) because an error in
# the new graph could have been caused by either of the two
# rebroadcast ops.
if inode and isinstance(inode.op, Unbroadcast):
# Merge axis of each unbroadcast
axis = tuple(set(inode.op.axes).union(set(op.axes)))
iinput = inode.inputs[0]
rval = [unbroadcast(iinput, *axis)]
# Copy over stacktrace from previous output (after second unbroadcasting)
# and from previous input (after first unbroadcasting) because an error in
# the new graph could have been caused by either of the two Unbroadcast ops.
copy_stack_trace(node.outputs + node.inputs, rval)
return rval
def apply_rebroadcast_opt(rval):
"""
Apply as many times as required the optimization local_useless_rebroadcast
and local_rebroadcast_lift.
Parameters
----------
rval: a Variable
Returns
-------
A Variable (the same if no optimization can be applied)
"""
fg = FunctionGraph([], [])
changed = True
while changed and rval.owner:
changed = False
rval2 = local_useless_rebroadcast.transform(fg, rval.owner)
if rval2:
assert len(rval2) == 1
rval = rval2[0]
changed = True
if rval.owner:
rval2 = local_rebroadcast_lift.transform(fg, rval.owner)
if rval2:
assert len(rval2) == 1
rval = rval2[0]
changed = True
return rval
@register_specialize
@register_canonicalize
@register_useless
......
......@@ -926,3 +926,108 @@ def specify_broadcastable(x, *axes):
shape_info = [1 if i in axes else None for i in range(len(x.type.shape))]
return specify_shape(x, shape_info)
class Unbroadcast(COp):
"""
Mask static broadcastable dimensions of input as `None`
See Also
--------
unbroadcast <aesara.tensor.shape.unbroadcast>
Examples
--------
``Unbroadcast((1,))(x)`` would make `x` second static dimension be `None`
"""
view_map = {0: [0]}
_f16_ok = True
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version: Dict = {}
check_input = False
__props__ = ("axes",)
_f16_ok = True
def __init__(self, *axis):
# Sort them to make sure we merge all possible case.
items = tuple(sorted(axis))
self.axes = items
for axis in self.axes:
if not isinstance(axis, (np.integer, int)):
raise TypeError(f"Unbroadcast needs integer axes. Got {axis}")
def __str__(self):
return f"{self.__class__.__name__}{{{','.join(str(i) for i in self.axes)}}}"
def make_node(self, x):
x = as_tensor_variable(x)
if x.type.ndim <= max(self.axes):
raise ValueError("Trying to unbroadcast of non-existent dimension")
shape = [
None if (sh == 1 and i in self.axes) else sh
for i, sh in enumerate(x.type.shape)
]
return Apply(self, [x], [x.type.clone(shape=shape)()])
def perform(self, node, inp, out_):
(x,) = inp
(out,) = out_
out[0] = x
def grad(self, inp, grads):
(x,) = inp
(gz,) = grads
# restore the broadcasting pattern of the input
return [specify_shape(gz, x.type.shape)]
def infer_shape(self, fgraph, node, ishapes):
assert len(ishapes) == 1
return [tuple(ishapes[0])]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return self(*eval_points, return_list=True)
def c_code(self, node, nodename, inp, out, sub):
(iname,) = inp
(oname,) = out
return f"""
Py_XDECREF({oname});
{oname} = {iname};
Py_XINCREF({oname});
"""
def c_code_cache_version(self):
return (3,)
def unbroadcast(x, *axes):
"""
Mask static broadcastable dimensions of input as `None`
Parameters
----------
x : tensor_like
Input aesara tensor.
axis : an int or an iterable object such as list or tuple of int values
The broadcastable dimensions of x that should be unbroadcasted.
Returns
-------
tensor
A aesara tensor, with static broadcastable dimensions masked as `None`
"""
x = as_tensor_variable(x)
unbroadcasted_axes = [axis for axis in axes if x.type.shape[axis] == 1]
if not unbroadcasted_axes:
return x
return Unbroadcast(*unbroadcasted_axes)(x)
......@@ -14,7 +14,6 @@ from aesara.tensor.basic import (
ARange,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
alloc,
......@@ -50,9 +49,11 @@ from aesara.tensor.math import (
from aesara.tensor.shape import (
Shape,
SpecifyShape,
Unbroadcast,
shape_padleft,
shape_tuple,
specify_shape,
unbroadcast,
)
from aesara.tensor.sharedvar import TensorSharedVariable
from aesara.tensor.subtensor import (
......@@ -370,7 +371,7 @@ def local_subtensor_lift(fgraph, node):
Handles the following unary ops:
elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all
rebroadcast(x)[idx] => rebroadcast(x[idx])
Unbroadcast(x)[idx] => Unbroadcast(x[idx])
"""
if isinstance(node.op, Subtensor):
......@@ -429,34 +430,34 @@ def local_subtensor_lift(fgraph, node):
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
return [ret]
if isinstance(u.owner.op, Rebroadcast):
# make sure that Rebroadcast has only 1 input
assert len(u.owner.inputs) == 1
if isinstance(u.owner.op, Unbroadcast):
# Subtensor might reduce dim., adapt broadcast pattern accordingly
new_axis = []
old_axes = u.owner.op.axes
new_axes = []
# loop through indices being subtensor-ed
# i indexes broadcastable pattern before subtensor
# j indexes broadcastable pattern after subtensor
j = 0
for (i, x) in enumerate(node.op.idx_list):
# if its not a slice, it will reduce the dimension, should
# if it is not a slice, it will reduce the dimension, should
# not appear in the broascastable dimensions
if isinstance(x, slice):
new_axis += [(j, u.broadcastable[i])]
if i in old_axes:
new_axes.append(j)
j += 1
# now keep the broadcastable pattern of all
# items not appearing in subtensor list
for i in range(len(node.op.idx_list), len(u.broadcastable)):
new_axis += [(j, u.broadcastable[i])]
if i in old_axes:
new_axes.append(j)
j += 1
subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
# Copy over previous output stacktrace
copy_stack_trace(node.outputs[0], subt_x)
rbcast_subt_x = Rebroadcast(*new_axis)(subt_x)
rbcast_subt_x = unbroadcast(subt_x, *new_axes)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x)
......
......@@ -39,7 +39,7 @@ from aesara.tensor.math import sum as at_sum
from aesara.tensor.nnet.basic import SoftmaxGrad
from aesara.tensor.random.basic import RandomVariable, normal
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.shape import Shape, Shape_i, SpecifyShape, reshape
from aesara.tensor.shape import Shape, Shape_i, SpecifyShape, Unbroadcast, reshape
from aesara.tensor.type import (
dscalar,
dvector,
......@@ -201,20 +201,11 @@ def test_jax_compile_ops():
compare_jax_and_py(x_fg, [])
x_np = np.zeros((20, 1, 1))
x = at.Rebroadcast((0, False), (1, True), (2, False))(at.as_tensor_variable(x_np))
x = Unbroadcast(0, 2)(at.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
with config.change_flags(compute_test_value="off"):
x = at.Rebroadcast((0, True), (1, False), (2, False))(
at.as_tensor_variable(x_np)
)
x_fg = FunctionGraph([], [x])
with pytest.raises(ValueError):
compare_jax_and_py(x_fg, [])
x = ViewOp()(at.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
......
......@@ -40,7 +40,7 @@ from aesara.tensor import extra_ops, nlinalg, slinalg
from aesara.tensor import subtensor as at_subtensor
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
class MyType(Type):
......@@ -769,39 +769,18 @@ def test_ScalarFromTensor(v):
)
@pytest.mark.parametrize(
"v, axis, fails",
[
(
set_test_value(at.matrix(), np.array([[1.0]], dtype=config.floatX)),
[(0, True), (1, True)],
False,
),
(
set_test_value(at.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
[(0, True), (1, False)],
False,
),
(
set_test_value(at.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
[(0, True), (1, True)],
True,
),
],
)
def test_Rebroadcast(v, axis, fails):
g = atb.Rebroadcast(*axis)(v)
def test_Unbroadcast():
v = set_test_value(at.row(), np.array([[1.0, 2.0]], dtype=config.floatX))
g = Unbroadcast(0)(v)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if not fails else pytest.raises(ValueError)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
......
......@@ -36,7 +36,7 @@ def test_debugprint_sitsot():
| | | | | |k [id D]
| | | | | |Subtensor{int64} [id H]
| | | | | |Shape [id I]
| | | | | | |Rebroadcast{(0, False)} [id J]
| | | | | | |Unbroadcast{0} [id J]
| | | | | | |InplaceDimShuffle{x,0} [id K]
| | | | | | |Elemwise{second,no_inplace} [id L]
| | | | | | |A [id M]
......@@ -45,9 +45,9 @@ def test_debugprint_sitsot():
| | | | | |ScalarConstant{0} [id P]
| | | | |Subtensor{int64} [id Q]
| | | | |Shape [id R]
| | | | | |Rebroadcast{(0, False)} [id J]
| | | | | |Unbroadcast{0} [id J]
| | | | |ScalarConstant{1} [id S]
| | | |Rebroadcast{(0, False)} [id J]
| | | |Unbroadcast{0} [id J]
| | | |ScalarFromTensor [id T]
| | | |Subtensor{int64} [id H]
| | |A [id M] (outer_in_non_seqs-0)
......@@ -91,7 +91,7 @@ def test_debugprint_sitsot_no_extra_info():
| | | | | |k [id D]
| | | | | |Subtensor{int64} [id H]
| | | | | |Shape [id I]
| | | | | | |Rebroadcast{(0, False)} [id J]
| | | | | | |Unbroadcast{0} [id J]
| | | | | | |InplaceDimShuffle{x,0} [id K]
| | | | | | |Elemwise{second,no_inplace} [id L]
| | | | | | |A [id M]
......@@ -100,9 +100,9 @@ def test_debugprint_sitsot_no_extra_info():
| | | | | |ScalarConstant{0} [id P]
| | | | |Subtensor{int64} [id Q]
| | | | |Shape [id R]
| | | | | |Rebroadcast{(0, False)} [id J]
| | | | | |Unbroadcast{0} [id J]
| | | | |ScalarConstant{1} [id S]
| | | |Rebroadcast{(0, False)} [id J]
| | | |Unbroadcast{0} [id J]
| | | |ScalarFromTensor [id T]
| | | |Subtensor{int64} [id H]
| | |A [id M]
......@@ -261,7 +261,7 @@ def test_debugprint_nested_scans():
> | | | | | | |*3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1)
> | | | | | | |Subtensor{int64} [id BJ]
> | | | | | | |Shape [id BK]
> | | | | | | | |Rebroadcast{(0, False)} [id BL]
> | | | | | | | |Unbroadcast{0} [id BL]
> | | | | | | | |InplaceDimShuffle{x,0} [id BM]
> | | | | | | | |Elemwise{second,no_inplace} [id BN]
> | | | | | | | |*2-<TensorType(float64, (None,))> [id BO] -> [id W] (inner_in_non_seqs-0)
......@@ -270,9 +270,9 @@ def test_debugprint_nested_scans():
> | | | | | | |ScalarConstant{0} [id BR]
> | | | | | |Subtensor{int64} [id BS]
> | | | | | |Shape [id BT]
> | | | | | | |Rebroadcast{(0, False)} [id BL]
> | | | | | | |Unbroadcast{0} [id BL]
> | | | | | |ScalarConstant{1} [id BU]
> | | | | |Rebroadcast{(0, False)} [id BL]
> | | | | |Unbroadcast{0} [id BL]
> | | | | |ScalarFromTensor [id BV]
> | | | | |Subtensor{int64} [id BJ]
> | | | |*2-<TensorType(float64, (None,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
......@@ -350,7 +350,7 @@ def test_debugprint_nested_scans():
> | | | | | | |*3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1)
> | | | | | | |Subtensor{int64} [id BL]
> | | | | | | |Shape [id BM]
> | | | | | | | |Rebroadcast{(0, False)} [id BN]
> | | | | | | | |Unbroadcast{0} [id BN]
> | | | | | | | |InplaceDimShuffle{x,0} [id BO]
> | | | | | | | |Elemwise{second,no_inplace} [id BP]
> | | | | | | | |*2-<TensorType(float64, (None,))> [id BA] (inner_in_non_seqs-0)
......@@ -359,9 +359,9 @@ def test_debugprint_nested_scans():
> | | | | | | |ScalarConstant{0} [id BS]
> | | | | | |Subtensor{int64} [id BT]
> | | | | | |Shape [id BU]
> | | | | | | |Rebroadcast{(0, False)} [id BN]
> | | | | | | |Unbroadcast{0} [id BN]
> | | | | | |ScalarConstant{1} [id BV]
> | | | | |Rebroadcast{(0, False)} [id BN]
> | | | | |Unbroadcast{0} [id BN]
> | | | | |ScalarFromTensor [id BW]
> | | | | |Subtensor{int64} [id BL]
> | | | |*2-<TensorType(float64, (None,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
......@@ -487,7 +487,7 @@ def test_debugprint_mitmot():
| | | | | | | |k [id G]
| | | | | | | |Subtensor{int64} [id K]
| | | | | | | |Shape [id L]
| | | | | | | | |Rebroadcast{(0, False)} [id M]
| | | | | | | | |Unbroadcast{0} [id M]
| | | | | | | | |InplaceDimShuffle{x,0} [id N]
| | | | | | | | |Elemwise{second,no_inplace} [id O]
| | | | | | | | |A [id P]
......@@ -496,9 +496,9 @@ def test_debugprint_mitmot():
| | | | | | | |ScalarConstant{0} [id S]
| | | | | | |Subtensor{int64} [id T]
| | | | | | |Shape [id U]
| | | | | | | |Rebroadcast{(0, False)} [id M]
| | | | | | | |Unbroadcast{0} [id M]
| | | | | | |ScalarConstant{1} [id V]
| | | | | |Rebroadcast{(0, False)} [id M]
| | | | | |Unbroadcast{0} [id M]
| | | | | |ScalarFromTensor [id W]
| | | | | |Subtensor{int64} [id K]
| | | | |A [id P] (outer_in_non_seqs-0)
......
......@@ -34,7 +34,6 @@ from aesara.tensor.basic import (
Join,
MakeVector,
PermuteRowElements,
Rebroadcast,
ScalarFromTensor,
Split,
TensorFromScalar,
......@@ -86,7 +85,6 @@ from aesara.tensor.basic import (
triu,
triu_indices,
triu_indices_from,
unbroadcast,
vertical_stack,
zeros_like,
)
......@@ -104,7 +102,6 @@ from aesara.tensor.type import (
dscalar,
dscalars,
dtensor3,
dtensor4,
dvector,
fmatrix,
fscalar,
......@@ -337,7 +334,7 @@ TestAllocb4GradBroadcast = makeBroadcastTester(
)
# Partial un broadcast of a dimshuffled input
# Partial unbroadcast of a dimshuffled input
TestAllocDimshuffleGradBroadcast = makeBroadcastTester(
name="Allocb4GradTester",
op=lambda x: alloc(x.dimshuffle("x", "x", 0), 1, s2, s3),
......@@ -3223,80 +3220,6 @@ class TestLongTensor:
constant()[[val, val]]
class TestBroadcast:
def test_unbroadcast(self):
# test that the unbroadcast fct don't insert not needed broadcast
# and fuse consecutive Rebroadcast op
x = matrix()
assert unbroadcast(x, 0) is x
assert unbroadcast(x, 1) is x
assert unbroadcast(x, 1, 0) is x
assert unbroadcast(x, 0, 1) is x
x = row()
assert unbroadcast(x, 0) is not x
assert unbroadcast(x, 1) is x
assert unbroadcast(x, 1, 0) is not x
assert unbroadcast(x, 0, 1) is not x
# The first broadcast is remove the broadcast, so the second
# should not make one
assert unbroadcast(unbroadcast(x, 0), 0).owner.inputs[0] is x
# Test that consecutive Rebroadcast op are fused
x = TensorType(dtype="float64", shape=(True, True))()
assert unbroadcast(unbroadcast(x, 1), 0).owner.inputs[0] is x
def test_infer_shape(self):
x = matrix()
y = unbroadcast(x, 0)
f = aesara.function([x], y.shape)
assert (f(np.zeros((2, 5), dtype=config.floatX)) == [2, 5]).all()
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert len(topo) == 3
assert isinstance(topo[0].op, Shape_i)
assert isinstance(topo[1].op, Shape_i)
assert isinstance(topo[2].op, MakeVector)
x = row()
y = unbroadcast(x, 0)
f = aesara.function([x], y.shape)
assert (f(np.zeros((1, 5), dtype=config.floatX)) == [1, 5]).all()
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert len(topo) == 2
assert isinstance(topo[0].op, Shape_i)
assert isinstance(topo[1].op, MakeVector)
class TestRebroadcast(utt.InferShapeTester):
def test_rebroadcast(self):
rng = np.random.default_rng(3453)
# Rebroadcast
adtens4 = dtensor4()
adict = [(0, False), (1, True), (2, False), (3, True)]
adtens4_val = rng.random((2, 1, 3, 1)).astype(config.floatX)
self._compile_and_check(
[adtens4],
[Rebroadcast(*adict)(adtens4)],
[adtens4_val],
Rebroadcast,
warn=False,
)
adtens4_bro = TensorType("float64", (True, True, True, False))()
bdict = [(0, True), (1, False), (2, False), (3, False)]
adtens4_bro_val = rng.random((1, 1, 1, 3)).astype(config.floatX)
self._compile_and_check(
[adtens4_bro],
[Rebroadcast(*bdict)(adtens4_bro)],
[adtens4_bro_val],
Rebroadcast,
)
def test_len():
for shape_ in [(5,), (3, 4), (7, 4, 6)]:
x = tensor(dtype="floatX", shape=(False,) * len(shape_))
......
......@@ -28,7 +28,6 @@ from aesara.tensor.basic import (
Alloc,
Join,
MakeVector,
Rebroadcast,
ScalarFromTensor,
Split,
TensorFromScalar,
......@@ -40,7 +39,6 @@ from aesara.tensor.basic import (
)
from aesara.tensor.basic_opt import (
ShapeFeature,
apply_rebroadcast_opt,
assert_op,
local_alloc_sink_dimshuffle,
local_dimshuffle_lift,
......@@ -92,9 +90,11 @@ from aesara.tensor.shape import (
Reshape,
Shape_i,
SpecifyShape,
Unbroadcast,
reshape,
shape,
specify_shape,
unbroadcast,
)
from aesara.tensor.subtensor import (
AdvancedIncSubtensor1,
......@@ -1898,18 +1898,46 @@ class TestTile:
f(data)
class TestRebroadcast:
def test_local_useless_rebroadcast(self):
mode = get_default_mode().including("canonicalize")
v1 = vector()
v2 = vector()
j = at.join(0, v1, v2)
f = function([v1, v2], j, mode=mode)
f([1, 2], [3, 4, 5])
e = f.maker.fgraph.toposort()
assert len([n for n in e if isinstance(n.op, Rebroadcast)]) == 0
class TestUnbroadcast:
def setup_method(self):
self.mode = get_default_mode().including("canonicalize")
def test_local_useless_unbroadcast(self):
x1 = tensor("float64", shape=(1, 2))
x2 = tensor("float64", shape=(2, 1))
unbroadcast_op = Unbroadcast(0)
f = function([x1], unbroadcast_op(x1), mode=self.mode)
assert (
sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort())
== 1
)
f = function([x2], unbroadcast_op(x2), mode=self.mode)
assert (
sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort())
== 0
)
def test_local_unbroadcast_lift(self):
x = tensor("float64", shape=(1, 1))
y = unbroadcast(at.exp(unbroadcast(x, 0)), 1)
assert (
sum(
isinstance(node.op, Unbroadcast)
for node in FunctionGraph([x], [y], copy_inputs=False).toposort()
)
== 2
)
f = function([x], y, mode=self.mode)
assert (
sum(isinstance(node.op, Unbroadcast) for node in f.maker.fgraph.toposort())
== 1
)
assert check_stack_trace(f, ops_to_check="all")
np.testing.assert_almost_equal(f([[1]]), np.exp([[1]]))
class TestUselessElemwise:
......@@ -3167,21 +3195,6 @@ def test_local_useless_alloc():
assert isinstance(topo[-1].op, Alloc)
def test_apply_rebroadcast_opt():
# Test the `Elemwise` case in `local_rebroadcast_lift` with `fgraph=None`.
# This is called by in `apply_rebroadcast_opt`.
a = vector(dtype="float32")
b = tensor("float64", [True])
x = b.astype(a.dtype)
broadcastable = (False,)
axis = [(i, broadcastable[i]) for i in range(len(broadcastable))]
rval = Rebroadcast(*axis)(x)
res = apply_rebroadcast_opt(rval)
assert res is rval
@pytest.mark.parametrize("return_index", [False])
@pytest.mark.parametrize("return_counts", [False])
@pytest.mark.parametrize("return_inverse", [False])
......
......@@ -17,12 +17,14 @@ from aesara.tensor.shape import (
Reshape,
Shape_i,
SpecifyShape,
Unbroadcast,
_specify_shape,
reshape,
shape,
shape_i,
specify_broadcastable,
specify_shape,
unbroadcast,
)
from aesara.tensor.subtensor import Subtensor
from aesara.tensor.type import (
......@@ -36,6 +38,7 @@ from aesara.tensor.type import (
lscalar,
matrix,
scalar,
tensor,
tensor3,
vector,
)
......@@ -594,3 +597,63 @@ def test_get_vector_length():
# Test `SpecifyShape`
x = specify_shape(ivector(), (10,))
assert get_vector_length(x) == 10
class TestUnbroadcast:
def test_basic(self):
x = matrix()
assert unbroadcast(x, 0) is x
assert unbroadcast(x, 1) is x
assert unbroadcast(x, 1, 0) is x
assert unbroadcast(x, 0, 1) is x
x = row()
assert unbroadcast(x, 0) is not x
assert unbroadcast(x, 1) is x
assert unbroadcast(x, 1, 0) is not x
assert unbroadcast(x, 0, 1) is not x
assert unbroadcast(unbroadcast(x, 0), 0).owner.inputs[0] is x
def test_infer_shape(self):
x = matrix()
y = unbroadcast(x, 0)
f = aesara.function([x], y.shape)
assert (f(np.zeros((2, 5), dtype=config.floatX)) == [2, 5]).all()
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert len(topo) == 3
assert isinstance(topo[0].op, Shape_i)
assert isinstance(topo[1].op, Shape_i)
assert isinstance(topo[2].op, MakeVector)
x = row()
y = unbroadcast(x, 0)
f = aesara.function([x], y.shape)
assert (f(np.zeros((1, 5), dtype=config.floatX)) == [1, 5]).all()
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert len(topo) == 2
assert isinstance(topo[0].op, Shape_i)
assert isinstance(topo[1].op, MakeVector)
def test_error_checks(self):
with pytest.raises(TypeError, match="needs integer axes"):
Unbroadcast(0.0)
with pytest.raises(ValueError, match="^Trying to unbroadcast"):
Unbroadcast(1)(vector())
class TestUnbroadcastInferShape(utt.InferShapeTester):
def test_basic(self):
rng = np.random.default_rng(3453)
adtens4 = tensor("float64", shape=(1, 1, 1, None))
adtens4_val = rng.random((1, 1, 1, 3)).astype(config.floatX)
self._compile_and_check(
[adtens4],
[Unbroadcast(0, 2)(adtens4)],
[adtens4_val],
Unbroadcast,
warn=False,
)
......@@ -16,16 +16,10 @@ from aesara.graph.optdb import OptimizationQuery
from aesara.graph.type import Type
from aesara.raise_op import Assert
from aesara.tensor import inplace
from aesara.tensor.basic import (
Alloc,
MakeVector,
Rebroadcast,
_convert_to_int8,
make_vector,
)
from aesara.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.math import Dot, add, dot, exp, sqr
from aesara.tensor.shape import SpecifyShape, _shape, shape, specify_shape
from aesara.tensor.shape import SpecifyShape, Unbroadcast, _shape, shape, specify_shape
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
......@@ -843,61 +837,61 @@ class TestLocalSubtensorLift:
f([1, 2, 3], 4) # let debugmode test something
def test_basic_8(self):
# Test that Subtensor(Rebroadcast(x)) gets optimized into
# Rebroadcast(Subtensor(x)).
# Test that Subtensor(Unbroadcast(x)) gets optimized into
# Unbroadcast(Subtensor(x)).
# test basic case
x = matrix("x")
x = row("x")
xval = np.random.random((1, 10)).astype(config.floatX)
assert x.broadcastable == (False, False)
newx = Rebroadcast((0, True), (1, False))(x)
assert newx.broadcastable == (True, False)
assert x.broadcastable == (True, False)
newx = Unbroadcast(0)(x)
assert newx.broadcastable == (False, False)
f1 = function([x], newx[:2, :5], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f1, ops_to_check=[Subtensor, Rebroadcast])
assert check_stack_trace(f1, ops_to_check=[Subtensor, Unbroadcast])
prog = f1.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Rebroadcast)
assert isinstance(prog[1].op, Unbroadcast)
assert (f1(xval) == xval[:2, :5]).all()
# corner case 1: rebroadcast changes dims which are dropped through subtensor
y = tensor4("x")
# corner case 1: Unbroadcast changes dims which are dropped through subtensor
y = tensor("float64", shape=(1, 10, 1, 3), name="x")
yval = np.random.random((1, 10, 1, 3)).astype(config.floatX)
assert y.broadcastable == (False, False, False, False)
newy = Rebroadcast((0, True), (2, True))(y)
assert newy.broadcastable == (True, False, True, False)
assert y.broadcastable == (True, False, True, False)
newy = Unbroadcast(0, 2)(y)
assert newy.broadcastable == (False, False, False, False)
f2 = function([y], newy[:, 3, 0, :], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f2, ops_to_check=[Subtensor, Rebroadcast])
assert check_stack_trace(f2, ops_to_check=[Subtensor, Unbroadcast])
prog = f2.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Rebroadcast)
assert isinstance(prog[1].op, Unbroadcast)
assert (f2(yval) == yval[:, 3, 0, :]).all()
# corner case 2: subtensor idx_list is shorter than resulting broadcast pattern
f3 = function([y], newy[:, 3, 0], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f3, ops_to_check=[Subtensor, Rebroadcast])
assert check_stack_trace(f3, ops_to_check=[Subtensor, Unbroadcast])
prog = f3.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Rebroadcast)
assert isinstance(prog[1].op, Unbroadcast)
assert (f3(yval) == yval[:, 3, 0]).all()
# corner case 3: subtensor idx_list is shorter than rebroadcast.axis
z = tensor4("x")
# corner case 3: subtensor idx_list is shorter than Unbroadcast.axis
z = tensor("float64", shape=(4, 10, 3, 1), name="x")
zval = np.random.random((4, 10, 3, 1)).astype(config.floatX)
assert z.broadcastable == (False, False, False, False)
newz = Rebroadcast((3, True))(z)
assert newz.broadcastable == (False, False, False, True)
assert z.broadcastable == (False, False, False, True)
newz = Unbroadcast(3)(z)
assert newz.broadcastable == (False, False, False, False)
f4 = function([z], newz[:, 3, 0], mode=mode_opt)
# Check stacktrace was copied over correctly after opt was applied
assert check_stack_trace(f4, ops_to_check=[Subtensor, Rebroadcast])
assert check_stack_trace(f4, ops_to_check=[Subtensor, Unbroadcast])
prog = f4.maker.fgraph.toposort()
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Rebroadcast)
assert isinstance(prog[1].op, Unbroadcast)
assert (f4(zval) == zval[:, 3, 0]).all()
......
......@@ -26,6 +26,7 @@ from aesara.graph.op import Op
from aesara.tensor.math import argmax, dot
from aesara.tensor.math import max as at_max
from aesara.tensor.nnet import conv, conv2d
from aesara.tensor.shape import unbroadcast
from aesara.tensor.signal.pool import Pool
from aesara.tensor.type import TensorType, matrix, vector
from tests import unittest_tools as utt
......@@ -237,11 +238,11 @@ class TestRopLop(RopLopChecker):
# vector
self.check_rop_lop(self.x[:4].dimshuffle("x", 0).sum(axis=0), (4,))
def test_rebroadcast(self):
def test_unbroadcast(self):
# I need the sum, because the setup expects the output to be a
# vector
self.check_rop_lop(
at.unbroadcast(self.x[:4].dimshuffle("x", 0), 0).sum(axis=1), (1,)
unbroadcast(self.x[:4].dimshuffle("x", 0), 0).sum(axis=1), (1,)
)
@pytest.mark.slow
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论