提交 810ee8f4 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Prevent patternbroadcast from creating no-ops

上级 56539bcd
......@@ -11,7 +11,7 @@ import warnings
from collections.abc import Sequence
from functools import partial
from numbers import Number
from typing import Dict, Optional, Tuple, Union
from typing import Dict, Iterable, Optional, Tuple, Union
import numpy as np
from numpy.core.multiarray import normalize_axis_index
......@@ -2161,33 +2161,30 @@ def unbroadcast(x, *axes):
return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval)
def patternbroadcast(x, broadcastable):
"""
Make the input adopt a specific broadcasting pattern.
Broadcastable must be iterable. For example,
patternbroadcast(x, (True, False)) will make the first
dimension of x broadcastable and the second dimension
not broadcastable, so x will now be a row.
def patternbroadcast(
x: TensorVariable, broadcastable: Iterable[Union[bool, int]]
) -> TensorVariable:
"""Make the input adopt a specific broadcasting pattern.
We apply the opt here not to pollute the graph especially during the gpu
optimization.
For example, ``patternbroadcast(x, (True, False))`` will make the first
dimension of `x` broadcastable and the second dimension not broadcastable,
so `x` will now be a row.
Parameters
----------
x : tensor_like
Input aesara tensor.
broadcastable : an iterable object such as list or tuple of bool values
A set of boolean values indicating whether a dimension should be
broadcastable or not. If the length of x along these dimensions is
not 1, a ValueError will be raised.
Returns
-------
tensor
A aesara tensor, which is unbroadcastable along the specified dimensions.
x
Input to re-broadcast.
broadcastable
Truthy values indicating whether or not a dimension should be
broadcastable or not. If the length of `x` along these dimensions is
not ``1``, a `ValueError` will be raised.
"""
x = as_tensor_variable(x)
if x.broadcastable == broadcastable:
return x
rval = Rebroadcast(*[(i, broadcastable[i]) for i in range(len(broadcastable))])(x)
return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论