提交 316ce0ba authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Refactor encompasses_broadcastable to broadcasted_by

上级 c946160a
...@@ -49,7 +49,7 @@ from pytensor.tensor.math import eq ...@@ -49,7 +49,7 @@ from pytensor.tensor.math import eq
from pytensor.tensor.shape import Shape_i from pytensor.tensor.shape import Shape_i
from pytensor.tensor.sort import TopKOp from pytensor.tensor.sort import TopKOp
from pytensor.tensor.type import DenseTensorType, TensorType from pytensor.tensor.type import DenseTensorType, TensorType
from pytensor.tensor.var import TensorConstant from pytensor.tensor.var import TensorConstant, TensorVariable
from pytensor.utils import NoDuplicateOptWarningFilter from pytensor.utils import NoDuplicateOptWarningFilter
...@@ -61,27 +61,26 @@ _logger = logging.getLogger("pytensor.tensor.rewriting.basic") ...@@ -61,27 +61,26 @@ _logger = logging.getLogger("pytensor.tensor.rewriting.basic")
_logger.addFilter(NoDuplicateOptWarningFilter()) _logger.addFilter(NoDuplicateOptWarningFilter())
def encompasses_broadcastable(b1, b2): def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool:
""" """Check whether x would be broadcasted by y in an Elemwise operation
Parameters Parameters
---------- ----------
b1 x: TensorVariable
The broadcastable attribute of a tensor type. The variable that may be broadcasted by y
b2 y: TensorVariable
The broadcastable attribute of a tensor type. The variable that may broadcast x
Returns Returns
------- -------
bool broadcasted_by: bool
True if the broadcastable patterns b1 and b2 are such that b2 is
broadcasted to b1's shape and not the opposite.
""" """
if len(b1) < len(b2): bx = x.type.broadcastable
return False by = y.type.broadcastable
b1 = b1[-len(b2) :] if len(bx) < len(by):
return not any(v1 and not v2 for v1, v2 in zip(b1, b2)) return True
bx = bx[-len(by) :]
return any(bx_dim and not by_dim for bx_dim, by_dim in zip(bx, by))
def merge_broadcastables(broadcastables): def merge_broadcastables(broadcastables):
......
...@@ -85,7 +85,7 @@ from pytensor.tensor.math import sum as at_sum ...@@ -85,7 +85,7 @@ from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.math import true_div from pytensor.tensor.math import true_div
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
broadcast_like, broadcast_like,
encompasses_broadcastable, broadcasted_by,
local_fill_sink, local_fill_sink,
register_canonicalize, register_canonicalize,
register_specialize, register_specialize,
...@@ -2049,9 +2049,7 @@ def local_pow_specialize(fgraph, node): ...@@ -2049,9 +2049,7 @@ def local_pow_specialize(fgraph, node):
xsym = node.inputs[0] xsym = node.inputs[0]
ysym = node.inputs[1] ysym = node.inputs[1]
y = get_constant(ysym) y = get_constant(ysym)
if (y is not None) and encompasses_broadcastable( if (y is not None) and not broadcasted_by(xsym, ysym):
xsym.type.broadcastable, ysym.type.broadcastable
):
rval = None rval = None
if np.all(y == 2): if np.all(y == 2):
...@@ -2107,9 +2105,7 @@ def local_pow_to_nested_squaring(fgraph, node): ...@@ -2107,9 +2105,7 @@ def local_pow_to_nested_squaring(fgraph, node):
y = y[0] y = y[0]
except IndexError: except IndexError:
pass pass
if (y is not None) and encompasses_broadcastable( if (y is not None) and not broadcasted_by(xsym, ysym):
xsym.type.broadcastable, ysym.type.broadcastable
):
rval = None rval = None
# 512 is too small for the cpu and too big for some gpu! # 512 is too small for the cpu and too big for some gpu!
if abs(y) == int(abs(y)) and abs(y) <= 512: if abs(y) == int(abs(y)) and abs(y) <= 512:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论