提交 810ee8f4 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Prevent patternbroadcast from creating no-ops

上级 56539bcd
...@@ -11,7 +11,7 @@ import warnings ...@@ -11,7 +11,7 @@ import warnings
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial from functools import partial
from numbers import Number from numbers import Number
from typing import Dict, Optional, Tuple, Union from typing import Dict, Iterable, Optional, Tuple, Union
import numpy as np import numpy as np
from numpy.core.multiarray import normalize_axis_index from numpy.core.multiarray import normalize_axis_index
...@@ -2161,33 +2161,30 @@ def unbroadcast(x, *axes): ...@@ -2161,33 +2161,30 @@ def unbroadcast(x, *axes):
return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval) return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval)
def patternbroadcast(x, broadcastable): def patternbroadcast(
""" x: TensorVariable, broadcastable: Iterable[Union[bool, int]]
Make the input adopt a specific broadcasting pattern. ) -> TensorVariable:
"""Make the input adopt a specific broadcasting pattern.
Broadcastable must be iterable. For example,
patternbroadcast(x, (True, False)) will make the first
dimension of x broadcastable and the second dimension
not broadcastable, so x will now be a row.
We apply the opt here not to pollute the graph especially during the gpu For example, ``patternbroadcast(x, (True, False))`` will make the first
optimization. dimension of `x` broadcastable and the second dimension not broadcastable,
so `x` will now be a row.
Parameters Parameters
---------- ----------
x : tensor_like x
Input aesara tensor. Input to re-broadcast.
broadcastable : an iterable object such as list or tuple of bool values broadcastable
A set of boolean values indicating whether a dimension should be Truthy values indicating whether or not a dimension should be
broadcastable or not. If the length of x along these dimensions is broadcastable or not. If the length of `x` along these dimensions is
not 1, a ValueError will be raised. not ``1``, a `ValueError` will be raised.
Returns
-------
tensor
A aesara tensor, which is unbroadcastable along the specified dimensions.
""" """
x = as_tensor_variable(x)
if x.broadcastable == broadcastable:
return x
rval = Rebroadcast(*[(i, broadcastable[i]) for i in range(len(broadcastable))])(x) rval = Rebroadcast(*[(i, broadcastable[i]) for i in range(len(broadcastable))])(x)
return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval) return aesara.tensor.basic_opt.apply_rebroadcast_opt(rval)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论