提交 2066065d authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Deprecate addbroadcast and patternbroadcast in favor of specify_broadcastable

上级 eb2b9afb
...@@ -171,7 +171,7 @@ def check_broadcast(v1, v2): ...@@ -171,7 +171,7 @@ def check_broadcast(v1, v2):
"dimension is fixed to 1 in the input, while it is still " "dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make " "variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using aesara.tensor." "them consistent, e.g. using aesara.tensor."
"{patternbroadcast,unbroadcast,addbroadcast}." "{unbroadcast, specify_broadcastable}."
) )
size = min(len(v1.broadcastable), len(v2.broadcastable)) size = min(len(v1.broadcastable), len(v2.broadcastable))
for n, (b1, b2) in enumerate( for n, (b1, b2) in enumerate(
......
...@@ -45,7 +45,7 @@ from aesara.tensor.math import ( ...@@ -45,7 +45,7 @@ from aesara.tensor.math import (
tanh, tanh,
trunc, trunc,
) )
from aesara.tensor.shape import shape from aesara.tensor.shape import shape, specify_broadcastable
from aesara.tensor.type import TensorType from aesara.tensor.type import TensorType
from aesara.tensor.type import continuous_dtypes as tensor_continuous_dtypes from aesara.tensor.type import continuous_dtypes as tensor_continuous_dtypes
from aesara.tensor.type import discrete_dtypes as tensor_discrete_dtypes from aesara.tensor.type import discrete_dtypes as tensor_discrete_dtypes
...@@ -1136,7 +1136,9 @@ class SparseFromDense(Op): ...@@ -1136,7 +1136,9 @@ class SparseFromDense(Op):
(x,) = inputs (x,) = inputs
(gz,) = gout (gz,) = gout
gx = dense_from_sparse(gz) gx = dense_from_sparse(gz)
gx = at.patternbroadcast(gx, x.broadcastable) gx = specify_broadcastable(
gx, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b)
)
return (gx,) return (gx,)
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
...@@ -1900,9 +1902,9 @@ class SpSum(Op): ...@@ -1900,9 +1902,9 @@ class SpSum(Op):
else: else:
ones = at.ones_like(x) ones = at.ones_like(x)
if self.axis == 0: if self.axis == 0:
r = at.addbroadcast(gz.dimshuffle("x", 0), 0) * ones r = specify_broadcastable(gz.dimshuffle("x", 0), 0) * ones
elif self.axis == 1: elif self.axis == 1:
r = at.addbroadcast(gz.dimshuffle(0, "x"), 1) * ones r = specify_broadcastable(gz.dimshuffle(0, "x"), 1) * ones
else: else:
raise ValueError("Illegal value for self.axis.") raise ValueError("Illegal value for self.axis.")
r = SparseFromDense(o_format)(r) r = SparseFromDense(o_format)(r)
......
...@@ -10,7 +10,7 @@ import warnings ...@@ -10,7 +10,7 @@ import warnings
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial from functools import partial
from numbers import Number from numbers import Number
from typing import Dict, Iterable, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
from typing import cast as type_cast from typing import cast as type_cast
import numpy as np import numpy as np
...@@ -49,6 +49,7 @@ from aesara.tensor.shape import ( ...@@ -49,6 +49,7 @@ from aesara.tensor.shape import (
shape_padleft, shape_padleft,
shape_padright, shape_padright,
shape_tuple, shape_tuple,
specify_broadcastable,
) )
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
...@@ -622,8 +623,6 @@ class Rebroadcast(COp): ...@@ -622,8 +623,6 @@ class Rebroadcast(COp):
See Also See Also
-------- --------
unbroadcast <aesara.tensor.unbroadcast> unbroadcast <aesara.tensor.unbroadcast>
addbroadcast <aesara.tensor.addbroadcast>
patternbroadcast <aesara.tensor.patternbroadcast>
Notes Notes
----- -----
...@@ -2255,48 +2254,12 @@ class Split(COp): ...@@ -2255,48 +2254,12 @@ class Split(COp):
) )
def addbroadcast(x, *axes):
"""
Make the input broadcastable in the specified axes.
For example, addbroadcast(x, 0) will make the first dimension of
x broadcastable. When performing the function, if the length of
x along that dimension is not 1, a ValueError will be raised.
We apply the opt here not to pollute the graph
Parameters
----------
x : tensor_like
Input aesara tensor.
axis : an int or an iterable object such as list or tuple of int values
The dimension along which the tensor x should be broadcastable.
If the length of x along these dimensions is not 1, a ValueError will
be raised.
Returns
-------
tensor
A aesara tensor, which is broadcastable along the specified dimensions.
"""
x = as_tensor_variable(x)
if isinstance(x.type, TensorType) and not any(s is None for s in x.type.shape):
if not set(i for i, b in enumerate(x.broadcastable) if b).issuperset(axes):
raise ValueError(f"{x}'s fixed broadcast pattern does not match {axes}")
return x
rval = Rebroadcast(*[(axis, True) for axis in axes])(x)
return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval)
def unbroadcast(x, *axes): def unbroadcast(x, *axes):
""" """
Make the input impossible to broadcast in the specified axes. Make the input impossible to broadcast in the specified axes.
For example, addbroadcast(x, 0) will make the first dimension For example, unbroadcast(x, 0) will make the first dimension
of x broadcastable. When performing the function, if the length of x not broadcastable. When performing the function, if the length
of x along that dimension is not 1, a ValueError will be raised. of x along that dimension is not 1, a ValueError will be raised.
We apply the opt here not to pollute the graph We apply the opt here not to pollute the graph
...@@ -2321,34 +2284,6 @@ def unbroadcast(x, *axes): ...@@ -2321,34 +2284,6 @@ def unbroadcast(x, *axes):
return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval) return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval)
def patternbroadcast(
x: TensorVariable, broadcastable: Iterable[Union[bool, int]]
) -> TensorVariable:
"""Make the input adopt a specific broadcasting pattern.
For example, ``patternbroadcast(x, (True, False))`` will make the first
dimension of `x` broadcastable and the second dimension not broadcastable,
so `x` will now be a row.
Parameters
----------
x
Input to re-broadcast.
broadcastable
Truthy values indicating whether or not a dimension should be
broadcastable or not. If the length of `x` along these dimensions is
not ``1``, a `ValueError` will be raised.
"""
x = as_tensor_variable(x)
if x.broadcastable == broadcastable:
return x
rval = Rebroadcast(*[(i, broadcastable[i]) for i in range(len(broadcastable))])(x)
return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval)
class Join(COp): class Join(COp):
r""" r"""
Concatenate multiple `TensorVariable`\s along some axis. Concatenate multiple `TensorVariable`\s along some axis.
...@@ -2599,7 +2534,12 @@ class Join(COp): ...@@ -2599,7 +2534,12 @@ class Join(COp):
# broadcast. As the grad need to keep the information, # broadcast. As the grad need to keep the information,
# read it if needed. # read it if needed.
split_gz = [ split_gz = [
patternbroadcast(g, t.broadcastable) for t, g in zip(tens, split_gz) g
if g.type.broadcastable == t.type.broadcastable
else specify_broadcastable(
g, *(ax for (ax, b) in enumerate(t.type.broadcastable) if b)
)
for t, g in zip(tens, split_gz)
] ]
rval = rval + split_gz rval = rval + split_gz
else: else:
...@@ -2822,7 +2762,7 @@ def stack(*tensors, **kwargs): ...@@ -2822,7 +2762,7 @@ def stack(*tensors, **kwargs):
raise ValueError("No tensor arguments provided") raise ValueError("No tensor arguments provided")
# If all tensors are scalars of the same type, call make_vector. # If all tensors are scalars of the same type, call make_vector.
# It makes the graph simpler, by not adding DimShuffles and Rebroadcasts # It makes the graph simpler, by not adding DimShuffles and SpecifyShapes
# This should be an optimization! # This should be an optimization!
# Doing it here make the graph less canonicalized # Doing it here make the graph less canonicalized
...@@ -2979,7 +2919,9 @@ def flatten(x, ndim=1): ...@@ -2979,7 +2919,9 @@ def flatten(x, ndim=1):
bcast_kept_dims = _x.broadcastable[: ndim - 1] bcast_kept_dims = _x.broadcastable[: ndim - 1]
bcast_new_dim = builtins.all(_x.broadcastable[ndim - 1 :]) bcast_new_dim = builtins.all(_x.broadcastable[ndim - 1 :])
broadcastable = bcast_kept_dims + (bcast_new_dim,) broadcastable = bcast_kept_dims + (bcast_new_dim,)
x_reshaped = addbroadcast(x_reshaped, *[i for i in range(ndim) if broadcastable[i]]) x_reshaped = specify_broadcastable(
x_reshaped, *[i for i in range(ndim) if broadcastable[i]]
)
return x_reshaped return x_reshaped
...@@ -4253,9 +4195,7 @@ __all__ = [ ...@@ -4253,9 +4195,7 @@ __all__ = [
"stack", "stack",
"roll", "roll",
"join", "join",
"patternbroadcast",
"unbroadcast", "unbroadcast",
"addbroadcast",
"split", "split",
"transpose", "transpose",
"extract_constant", "extract_constant",
......
...@@ -165,6 +165,7 @@ from aesara.tensor.blas_headers import blas_header_text, blas_header_version ...@@ -165,6 +165,7 @@ from aesara.tensor.blas_headers import blas_header_text, blas_header_version
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import Dot, add, mul, neg, sub from aesara.tensor.math import Dot, add, mul, neg, sub
from aesara.tensor.shape import specify_broadcastable
from aesara.tensor.type import ( from aesara.tensor.type import (
DenseTensorType, DenseTensorType,
integer_dtypes, integer_dtypes,
...@@ -2552,9 +2553,13 @@ class BatchedDot(COp): ...@@ -2552,9 +2553,13 @@ class BatchedDot(COp):
# above code don't always return the right broadcast pattern. # above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461. # This cause problem down the road. See gh-1461.
if xgrad.broadcastable != x.broadcastable: if xgrad.broadcastable != x.broadcastable:
xgrad = at.patternbroadcast(xgrad, x.broadcastable) xgrad = specify_broadcastable(
xgrad, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b)
)
if ygrad.broadcastable != y.broadcastable: if ygrad.broadcastable != y.broadcastable:
ygrad = at.patternbroadcast(ygrad, y.broadcastable) ygrad = specify_broadcastable(
ygrad, *(ax for (ax, b) in enumerate(y.type.broadcastable) if b)
)
return xgrad, ygrad return xgrad, ygrad
......
...@@ -21,7 +21,6 @@ from aesara.tensor.basic import ( ...@@ -21,7 +21,6 @@ from aesara.tensor.basic import (
cast, cast,
concatenate, concatenate,
constant, constant,
patternbroadcast,
stack, stack,
switch, switch,
) )
...@@ -32,7 +31,7 @@ from aesara.tensor.elemwise import ( ...@@ -32,7 +31,7 @@ from aesara.tensor.elemwise import (
Elemwise, Elemwise,
scalar_elemwise, scalar_elemwise,
) )
from aesara.tensor.shape import shape from aesara.tensor.shape import shape, specify_broadcastable
from aesara.tensor.type import ( from aesara.tensor.type import (
DenseTensorType, DenseTensorType,
complex_dtypes, complex_dtypes,
...@@ -1961,9 +1960,13 @@ class Dot(Op): ...@@ -1961,9 +1960,13 @@ class Dot(Op):
# above code don't always return the right broadcast pattern. # above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461. # This cause problem down the road. See gh-1461.
if xgrad.broadcastable != x.broadcastable: if xgrad.broadcastable != x.broadcastable:
xgrad = patternbroadcast(xgrad, x.broadcastable) xgrad = specify_broadcastable(
xgrad, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b)
)
if ygrad.broadcastable != y.broadcastable: if ygrad.broadcastable != y.broadcastable:
ygrad = patternbroadcast(ygrad, y.broadcastable) ygrad = specify_broadcastable(
ygrad, *(ax for (ax, b) in enumerate(y.type.broadcastable) if b)
)
rval = xgrad, ygrad rval = xgrad, ygrad
...@@ -2178,7 +2181,11 @@ def _tensordot_as_dot(a, b, axes, dot, batched): ...@@ -2178,7 +2181,11 @@ def _tensordot_as_dot(a, b, axes, dot, batched):
out = out_reshaped.reshape(outshape, outndim) out = out_reshaped.reshape(outshape, outndim)
# Make sure the broadcastable pattern of the result is correct, # Make sure the broadcastable pattern of the result is correct,
# since some shape information can be lost in the reshapes. # since some shape information can be lost in the reshapes.
return patternbroadcast(out, outbcast) if out.type.broadcastable != outbcast:
out = specify_broadcastable(
out, *(ax for (ax, b) in enumerate(outbcast) if b)
)
return out
# if 'axes' is a list, transpose a and b such that the summed axes of a # if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first. # are last and the summed axes of b are first.
......
...@@ -12,6 +12,7 @@ from aesara.tensor.basic_opt import register_specialize_device ...@@ -12,6 +12,7 @@ from aesara.tensor.basic_opt import register_specialize_device
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import mean, prod, reciprocal, sqrt from aesara.tensor.math import mean, prod, reciprocal, sqrt
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.shape import specify_broadcastable
from aesara.tensor.type import TensorType from aesara.tensor.type import TensorType
...@@ -241,8 +242,8 @@ def batch_normalization_train( ...@@ -241,8 +242,8 @@ def batch_normalization_train(
gamma = gamma.dimshuffle(params_dimshuffle_pattern) gamma = gamma.dimshuffle(params_dimshuffle_pattern)
beta = beta.dimshuffle(params_dimshuffle_pattern) beta = beta.dimshuffle(params_dimshuffle_pattern)
else: else:
gamma = at.addbroadcast(gamma, *axes) gamma = specify_broadcastable(gamma, *axes)
beta = at.addbroadcast(beta, *axes) beta = specify_broadcastable(beta, *axes)
batchnorm_op = AbstractBatchNormTrain(axes=axes) batchnorm_op = AbstractBatchNormTrain(axes=axes)
...@@ -253,8 +254,8 @@ def batch_normalization_train( ...@@ -253,8 +254,8 @@ def batch_normalization_train(
running_mean = running_mean.dimshuffle(params_dimshuffle_pattern) running_mean = running_mean.dimshuffle(params_dimshuffle_pattern)
running_var = running_var.dimshuffle(params_dimshuffle_pattern) running_var = running_var.dimshuffle(params_dimshuffle_pattern)
else: else:
running_mean = at.addbroadcast(running_mean, *axes) running_mean = specify_broadcastable(running_mean, *axes)
running_var = at.addbroadcast(running_var, *axes) running_var = specify_broadcastable(running_var, *axes)
out, mean, invstd, new_running_mean, new_running_var = batchnorm_op( out, mean, invstd, new_running_mean, new_running_var = batchnorm_op(
inputs, inputs,
gamma, gamma,
...@@ -265,12 +266,14 @@ def batch_normalization_train( ...@@ -265,12 +266,14 @@ def batch_normalization_train(
running_var=running_var, running_var=running_var,
) )
if new_running_mean.broadcastable != running_mean.broadcastable: if new_running_mean.broadcastable != running_mean.broadcastable:
new_running_mean = at.patternbroadcast( new_running_mean = specify_broadcastable(
new_running_mean, running_mean.broadcastable new_running_mean,
*(ax for (ax, b) in enumerate(running_mean.type.broadcastable) if b),
) )
if new_running_var.broadcastable != running_var.broadcastable: if new_running_var.broadcastable != running_var.broadcastable:
new_running_var = at.patternbroadcast( new_running_var = specify_broadcastable(
new_running_var, running_var.broadcastable new_running_var,
*(ax for (ax, b) in enumerate(running_var.type.broadcastable) if b),
) )
results = (out, mean, invstd, new_running_mean, new_running_var) results = (out, mean, invstd, new_running_mean, new_running_var)
else: else:
...@@ -331,7 +334,7 @@ def batch_normalization_test( ...@@ -331,7 +334,7 @@ def batch_normalization_test(
axes = (0,) axes = (0,)
# for spatial normalization # for spatial normalization
axes = (0,) + tuple(range(2, inputs.ndim)) axes = (0,) + tuple(range(2, inputs.ndim))
gamma, beta, mean, var = (at.addbroadcast(t, *axes) gamma, beta, mean, var = (at.specify_broadcastable(t, *axes)
for t in (gamma, beta, mean, var)) for t in (gamma, beta, mean, var))
out = (inputs - mean) * gamma / at.sqrt(var + epsilon) + beta out = (inputs - mean) * gamma / at.sqrt(var + epsilon) + beta
""" """
...@@ -377,10 +380,10 @@ def batch_normalization_test( ...@@ -377,10 +380,10 @@ def batch_normalization_test(
mean = mean.dimshuffle(params_dimshuffle_pattern) mean = mean.dimshuffle(params_dimshuffle_pattern)
var = var.dimshuffle(params_dimshuffle_pattern) var = var.dimshuffle(params_dimshuffle_pattern)
else: else:
gamma = at.addbroadcast(gamma, *axes) gamma = specify_broadcastable(gamma, *axes)
beta = at.addbroadcast(beta, *axes) beta = specify_broadcastable(beta, *axes)
mean = at.addbroadcast(mean, *axes) mean = specify_broadcastable(mean, *axes)
var = at.addbroadcast(var, *axes) var = specify_broadcastable(var, *axes)
batchnorm_op = AbstractBatchNormInference(axes=axes) batchnorm_op = AbstractBatchNormInference(axes=axes)
return batchnorm_op(inputs, gamma, beta, mean, var, epsilon=epsilon) return batchnorm_op(inputs, gamma, beta, mean, var, epsilon=epsilon)
...@@ -609,7 +612,7 @@ class AbstractBatchNormInference(Op): ...@@ -609,7 +612,7 @@ class AbstractBatchNormInference(Op):
) )
scale, bias, est_mean, est_var = ( scale, bias, est_mean, est_var = (
at.addbroadcast(t, *axes) for t in (scale, bias, est_mean, est_var) specify_broadcastable(t, *axes) for t in (scale, bias, est_mean, est_var)
) )
# define helper expressions # define helper expressions
......
...@@ -26,13 +26,10 @@ import aesara ...@@ -26,13 +26,10 @@ import aesara
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.link.c.op import OpenMPOp from aesara.link.c.op import OpenMPOp
from aesara.tensor import blas from aesara.tensor import blas
from aesara.tensor.basic import ( from aesara.tensor.basic import as_tensor_variable, get_scalar_constant_value
as_tensor_variable,
get_scalar_constant_value,
patternbroadcast,
)
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.nnet.abstract_conv import get_conv_output_shape, get_conv_shape_1axis from aesara.tensor.nnet.abstract_conv import get_conv_output_shape, get_conv_shape_1axis
from aesara.tensor.shape import specify_broadcastable
from aesara.tensor.type import discrete_dtypes, tensor from aesara.tensor.type import discrete_dtypes, tensor
...@@ -1103,8 +1100,14 @@ class ConvOp(OpenMPOp): ...@@ -1103,8 +1100,14 @@ class ConvOp(OpenMPOp):
# din and dw should have the same broadcasting pattern as the # din and dw should have the same broadcasting pattern as the
# parameters they are the gradient of (resp. inputs and kerns). # parameters they are the gradient of (resp. inputs and kerns).
din = patternbroadcast(din, inputs.broadcastable) if din.type.broadcastable != inputs.type.broadcastable:
dw = patternbroadcast(dw, kerns.broadcastable) din = specify_broadcastable(
din, *(ax for (ax, b) in enumerate(inputs.type.broadcastable) if b)
)
if dw.type.broadcastable != kerns.type.broadcastable:
dw = specify_broadcastable(
dw, *(ax for (ax, b) in enumerate(kerns.type.broadcastable) if b)
)
return [din, dw] return [din, dw]
def c_headers(self, **kwargs): def c_headers(self, **kwargs):
......
...@@ -12,7 +12,7 @@ from aesara.link.c.op import COp ...@@ -12,7 +12,7 @@ from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType from aesara.link.c.params_type import ParamsType
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.scalar import int32 from aesara.scalar import int32
from aesara.tensor import _get_vector_length from aesara.tensor import _get_vector_length, as_tensor_variable
from aesara.tensor import basic as at from aesara.tensor import basic as at
from aesara.tensor import get_vector_length from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
...@@ -891,3 +891,38 @@ register_shape_i_c_code( ...@@ -891,3 +891,38 @@ register_shape_i_c_code(
""", """,
version=3, version=3,
) )
def specify_broadcastable(x, *axes):
"""
Specify the input as being broadcastable in the specified axes.
For example, specify_broadcastable(x, 0) will make the first dimension of
x broadcastable. When performing the function, if the length of
x along that dimension is not 1, a ValueError will be raised.
Parameters
----------
x : tensor_like
Input aesara tensor.
axis : an int or an iterable object such as list or tuple of int values
The dimension along which the tensor x should be broadcastable.
If the length of x along these dimensions is not 1, a ValueError will
be raised.
Returns
-------
tensor
A aesara tensor, which is broadcastable along the specified dimensions.
"""
x = as_tensor_variable(x)
if not axes:
return x
if max(axes) >= x.type.ndim:
raise ValueError("Trying to specify broadcastable of non-existent dimension")
shape_info = [1 if i in axes else None for i in range(len(x.type.shape))]
return specify_shape(x, shape_info)
...@@ -20,7 +20,7 @@ from aesara.misc.safe_asarray import _asarray ...@@ -20,7 +20,7 @@ from aesara.misc.safe_asarray import _asarray
from aesara.printing import Printer, pprint, set_precedence from aesara.printing import Printer, pprint, set_precedence
from aesara.scalar.basic import ScalarConstant from aesara.scalar.basic import ScalarConstant
from aesara.tensor import _get_vector_length, as_tensor_variable, get_vector_length from aesara.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from aesara.tensor.basic import addbroadcast, alloc, get_scalar_constant_value from aesara.tensor.basic import alloc, get_scalar_constant_value
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import ( from aesara.tensor.exceptions import (
AdvancedIndexingError, AdvancedIndexingError,
...@@ -28,7 +28,7 @@ from aesara.tensor.exceptions import ( ...@@ -28,7 +28,7 @@ from aesara.tensor.exceptions import (
ShapeError, ShapeError,
) )
from aesara.tensor.math import clip from aesara.tensor.math import clip
from aesara.tensor.shape import Reshape from aesara.tensor.shape import Reshape, specify_broadcastable
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
bscalar, bscalar,
...@@ -1322,8 +1322,8 @@ def inc_subtensor( ...@@ -1322,8 +1322,8 @@ def inc_subtensor(
# It is acceptable to try to increment a subtensor with a # It is acceptable to try to increment a subtensor with a
# broadcastable dim with a tensor that is not broadcastable # broadcastable dim with a tensor that is not broadcastable
# on that dimension. However, its length must then be 1. # on that dimension. However, its length must then be 1.
# We insert a Rebroadcast Op to make sure it is the case. # We insert a SpecifyShape Op to make sure it is the case.
y = addbroadcast(y, dim) y = specify_broadcastable(y, dim)
if not x.owner: if not x.owner:
raise TypeError("x must be the result of a subtensor operation") raise TypeError("x must be the result of a subtensor operation")
......
...@@ -8,6 +8,7 @@ import aesara.tensor as at ...@@ -8,6 +8,7 @@ import aesara.tensor as at
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
from aesara.tensor.nnet import batchnorm from aesara.tensor.nnet import batchnorm
from aesara.tensor.shape import specify_broadcastable
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
matrix, matrix,
...@@ -219,8 +220,8 @@ def test_batch_normalization_train(): ...@@ -219,8 +220,8 @@ def test_batch_normalization_train():
x_mean2 = x.mean(axis=axes2, keepdims=True) x_mean2 = x.mean(axis=axes2, keepdims=True)
x_var2 = x.var(axis=axes2, keepdims=True) x_var2 = x.var(axis=axes2, keepdims=True)
x_invstd2 = at.reciprocal(at.sqrt(x_var2 + eps)) x_invstd2 = at.reciprocal(at.sqrt(x_var2 + eps))
scale2 = at.addbroadcast(scale, *axes2) scale2 = specify_broadcastable(scale, *axes2)
bias2 = at.addbroadcast(bias, *axes2) bias2 = specify_broadcastable(bias, *axes2)
out2 = (x - x_mean2) * (scale2 * x_invstd2) + bias2 out2 = (x - x_mean2) * (scale2 * x_invstd2) + bias2
m = at.cast(at.prod(x.shape) / at.prod(scale.shape), aesara.config.floatX) m = at.cast(at.prod(x.shape) / at.prod(scale.shape), aesara.config.floatX)
out_running_mean2 = ( out_running_mean2 = (
...@@ -597,7 +598,7 @@ def test_batch_normalization_test(): ...@@ -597,7 +598,7 @@ def test_batch_normalization_test():
else: else:
axes2 = axes axes2 = axes
scale2, bias2, mean2, var2 = ( scale2, bias2, mean2, var2 = (
at.addbroadcast(t, *axes2) for t in (scale, bias, mean, var) specify_broadcastable(t, *axes2) for t in (scale, bias, mean, var)
) )
out2 = (x - mean2) * (scale2 / at.sqrt(var2 + eps)) + bias2 out2 = (x - mean2) * (scale2 / at.sqrt(var2 + eps)) + bias2
# backward pass # backward pass
......
...@@ -39,7 +39,6 @@ from aesara.tensor.basic import ( ...@@ -39,7 +39,6 @@ from aesara.tensor.basic import (
Split, Split,
TensorFromScalar, TensorFromScalar,
Tri, Tri,
addbroadcast,
alloc, alloc,
arange, arange,
as_tensor_variable, as_tensor_variable,
...@@ -69,7 +68,6 @@ from aesara.tensor.basic import ( ...@@ -69,7 +68,6 @@ from aesara.tensor.basic import (
nonzero_values, nonzero_values,
ogrid, ogrid,
ones_like, ones_like,
patternbroadcast,
permute_row_elements, permute_row_elements,
roll, roll,
scalar_from_tensor, scalar_from_tensor,
...@@ -3226,20 +3224,7 @@ class TestLongTensor: ...@@ -3226,20 +3224,7 @@ class TestLongTensor:
class TestBroadcast: class TestBroadcast:
def test_addbroadcast_validation(self): def test_unbroadcast(self):
x = as_tensor_variable(np.zeros((2, 3)))
with pytest.raises(ValueError, match=".*pattern does not.*"):
addbroadcast(x, 4)
def test_broadcast_bigdim(self):
def f():
x = matrix()
addbroadcast(x, 2)
with pytest.raises(ValueError):
f()
def test_unbroadcast_addbroadcast(self):
# test that the unbroadcast fct don't insert not needed broadcast # test that the unbroadcast fct don't insert not needed broadcast
# and fuse consecutive Rebroadcast op # and fuse consecutive Rebroadcast op
...@@ -3249,26 +3234,12 @@ class TestBroadcast: ...@@ -3249,26 +3234,12 @@ class TestBroadcast:
assert unbroadcast(x, 1, 0) is x assert unbroadcast(x, 1, 0) is x
assert unbroadcast(x, 0, 1) is x assert unbroadcast(x, 0, 1) is x
assert addbroadcast(x, 0) is not x
assert addbroadcast(x, 1) is not x
assert addbroadcast(x, 1, 0).owner.inputs[0] is x
assert unbroadcast(addbroadcast(x, 0), 0) is x
assert addbroadcast(unbroadcast(x, 0), 0) is not x
x = row() x = row()
assert unbroadcast(x, 0) is not x assert unbroadcast(x, 0) is not x
assert unbroadcast(x, 1) is x assert unbroadcast(x, 1) is x
assert unbroadcast(x, 1, 0) is not x assert unbroadcast(x, 1, 0) is not x
assert unbroadcast(x, 0, 1) is not x assert unbroadcast(x, 0, 1) is not x
assert addbroadcast(x, 0) is x
assert addbroadcast(x, 1).owner.inputs[0] is x
assert addbroadcast(x, 1, 0).owner.inputs[0] is x
assert addbroadcast(x, 0, 1).owner.inputs[0] is x
assert unbroadcast(addbroadcast(x, 1), 1) is x
assert addbroadcast(unbroadcast(x, 1), 1) is not x
# The first broadcast is remove the broadcast, so the second # The first broadcast is remove the broadcast, so the second
# should not make one # should not make one
assert unbroadcast(unbroadcast(x, 0), 0).owner.inputs[0] is x assert unbroadcast(unbroadcast(x, 0), 0).owner.inputs[0] is x
...@@ -3276,29 +3247,8 @@ class TestBroadcast: ...@@ -3276,29 +3247,8 @@ class TestBroadcast:
# Test that consecutive Rebroadcast op are fused # Test that consecutive Rebroadcast op are fused
x = TensorType(dtype="float64", shape=(True, True))() x = TensorType(dtype="float64", shape=(True, True))()
assert unbroadcast(unbroadcast(x, 1), 0).owner.inputs[0] is x assert unbroadcast(unbroadcast(x, 1), 0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x, 1), 0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x, 0), 0) is x
def test_patternbroadcast(self):
# Test that patternbroadcast with an empty broadcasting pattern works
x = scalar("x")
m = matrix("m")
s = patternbroadcast(m, x.broadcastable)
assert s is m
x2 = patternbroadcast(x, x.broadcastable)
assert x2 is x
def test_infer_shape(self): def test_infer_shape(self):
x = matrix()
y = addbroadcast(x, 0)
f = aesara.function([x], y.shape)
assert (f(np.zeros((1, 5), dtype=config.floatX)) == [1, 5]).all()
topo = f.maker.fgraph.toposort()
if config.mode != "FAST_COMPILE":
assert len(topo) == 2
assert isinstance(topo[0].op, Shape_i)
assert isinstance(topo[1].op, MakeVector)
x = matrix() x = matrix()
y = unbroadcast(x, 0) y = unbroadcast(x, 0)
f = aesara.function([x], y.shape) f = aesara.function([x], y.shape)
......
...@@ -1911,18 +1911,6 @@ class TestRebroadcast: ...@@ -1911,18 +1911,6 @@ class TestRebroadcast:
assert check_stack_trace(f, ops_to_check="all") assert check_stack_trace(f, ops_to_check="all")
def test_rebroadcast_rebroadcast(self):
mode = get_default_mode().including("canonicalize")
m = matrix()
s = at.addbroadcast(m, 0, 1)
v = at.unbroadcast(s, 1)
f = function([m], v, mode=mode)
f([[76]])
e = f.maker.fgraph.toposort()
rebroadcast_nodes = [n for n in e if isinstance(n.op, Rebroadcast)]
assert len(rebroadcast_nodes) == 1
assert rebroadcast_nodes[0].op.axis == {0: True}
class TestUselessElemwise: class TestUselessElemwise:
def setup_method(self): def setup_method(self):
......
...@@ -1918,6 +1918,9 @@ class TestDot: ...@@ -1918,6 +1918,9 @@ class TestDot:
# These examples should all work. All dimensions of all results have # These examples should all work. All dimensions of all results have
# size 1. # size 1.
# #
def is_super_shape(var1, var2):
# Check that var1.type is a superset of var2.type, ignoring dtype
return var1.type.is_super(var2.type.clone(dtype=var1.type.dtype))
for dtype0 in ("float32", "float64", "complex64"): for dtype0 in ("float32", "float64", "complex64"):
for dtype1 in ("float32", "complex64", "complex128"): for dtype1 in ("float32", "complex64", "complex128"):
...@@ -1944,9 +1947,9 @@ class TestDot: ...@@ -1944,9 +1947,9 @@ class TestDot:
if dtype0.startswith("float") and dtype1.startswith("float"): if dtype0.startswith("float") and dtype1.startswith("float"):
g = grad(z.sum(), x) g = grad(z.sum(), x)
assert g.broadcastable == x.broadcastable assert is_super_shape(x, g)
g = grad(z.sum(), y) g = grad(z.sum(), y)
assert g.broadcastable == y.broadcastable assert is_super_shape(y, g)
class TestTensordot: class TestTensordot:
......
...@@ -19,7 +19,7 @@ from aesara.tensor.opt_uncanonicalize import ( ...@@ -19,7 +19,7 @@ from aesara.tensor.opt_uncanonicalize import (
local_dimshuffle_subtensor, local_dimshuffle_subtensor,
local_reshape_dimshuffle, local_reshape_dimshuffle,
) )
from aesara.tensor.shape import reshape from aesara.tensor.shape import reshape, specify_shape
from aesara.tensor.type import dtensor4, iscalar, matrix, tensor, vector from aesara.tensor.type import dtensor4, iscalar, matrix, tensor, vector
from tests.link.test_link import make_function from tests.link.test_link import make_function
...@@ -179,7 +179,7 @@ def test_local_dimshuffle_subtensor(): ...@@ -179,7 +179,7 @@ def test_local_dimshuffle_subtensor():
dimshuffle_subtensor = out2in(local_dimshuffle_subtensor) dimshuffle_subtensor = out2in(local_dimshuffle_subtensor)
x = dtensor4("x") x = dtensor4("x")
x = at.patternbroadcast(x, (False, True, False, False)) x = specify_shape(x, (None, 1, None, None))
i = iscalar("i") i = iscalar("i")
out = x[:, :, 10:30, ::i].dimshuffle(0, 2, 3) out = x[:, :, 10:30, ::i].dimshuffle(0, 2, 3)
...@@ -213,7 +213,7 @@ def test_local_dimshuffle_subtensor(): ...@@ -213,7 +213,7 @@ def test_local_dimshuffle_subtensor():
# Test a corner case that had Aesara return a bug. # Test a corner case that had Aesara return a bug.
x = dtensor4("x") x = dtensor4("x")
x = at.patternbroadcast(x, (False, True, False, False)) x = specify_shape(x, (None, 1, None, None))
assert x[:, :, 0:3, ::-1].dimshuffle(0, 2, 3).eval( assert x[:, :, 0:3, ::-1].dimshuffle(0, 2, 3).eval(
{x: np.ones((5, 1, 6, 7))} {x: np.ones((5, 1, 6, 7))}
......
...@@ -9,7 +9,7 @@ from aesara.graph.basic import Variable ...@@ -9,7 +9,7 @@ from aesara.graph.basic import Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.tensor import as_tensor_variable, get_vector_length from aesara.tensor import as_tensor_variable, get_vector_length, row
from aesara.tensor.basic import MakeVector, constant from aesara.tensor.basic import MakeVector, constant
from aesara.tensor.basic_opt import ShapeFeature from aesara.tensor.basic_opt import ShapeFeature
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
...@@ -21,6 +21,7 @@ from aesara.tensor.shape import ( ...@@ -21,6 +21,7 @@ from aesara.tensor.shape import (
reshape, reshape,
shape, shape,
shape_i, shape_i,
specify_broadcastable,
specify_shape, specify_shape,
) )
from aesara.tensor.subtensor import Subtensor from aesara.tensor.subtensor import Subtensor
...@@ -518,6 +519,23 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -518,6 +519,23 @@ class TestSpecifyShape(utt.InferShapeTester):
assert isinstance(z_grad.owner.op, SpecifyShape) assert isinstance(z_grad.owner.op, SpecifyShape)
class TestSpecifyBroadcastable:
def test_basic(self):
x = matrix()
assert specify_broadcastable(x, 0).type.shape == (1, None)
assert specify_broadcastable(x, 1).type.shape == (None, 1)
assert specify_broadcastable(x, 0, 1).type.shape == (1, 1)
x = row()
assert specify_broadcastable(x, 0) is x
assert specify_broadcastable(x, 1) is not x
def test_validation(self):
x = matrix()
with pytest.raises(ValueError, match="^Trying to specify broadcastable of*"):
specify_broadcastable(x, 2)
class TestRopLop(RopLopChecker): class TestRopLop(RopLopChecker):
def test_shape(self): def test_shape(self):
self.check_nondiff_rop(self.x.shape[0]) self.check_nondiff_rop(self.x.shape[0])
......
...@@ -1501,11 +1501,11 @@ class TestIncSubtensor: ...@@ -1501,11 +1501,11 @@ class TestIncSubtensor:
# This one should work # This one should work
f(rng_randX(3, 1), rng_randX(1)) f(rng_randX(3, 1), rng_randX(1))
# These ones should not # These ones should not
with pytest.raises(ValueError): with pytest.raises(AssertionError):
f(rng_randX(3, 1), rng_randX(2)) f(rng_randX(3, 1), rng_randX(2))
with pytest.raises(ValueError): with pytest.raises(AssertionError):
f(rng_randX(3, 1), rng_randX(3)) f(rng_randX(3, 1), rng_randX(3))
with pytest.raises(ValueError): with pytest.raises(AssertionError):
f(rng_randX(3, 1), rng_randX(0)) f(rng_randX(3, 1), rng_randX(0))
def test_simple_3d(self): def test_simple_3d(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论