提交 316ce0ba authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Refactor encompasses_broadcastable to broadcasted_by

上级 c946160a
......@@ -49,7 +49,7 @@ from pytensor.tensor.math import eq
from pytensor.tensor.shape import Shape_i
from pytensor.tensor.sort import TopKOp
from pytensor.tensor.type import DenseTensorType, TensorType
from pytensor.tensor.var import TensorConstant
from pytensor.tensor.var import TensorConstant, TensorVariable
from pytensor.utils import NoDuplicateOptWarningFilter
......@@ -61,27 +61,26 @@ _logger = logging.getLogger("pytensor.tensor.rewriting.basic")
_logger.addFilter(NoDuplicateOptWarningFilter())
def encompasses_broadcastable(b1, b2):
"""
def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool:
"""Check whether x would be broadcasted by y in an Elemwise operation
Parameters
----------
b1
The broadcastable attribute of a tensor type.
b2
The broadcastable attribute of a tensor type.
x: TensorVariable
The variable that may be broadcasted by y
y: TensorVariable
The variable that may broadcast x
Returns
-------
bool
True if the broadcastable patterns b1 and b2 are such that b2 is
broadcasted to b1's shape and not the opposite.
broadcasted_by: bool
"""
if len(b1) < len(b2):
return False
b1 = b1[-len(b2) :]
return not any(v1 and not v2 for v1, v2 in zip(b1, b2))
bx = x.type.broadcastable
by = y.type.broadcastable
if len(bx) < len(by):
return True
bx = bx[-len(by) :]
return any(bx_dim and not by_dim for bx_dim, by_dim in zip(bx, by))
def merge_broadcastables(broadcastables):
......
......@@ -85,7 +85,7 @@ from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import true_div
from pytensor.tensor.rewriting.basic import (
broadcast_like,
encompasses_broadcastable,
broadcasted_by,
local_fill_sink,
register_canonicalize,
register_specialize,
......@@ -2049,9 +2049,7 @@ def local_pow_specialize(fgraph, node):
xsym = node.inputs[0]
ysym = node.inputs[1]
y = get_constant(ysym)
if (y is not None) and encompasses_broadcastable(
xsym.type.broadcastable, ysym.type.broadcastable
):
if (y is not None) and not broadcasted_by(xsym, ysym):
rval = None
if np.all(y == 2):
......@@ -2107,9 +2105,7 @@ def local_pow_to_nested_squaring(fgraph, node):
y = y[0]
except IndexError:
pass
if (y is not None) and encompasses_broadcastable(
xsym.type.broadcastable, ysym.type.broadcastable
):
if (y is not None) and not broadcasted_by(xsym, ysym):
rval = None
# 512 is too small for the cpu and too big for some gpu!
if abs(y) == int(abs(y)) and abs(y) <= 512:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论