提交 2066065d authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Deprecate addbroadcast and patternbroadcast in favor of specify_broadcastable

上级 eb2b9afb
......@@ -171,7 +171,7 @@ def check_broadcast(v1, v2):
"dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using aesara.tensor."
"{patternbroadcast,unbroadcast,addbroadcast}."
"{unbroadcast, specify_broadcastable}."
)
size = min(len(v1.broadcastable), len(v2.broadcastable))
for n, (b1, b2) in enumerate(
......
......@@ -45,7 +45,7 @@ from aesara.tensor.math import (
tanh,
trunc,
)
from aesara.tensor.shape import shape
from aesara.tensor.shape import shape, specify_broadcastable
from aesara.tensor.type import TensorType
from aesara.tensor.type import continuous_dtypes as tensor_continuous_dtypes
from aesara.tensor.type import discrete_dtypes as tensor_discrete_dtypes
......@@ -1136,7 +1136,9 @@ class SparseFromDense(Op):
(x,) = inputs
(gz,) = gout
gx = dense_from_sparse(gz)
gx = at.patternbroadcast(gx, x.broadcastable)
gx = specify_broadcastable(
gx, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b)
)
return (gx,)
def infer_shape(self, fgraph, node, shapes):
......@@ -1900,9 +1902,9 @@ class SpSum(Op):
else:
ones = at.ones_like(x)
if self.axis == 0:
r = at.addbroadcast(gz.dimshuffle("x", 0), 0) * ones
r = specify_broadcastable(gz.dimshuffle("x", 0), 0) * ones
elif self.axis == 1:
r = at.addbroadcast(gz.dimshuffle(0, "x"), 1) * ones
r = specify_broadcastable(gz.dimshuffle(0, "x"), 1) * ones
else:
raise ValueError("Illegal value for self.axis.")
r = SparseFromDense(o_format)(r)
......
......@@ -10,7 +10,7 @@ import warnings
from collections.abc import Sequence
from functools import partial
from numbers import Number
from typing import Dict, Iterable, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union
from typing import cast as type_cast
import numpy as np
......@@ -49,6 +49,7 @@ from aesara.tensor.shape import (
shape_padleft,
shape_padright,
shape_tuple,
specify_broadcastable,
)
from aesara.tensor.type import (
TensorType,
......@@ -622,8 +623,6 @@ class Rebroadcast(COp):
See Also
--------
unbroadcast <aesara.tensor.unbroadcast>
addbroadcast <aesara.tensor.addbroadcast>
patternbroadcast <aesara.tensor.patternbroadcast>
Notes
-----
......@@ -2255,48 +2254,12 @@ class Split(COp):
)
def addbroadcast(x, *axes):
"""
Make the input broadcastable in the specified axes.
For example, addbroadcast(x, 0) will make the first dimension of
x broadcastable. When performing the function, if the length of
x along that dimension is not 1, a ValueError will be raised.
We apply the opt here not to pollute the graph
Parameters
----------
x : tensor_like
Input aesara tensor.
axis : an int or an iterable object such as list or tuple of int values
The dimension along which the tensor x should be broadcastable.
If the length of x along these dimensions is not 1, a ValueError will
be raised.
Returns
-------
tensor
A aesara tensor, which is broadcastable along the specified dimensions.
"""
x = as_tensor_variable(x)
if isinstance(x.type, TensorType) and not any(s is None for s in x.type.shape):
if not set(i for i, b in enumerate(x.broadcastable) if b).issuperset(axes):
raise ValueError(f"{x}'s fixed broadcast pattern does not match {axes}")
return x
rval = Rebroadcast(*[(axis, True) for axis in axes])(x)
return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval)
def unbroadcast(x, *axes):
"""
Make the input impossible to broadcast in the specified axes.
For example, addbroadcast(x, 0) will make the first dimension
of x broadcastable. When performing the function, if the length
For example, unbroadcast(x, 0) will make the first dimension
of x not broadcastable. When performing the function, if the length
of x along that dimension is not 1, a ValueError will be raised.
We apply the opt here not to pollute the graph
......@@ -2321,34 +2284,6 @@ def unbroadcast(x, *axes):
return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval)
def patternbroadcast(
x: TensorVariable, broadcastable: Iterable[Union[bool, int]]
) -> TensorVariable:
"""Make the input adopt a specific broadcasting pattern.
For example, ``patternbroadcast(x, (True, False))`` will make the first
dimension of `x` broadcastable and the second dimension not broadcastable,
so `x` will now be a row.
Parameters
----------
x
Input to re-broadcast.
broadcastable
Truthy values indicating whether or not a dimension should be
broadcastable or not. If the length of `x` along these dimensions is
not ``1``, a `ValueError` will be raised.
"""
x = as_tensor_variable(x)
if x.broadcastable == broadcastable:
return x
rval = Rebroadcast(*[(i, broadcastable[i]) for i in range(len(broadcastable))])(x)
return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval)
class Join(COp):
r"""
Concatenate multiple `TensorVariable`\s along some axis.
......@@ -2599,7 +2534,12 @@ class Join(COp):
# broadcast. As the grad need to keep the information,
# read it if needed.
split_gz = [
patternbroadcast(g, t.broadcastable) for t, g in zip(tens, split_gz)
g
if g.type.broadcastable == t.type.broadcastable
else specify_broadcastable(
g, *(ax for (ax, b) in enumerate(t.type.broadcastable) if b)
)
for t, g in zip(tens, split_gz)
]
rval = rval + split_gz
else:
......@@ -2822,7 +2762,7 @@ def stack(*tensors, **kwargs):
raise ValueError("No tensor arguments provided")
# If all tensors are scalars of the same type, call make_vector.
# It makes the graph simpler, by not adding DimShuffles and Rebroadcasts
# It makes the graph simpler, by not adding DimShuffles and SpecifyShapes
# This should be an optimization!
# Doing it here make the graph less canonicalized
......@@ -2979,7 +2919,9 @@ def flatten(x, ndim=1):
bcast_kept_dims = _x.broadcastable[: ndim - 1]
bcast_new_dim = builtins.all(_x.broadcastable[ndim - 1 :])
broadcastable = bcast_kept_dims + (bcast_new_dim,)
x_reshaped = addbroadcast(x_reshaped, *[i for i in range(ndim) if broadcastable[i]])
x_reshaped = specify_broadcastable(
x_reshaped, *[i for i in range(ndim) if broadcastable[i]]
)
return x_reshaped
......@@ -4253,9 +4195,7 @@ __all__ = [
"stack",
"roll",
"join",
"patternbroadcast",
"unbroadcast",
"addbroadcast",
"split",
"transpose",
"extract_constant",
......
......@@ -165,6 +165,7 @@ from aesara.tensor.blas_headers import blas_header_text, blas_header_version
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, add, mul, neg, sub
from aesara.tensor.shape import specify_broadcastable
from aesara.tensor.type import (
DenseTensorType,
integer_dtypes,
......@@ -2552,9 +2553,13 @@ class BatchedDot(COp):
# above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461.
if xgrad.broadcastable != x.broadcastable:
xgrad = at.patternbroadcast(xgrad, x.broadcastable)
xgrad = specify_broadcastable(
xgrad, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b)
)
if ygrad.broadcastable != y.broadcastable:
ygrad = at.patternbroadcast(ygrad, y.broadcastable)
ygrad = specify_broadcastable(
ygrad, *(ax for (ax, b) in enumerate(y.type.broadcastable) if b)
)
return xgrad, ygrad
......
......@@ -21,7 +21,6 @@ from aesara.tensor.basic import (
cast,
concatenate,
constant,
patternbroadcast,
stack,
switch,
)
......@@ -32,7 +31,7 @@ from aesara.tensor.elemwise import (
Elemwise,
scalar_elemwise,
)
from aesara.tensor.shape import shape
from aesara.tensor.shape import shape, specify_broadcastable
from aesara.tensor.type import (
DenseTensorType,
complex_dtypes,
......@@ -1961,9 +1960,13 @@ class Dot(Op):
# above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461.
if xgrad.broadcastable != x.broadcastable:
xgrad = patternbroadcast(xgrad, x.broadcastable)
xgrad = specify_broadcastable(
xgrad, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b)
)
if ygrad.broadcastable != y.broadcastable:
ygrad = patternbroadcast(ygrad, y.broadcastable)
ygrad = specify_broadcastable(
ygrad, *(ax for (ax, b) in enumerate(y.type.broadcastable) if b)
)
rval = xgrad, ygrad
......@@ -2178,7 +2181,11 @@ def _tensordot_as_dot(a, b, axes, dot, batched):
out = out_reshaped.reshape(outshape, outndim)
# Make sure the broadcastable pattern of the result is correct,
# since some shape information can be lost in the reshapes.
return patternbroadcast(out, outbcast)
if out.type.broadcastable != outbcast:
out = specify_broadcastable(
out, *(ax for (ax, b) in enumerate(outbcast) if b)
)
return out
# if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first.
......
......@@ -12,6 +12,7 @@ from aesara.tensor.basic_opt import register_specialize_device
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import mean, prod, reciprocal, sqrt
from aesara.tensor.math import sum as at_sum
from aesara.tensor.shape import specify_broadcastable
from aesara.tensor.type import TensorType
......@@ -241,8 +242,8 @@ def batch_normalization_train(
gamma = gamma.dimshuffle(params_dimshuffle_pattern)
beta = beta.dimshuffle(params_dimshuffle_pattern)
else:
gamma = at.addbroadcast(gamma, *axes)
beta = at.addbroadcast(beta, *axes)
gamma = specify_broadcastable(gamma, *axes)
beta = specify_broadcastable(beta, *axes)
batchnorm_op = AbstractBatchNormTrain(axes=axes)
......@@ -253,8 +254,8 @@ def batch_normalization_train(
running_mean = running_mean.dimshuffle(params_dimshuffle_pattern)
running_var = running_var.dimshuffle(params_dimshuffle_pattern)
else:
running_mean = at.addbroadcast(running_mean, *axes)
running_var = at.addbroadcast(running_var, *axes)
running_mean = specify_broadcastable(running_mean, *axes)
running_var = specify_broadcastable(running_var, *axes)
out, mean, invstd, new_running_mean, new_running_var = batchnorm_op(
inputs,
gamma,
......@@ -265,12 +266,14 @@ def batch_normalization_train(
running_var=running_var,
)
if new_running_mean.broadcastable != running_mean.broadcastable:
new_running_mean = at.patternbroadcast(
new_running_mean, running_mean.broadcastable
new_running_mean = specify_broadcastable(
new_running_mean,
*(ax for (ax, b) in enumerate(running_mean.type.broadcastable) if b),
)
if new_running_var.broadcastable != running_var.broadcastable:
new_running_var = at.patternbroadcast(
new_running_var, running_var.broadcastable
new_running_var = specify_broadcastable(
new_running_var,
*(ax for (ax, b) in enumerate(running_var.type.broadcastable) if b),
)
results = (out, mean, invstd, new_running_mean, new_running_var)
else:
......@@ -331,7 +334,7 @@ def batch_normalization_test(
axes = (0,)
# for spatial normalization
axes = (0,) + tuple(range(2, inputs.ndim))
gamma, beta, mean, var = (at.addbroadcast(t, *axes)
gamma, beta, mean, var = (at.specify_broadcastable(t, *axes)
for t in (gamma, beta, mean, var))
out = (inputs - mean) * gamma / at.sqrt(var + epsilon) + beta
"""
......@@ -377,10 +380,10 @@ def batch_normalization_test(
mean = mean.dimshuffle(params_dimshuffle_pattern)
var = var.dimshuffle(params_dimshuffle_pattern)
else:
gamma = at.addbroadcast(gamma, *axes)
beta = at.addbroadcast(beta, *axes)
mean = at.addbroadcast(mean, *axes)
var = at.addbroadcast(var, *axes)
gamma = specify_broadcastable(gamma, *axes)
beta = specify_broadcastable(beta, *axes)
mean = specify_broadcastable(mean, *axes)
var = specify_broadcastable(var, *axes)
batchnorm_op = AbstractBatchNormInference(axes=axes)
return batchnorm_op(inputs, gamma, beta, mean, var, epsilon=epsilon)
......@@ -609,7 +612,7 @@ class AbstractBatchNormInference(Op):
)
scale, bias, est_mean, est_var = (
at.addbroadcast(t, *axes) for t in (scale, bias, est_mean, est_var)
specify_broadcastable(t, *axes) for t in (scale, bias, est_mean, est_var)
)
# define helper expressions
......
......@@ -26,13 +26,10 @@ import aesara
from aesara.graph.basic import Apply
from aesara.link.c.op import OpenMPOp
from aesara.tensor import blas
from aesara.tensor.basic import (
as_tensor_variable,
get_scalar_constant_value,
patternbroadcast,
)
from aesara.tensor.basic import as_tensor_variable, get_scalar_constant_value
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.nnet.abstract_conv import get_conv_output_shape, get_conv_shape_1axis
from aesara.tensor.shape import specify_broadcastable
from aesara.tensor.type import discrete_dtypes, tensor
......@@ -1103,8 +1100,14 @@ class ConvOp(OpenMPOp):
# din and dw should have the same broadcasting pattern as the
# parameters they are the gradient of (resp. inputs and kerns).
din = patternbroadcast(din, inputs.broadcastable)
dw = patternbroadcast(dw, kerns.broadcastable)
if din.type.broadcastable != inputs.type.broadcastable:
din = specify_broadcastable(
din, *(ax for (ax, b) in enumerate(inputs.type.broadcastable) if b)
)
if dw.type.broadcastable != kerns.type.broadcastable:
dw = specify_broadcastable(
dw, *(ax for (ax, b) in enumerate(kerns.type.broadcastable) if b)
)
return [din, dw]
def c_headers(self, **kwargs):
......
......@@ -12,7 +12,7 @@ from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType
from aesara.misc.safe_asarray import _asarray
from aesara.scalar import int32
from aesara.tensor import _get_vector_length
from aesara.tensor import _get_vector_length, as_tensor_variable
from aesara.tensor import basic as at
from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError
......@@ -891,3 +891,38 @@ register_shape_i_c_code(
""",
version=3,
)
def specify_broadcastable(x, *axes):
"""
Specify the input as being broadcastable in the specified axes.
For example, specify_broadcastable(x, 0) will make the first dimension of
x broadcastable. When performing the function, if the length of
x along that dimension is not 1, a ValueError will be raised.
Parameters
----------
x : tensor_like
Input aesara tensor.
axis : an int or an iterable object such as list or tuple of int values
The dimension along which the tensor x should be broadcastable.
If the length of x along these dimensions is not 1, a ValueError will
be raised.
Returns
-------
tensor
A aesara tensor, which is broadcastable along the specified dimensions.
"""
x = as_tensor_variable(x)
if not axes:
return x
if max(axes) >= x.type.ndim:
raise ValueError("Trying to specify broadcastable of non-existent dimension")
shape_info = [1 if i in axes else None for i in range(len(x.type.shape))]
return specify_shape(x, shape_info)
......@@ -20,7 +20,7 @@ from aesara.misc.safe_asarray import _asarray
from aesara.printing import Printer, pprint, set_precedence
from aesara.scalar.basic import ScalarConstant
from aesara.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from aesara.tensor.basic import addbroadcast, alloc, get_scalar_constant_value
from aesara.tensor.basic import alloc, get_scalar_constant_value
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import (
AdvancedIndexingError,
......@@ -28,7 +28,7 @@ from aesara.tensor.exceptions import (
ShapeError,
)
from aesara.tensor.math import clip
from aesara.tensor.shape import Reshape
from aesara.tensor.shape import Reshape, specify_broadcastable
from aesara.tensor.type import (
TensorType,
bscalar,
......@@ -1322,8 +1322,8 @@ def inc_subtensor(
# It is acceptable to try to increment a subtensor with a
# broadcastable dim with a tensor that is not broadcastable
# on that dimension. However, its length must then be 1.
# We insert a Rebroadcast Op to make sure it is the case.
y = addbroadcast(y, dim)
# We insert a SpecifyShape Op to make sure it is the case.
y = specify_broadcastable(y, dim)
if not x.owner:
raise TypeError("x must be the result of a subtensor operation")
......
......@@ -8,6 +8,7 @@ import aesara.tensor as at
from aesara.configdefaults import config
from aesara.tensor.math import sum as at_sum
from aesara.tensor.nnet import batchnorm
from aesara.tensor.shape import specify_broadcastable
from aesara.tensor.type import (
TensorType,
matrix,
......@@ -219,8 +220,8 @@ def test_batch_normalization_train():
x_mean2 = x.mean(axis=axes2, keepdims=True)
x_var2 = x.var(axis=axes2, keepdims=True)
x_invstd2 = at.reciprocal(at.sqrt(x_var2 + eps))
scale2 = at.addbroadcast(scale, *axes2)
bias2 = at.addbroadcast(bias, *axes2)
scale2 = specify_broadcastable(scale, *axes2)
bias2 = specify_broadcastable(bias, *axes2)
out2 = (x - x_mean2) * (scale2 * x_invstd2) + bias2
m = at.cast(at.prod(x.shape) / at.prod(scale.shape), aesara.config.floatX)
out_running_mean2 = (
......@@ -597,7 +598,7 @@ def test_batch_normalization_test():
else:
axes2 = axes
scale2, bias2, mean2, var2 = (
at.addbroadcast(t, *axes2) for t in (scale, bias, mean, var)
specify_broadcastable(t, *axes2) for t in (scale, bias, mean, var)
)
out2 = (x - mean2) * (scale2 / at.sqrt(var2 + eps)) + bias2
# backward pass
......
......@@ -39,7 +39,6 @@ from aesara.tensor.basic import (
Split,
TensorFromScalar,
Tri,
addbroadcast,
alloc,
arange,
as_tensor_variable,
......@@ -69,7 +68,6 @@ from aesara.tensor.basic import (
nonzero_values,
ogrid,
ones_like,
patternbroadcast,
permute_row_elements,
roll,
scalar_from_tensor,
......@@ -3226,20 +3224,7 @@ class TestLongTensor:
class TestBroadcast:
def test_addbroadcast_validation(self):
x = as_tensor_variable(np.zeros((2, 3)))
with pytest.raises(ValueError, match=".*pattern does not.*"):
addbroadcast(x, 4)
def test_broadcast_bigdim(self):
def f():
x = matrix()
addbroadcast(x, 2)
with pytest.raises(ValueError):
f()
def test_unbroadcast_addbroadcast(self):
def test_unbroadcast(self):
# test that the unbroadcast fct don't insert not needed broadcast
# and fuse consecutive Rebroadcast op
......@@ -3249,26 +3234,12 @@ class TestBroadcast:
assert unbroadcast(x, 1, 0) is x
assert unbroadcast(x, 0, 1) is x
assert addbroadcast(x, 0) is not x
assert addbroadcast(x, 1) is not x
assert addbroadcast(x, 1, 0).owner.inputs[0] is x
assert unbroadcast(addbroadcast(x, 0), 0) is x
assert addbroadcast(unbroadcast(x, 0), 0) is not x
x = row()
assert unbroadcast(x, 0) is not x
assert unbroadcast(x, 1) is x
assert unbroadcast(x, 1, 0) is not x
assert unbroadcast(x, 0, 1) is not x
assert addbroadcast(x, 0) is x
assert addbroadcast(x, 1).owner.inputs[0] is x
assert addbroadcast(x, 1, 0).owner.inputs[0] is x
assert addbroadcast(x, 0, 1).owner.inputs[0] is x
assert unbroadcast(addbroadcast(x, 1), 1) is x
assert addbroadcast(unbroadcast(x, 1), 1) is not x
# The first broadcast is remove the broadcast, so the second
# should not make one
assert unbroadcast(unbroadcast(x, 0), 0).owner.inputs[0] is x
......@@ -3276,29 +3247,8 @@ class TestBroadcast:
# Test that consecutive Rebroadcast op are fused
x = TensorType(dtype="float64", shape=(True, True))()
assert unbroadcast(unbroadcast(x, 1), 0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x, 1), 0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x, 0), 0) is x
def test_patternbroadcast(self):
# Test that patternbroadcast with an empty broadcasting pattern works
x = scalar("x")
m = matrix("m")
s = patternbroadcast(m, x.broadcastable)
assert s is m
x2 = patternbroadcast(x, x.broadcastable)
assert x2 is x
def test_infer_shape(self):
x = matrix()
y = addbroadcast(x, 0)
f = aesara.function([x], y.shape)
assert (f(np.zeros((1, 5), dtype=config.floatX)) == [1, 5]).all()
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert len(topo) == 2
assert isinstance(topo[0].op, Shape_i)
assert isinstance(topo[1].op, MakeVector)
x = matrix()
y = unbroadcast(x, 0)
f = aesara.function([x], y.shape)
......
......@@ -1911,18 +1911,6 @@ class TestRebroadcast:
assert check_stack_trace(f, ops_to_check="all")
def test_rebroadcast_rebroadcast(self):
mode = get_default_mode().including("canonicalize")
m = matrix()
s = at.addbroadcast(m, 0, 1)
v = at.unbroadcast(s, 1)
f = function([m], v, mode=mode)
f([[76]])
e = f.maker.fgraph.toposort()
rebroadcast_nodes = [n for n in e if isinstance(n.op, Rebroadcast)]
assert len(rebroadcast_nodes) == 1
assert rebroadcast_nodes[0].op.axis == {0: True}
class TestUselessElemwise:
def setup_method(self):
......
......@@ -1918,6 +1918,9 @@ class TestDot:
# These examples should all work. All dimensions of all results have
# size 1.
#
def is_super_shape(var1, var2):
# Check that var1.type is a superset of var2.type, ignoring dtype
return var1.type.is_super(var2.type.clone(dtype=var1.type.dtype))
for dtype0 in ("float32", "float64", "complex64"):
for dtype1 in ("float32", "complex64", "complex128"):
......@@ -1944,9 +1947,9 @@ class TestDot:
if dtype0.startswith("float") and dtype1.startswith("float"):
g = grad(z.sum(), x)
assert g.broadcastable == x.broadcastable
assert is_super_shape(x, g)
g = grad(z.sum(), y)
assert g.broadcastable == y.broadcastable
assert is_super_shape(y, g)
class TestTensordot:
......
......@@ -19,7 +19,7 @@ from aesara.tensor.opt_uncanonicalize import (
local_dimshuffle_subtensor,
local_reshape_dimshuffle,
)
from aesara.tensor.shape import reshape
from aesara.tensor.shape import reshape, specify_shape
from aesara.tensor.type import dtensor4, iscalar, matrix, tensor, vector
from tests.link.test_link import make_function
......@@ -179,7 +179,7 @@ def test_local_dimshuffle_subtensor():
dimshuffle_subtensor = out2in(local_dimshuffle_subtensor)
x = dtensor4("x")
x = at.patternbroadcast(x, (False, True, False, False))
x = specify_shape(x, (None, 1, None, None))
i = iscalar("i")
out = x[:, :, 10:30, ::i].dimshuffle(0, 2, 3)
......@@ -213,7 +213,7 @@ def test_local_dimshuffle_subtensor():
# Test a corner case that had Aesara return a bug.
x = dtensor4("x")
x = at.patternbroadcast(x, (False, True, False, False))
x = specify_shape(x, (None, 1, None, None))
assert x[:, :, 0:3, ::-1].dimshuffle(0, 2, 3).eval(
{x: np.ones((5, 1, 6, 7))}
......
......@@ -9,7 +9,7 @@ from aesara.graph.basic import Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray
from aesara.tensor import as_tensor_variable, get_vector_length
from aesara.tensor import as_tensor_variable, get_vector_length, row
from aesara.tensor.basic import MakeVector, constant
from aesara.tensor.basic_opt import ShapeFeature
from aesara.tensor.elemwise import DimShuffle, Elemwise
......@@ -21,6 +21,7 @@ from aesara.tensor.shape import (
reshape,
shape,
shape_i,
specify_broadcastable,
specify_shape,
)
from aesara.tensor.subtensor import Subtensor
......@@ -518,6 +519,23 @@ class TestSpecifyShape(utt.InferShapeTester):
assert isinstance(z_grad.owner.op, SpecifyShape)
class TestSpecifyBroadcastable:
def test_basic(self):
x = matrix()
assert specify_broadcastable(x, 0).type.shape == (1, None)
assert specify_broadcastable(x, 1).type.shape == (None, 1)
assert specify_broadcastable(x, 0, 1).type.shape == (1, 1)
x = row()
assert specify_broadcastable(x, 0) is x
assert specify_broadcastable(x, 1) is not x
def test_validation(self):
x = matrix()
with pytest.raises(ValueError, match="^Trying to specify broadcastable of*"):
specify_broadcastable(x, 2)
class TestRopLop(RopLopChecker):
def test_shape(self):
self.check_nondiff_rop(self.x.shape[0])
......
......@@ -1501,11 +1501,11 @@ class TestIncSubtensor:
# This one should work
f(rng_randX(3, 1), rng_randX(1))
# These ones should not
with pytest.raises(ValueError):
with pytest.raises(AssertionError):
f(rng_randX(3, 1), rng_randX(2))
with pytest.raises(ValueError):
with pytest.raises(AssertionError):
f(rng_randX(3, 1), rng_randX(3))
with pytest.raises(ValueError):
with pytest.raises(AssertionError):
f(rng_randX(3, 1), rng_randX(0))
def test_simple_3d(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论