提交 ce2c8613 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add local Shape_i of broadcastable dimension canonicalization

上级 8ec11777
...@@ -79,7 +79,7 @@ from aesara.tensor.math import eq ...@@ -79,7 +79,7 @@ from aesara.tensor.math import eq
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, shape_padleft
from aesara.tensor.sort import TopKOp from aesara.tensor.sort import TopKOp
from aesara.tensor.subtensor import Subtensor, get_idx_list from aesara.tensor.subtensor import Subtensor, get_idx_list
from aesara.tensor.type import discrete_dtypes, integer_dtypes, lscalar from aesara.tensor.type import TensorType, discrete_dtypes, integer_dtypes, lscalar
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant
from aesara.utils import NoDuplicateOptWarningFilter from aesara.utils import NoDuplicateOptWarningFilter
...@@ -3553,3 +3553,21 @@ def local_Shape_of_SpecifyShape(fgraph, node): ...@@ -3553,3 +3553,21 @@ def local_Shape_of_SpecifyShape(fgraph, node):
return False return False
return [specified_shape.owner.inputs[1].astype(np.int64)] return [specified_shape.owner.inputs[1].astype(np.int64)]
@register_useless
@register_canonicalize
@local_optimizer([Shape_i])
def local_Shape_i_of_broadcastable(fgraph, node):
"""Replace ``shape_i(x, i)`` with ``1`` when ``x.broadcastable[i]`` is ``True``."""
if not isinstance(node.op, Shape_i):
return False
shape_arg = node.inputs[0]
if not isinstance(shape_arg.type, TensorType):
return False
if shape_arg.broadcastable[node.op.i]:
return [as_tensor_variable(1, dtype=np.int64)]
...@@ -15,12 +15,13 @@ from aesara.compile.function import function ...@@ -15,12 +15,13 @@ from aesara.compile.function import function
from aesara.compile.mode import Mode, get_default_mode, get_mode from aesara.compile.mode import Mode, get_default_mode, get_mode
from aesara.compile.ops import DeepCopyOp, deep_copy_op from aesara.compile.ops import DeepCopyOp, deep_copy_op
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Constant from aesara.graph.basic import Apply, Constant, Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt import check_stack_trace, local_optimizer, out2in from aesara.graph.opt import check_stack_trace, local_optimizer, out2in
from aesara.graph.opt_utils import optimize_graph from aesara.graph.opt_utils import optimize_graph
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import OptimizationQuery
from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.tensor.basic import ( from aesara.tensor.basic import (
Alloc, Alloc,
...@@ -3208,6 +3209,35 @@ def test_local_Shape_of_SpecifyShape(shape): ...@@ -3208,6 +3209,35 @@ def test_local_Shape_of_SpecifyShape(shape):
assert shape in fgraph.variables assert shape in fgraph.variables
def test_local_Shape_i_of_broadcastable():
x = tensor(np.float64, [False, True])
s = Shape_i(1)(x)
fgraph = FunctionGraph(outputs=[s], clone=False)
_ = optimize_graph(fgraph, clone=False)
assert x not in fgraph.variables
assert fgraph.outputs[0].data == 1
# A test for a non-`TensorType`
class MyType(Type):
def filter(self, *args, **kwargs):
raise NotImplementedError()
def __eq__(self, other):
return isinstance(other, MyType) and other.thingy == self.thingy
class MyVariable(Variable):
ndim = 1
x = MyVariable(MyType(), None, None)
s = Shape_i(0)(x)
fgraph = FunctionGraph(outputs=[s], clone=False)
_ = optimize_graph(fgraph, clone=False)
assert fgraph.outputs[0] == s
def test_assert_op_gradient(): def test_assert_op_gradient():
x = vector("x") x = vector("x")
assert_op = Assert() assert_op = Assert()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论