提交 ce2c8613 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add local Shape_i of broadcastable dimension canonicalization

上级 8ec11777
......@@ -79,7 +79,7 @@ from aesara.tensor.math import eq
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft
from aesara.tensor.sort import TopKOp
from aesara.tensor.subtensor import Subtensor, get_idx_list
from aesara.tensor.type import discrete_dtypes, integer_dtypes, lscalar
from aesara.tensor.type import TensorType, discrete_dtypes, integer_dtypes, lscalar
from aesara.tensor.var import TensorConstant
from aesara.utils import NoDuplicateOptWarningFilter
......@@ -3553,3 +3553,21 @@ def local_Shape_of_SpecifyShape(fgraph, node):
return False
return [specified_shape.owner.inputs[1].astype(np.int64)]
@register_useless
@register_canonicalize
@local_optimizer([Shape_i])
def local_Shape_i_of_broadcastable(fgraph, node):
"""Replace ``shape_i(x, i)`` with ``1`` when ``x.broadcastable[i]`` is ``True``."""
if not isinstance(node.op, Shape_i):
return False
shape_arg = node.inputs[0]
if not isinstance(shape_arg.type, TensorType):
return False
if shape_arg.broadcastable[node.op.i]:
return [as_tensor_variable(1, dtype=np.int64)]
......@@ -15,12 +15,13 @@ from aesara.compile.function import function
from aesara.compile.mode import Mode, get_default_mode, get_mode
from aesara.compile.ops import DeepCopyOp, deep_copy_op
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant
from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt import check_stack_trace, local_optimizer, out2in
from aesara.graph.opt_utils import optimize_graph
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray
from aesara.tensor.basic import (
Alloc,
......@@ -3208,6 +3209,35 @@ def test_local_Shape_of_SpecifyShape(shape):
assert shape in fgraph.variables
def test_local_Shape_i_of_broadcastable():
x = tensor(np.float64, [False, True])
s = Shape_i(1)(x)
fgraph = FunctionGraph(outputs=[s], clone=False)
_ = optimize_graph(fgraph, clone=False)
assert x not in fgraph.variables
assert fgraph.outputs[0].data == 1
# A test for a non-`TensorType`
class MyType(Type):
def filter(self, *args, **kwargs):
raise NotImplementedError()
def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy
class MyVariable(Variable):
ndim = 1
x = MyVariable(MyType(), None, None)
s = Shape_i(0)(x)
fgraph = FunctionGraph(outputs=[s], clone=False)
_ = optimize_graph(fgraph, clone=False)
assert fgraph.outputs[0] == s
def test_assert_op_gradient():
x = vector("x")
assert_op = Assert()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论