提交 6f2a60ba authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix attribute access error in local_rebroadcast_lift

上级 900d9f87
......@@ -26,7 +26,7 @@ from aesara.graph.basic import (
equal_computations,
io_toposort,
)
from aesara.graph.fg import InconsistencyError
from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import get_test_value
from aesara.graph.opt import (
GlobalOptimizer,
......@@ -3622,11 +3622,6 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
return [r]
####################
# Rebroadcast opts #
####################
@register_useless
@register_canonicalize
@register_specialize
......@@ -3682,7 +3677,7 @@ def local_rebroadcast_lift(fgraph, node):
# is called from `apply_rebroadcast_opt`, which in particular is used
# by the `unbroadcast` function before we are in the actual function
# compilation phase.
if len(fgraph.clients[input]) == 1:
if len(fgraph.clients.get(input, ())) == 1:
rebroadcasted = Rebroadcast(*list(op.axis.items()))(inode.inputs[0])
# Copy over stacktrace from previous output (after rebroadcasting)
# to new output, because an error in the new graph right after
......@@ -3730,16 +3725,17 @@ def apply_rebroadcast_opt(rval):
"""
fg = FunctionGraph([], [])
changed = True
while changed and rval.owner:
changed = False
rval2 = local_useless_rebroadcast.transform(None, rval.owner)
rval2 = local_useless_rebroadcast.transform(fg, rval.owner)
if rval2:
assert len(rval2) == 1
rval = rval2[0]
changed = True
if rval.owner:
rval2 = local_rebroadcast_lift.transform(None, rval.owner)
rval2 = local_rebroadcast_lift.transform(fg, rval.owner)
if rval2:
assert len(rval2) == 1
rval = rval2[0]
......@@ -3747,9 +3743,6 @@ def apply_rebroadcast_opt(rval):
return rval
#############
# Join opts #
#############
@register_specialize
@register_canonicalize
@register_useless
......
......@@ -38,6 +38,7 @@ from aesara.tensor.basic import (
)
from aesara.tensor.basic_opt import (
ShapeFeature,
apply_rebroadcast_opt,
assert_op,
local_canonicalize_alloc,
local_dimshuffle_lift,
......@@ -5184,3 +5185,18 @@ def test_local_useless_alloc():
assert len(topo) == 3
assert isinstance(topo[-2].op, Assert)
assert isinstance(topo[-1].op, Alloc)
def test_apply_rebroadcast_opt():
# Test the `Elemwise` case in `local_rebroadcast_lift` with `fgraph=None`.
# This is called by in `apply_rebroadcast_opt`.
a = vector(dtype="float32")
b = tensor("float64", [True])
x = b.astype(a.dtype)
broadcastable = (False,)
axis = [(i, broadcastable[i]) for i in range(len(broadcastable))]
rval = Rebroadcast(*axis)(x)
res = apply_rebroadcast_opt(rval)
assert res is rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论