提交 6f2a60ba authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix attribute access error in local_rebroadcast_lift

上级 900d9f87
...@@ -26,7 +26,7 @@ from aesara.graph.basic import ( ...@@ -26,7 +26,7 @@ from aesara.graph.basic import (
equal_computations, equal_computations,
io_toposort, io_toposort,
) )
from aesara.graph.fg import InconsistencyError from aesara.graph.fg import FunctionGraph, InconsistencyError
from aesara.graph.op import get_test_value from aesara.graph.op import get_test_value
from aesara.graph.opt import ( from aesara.graph.opt import (
GlobalOptimizer, GlobalOptimizer,
...@@ -3622,11 +3622,6 @@ def local_useless_inc_subtensor_alloc(fgraph, node): ...@@ -3622,11 +3622,6 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
return [r] return [r]
####################
# Rebroadcast opts #
####################
@register_useless @register_useless
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
...@@ -3682,7 +3677,7 @@ def local_rebroadcast_lift(fgraph, node): ...@@ -3682,7 +3677,7 @@ def local_rebroadcast_lift(fgraph, node):
# is called from `apply_rebroadcast_opt`, which in particular is used # is called from `apply_rebroadcast_opt`, which in particular is used
# by the `unbroadcast` function before we are in the actual function # by the `unbroadcast` function before we are in the actual function
# compilation phase. # compilation phase.
if len(fgraph.clients[input]) == 1: if len(fgraph.clients.get(input, ())) == 1:
rebroadcasted = Rebroadcast(*list(op.axis.items()))(inode.inputs[0]) rebroadcasted = Rebroadcast(*list(op.axis.items()))(inode.inputs[0])
# Copy over stacktrace from previous output (after rebroadcasting) # Copy over stacktrace from previous output (after rebroadcasting)
# to new output, because an error in the new graph right after # to new output, because an error in the new graph right after
...@@ -3730,16 +3725,17 @@ def apply_rebroadcast_opt(rval): ...@@ -3730,16 +3725,17 @@ def apply_rebroadcast_opt(rval):
""" """
fg = FunctionGraph([], [])
changed = True changed = True
while changed and rval.owner: while changed and rval.owner:
changed = False changed = False
rval2 = local_useless_rebroadcast.transform(None, rval.owner) rval2 = local_useless_rebroadcast.transform(fg, rval.owner)
if rval2: if rval2:
assert len(rval2) == 1 assert len(rval2) == 1
rval = rval2[0] rval = rval2[0]
changed = True changed = True
if rval.owner: if rval.owner:
rval2 = local_rebroadcast_lift.transform(None, rval.owner) rval2 = local_rebroadcast_lift.transform(fg, rval.owner)
if rval2: if rval2:
assert len(rval2) == 1 assert len(rval2) == 1
rval = rval2[0] rval = rval2[0]
...@@ -3747,9 +3743,6 @@ def apply_rebroadcast_opt(rval): ...@@ -3747,9 +3743,6 @@ def apply_rebroadcast_opt(rval):
return rval return rval
#############
# Join opts #
#############
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@register_useless @register_useless
......
...@@ -38,6 +38,7 @@ from aesara.tensor.basic import ( ...@@ -38,6 +38,7 @@ from aesara.tensor.basic import (
) )
from aesara.tensor.basic_opt import ( from aesara.tensor.basic_opt import (
ShapeFeature, ShapeFeature,
apply_rebroadcast_opt,
assert_op, assert_op,
local_canonicalize_alloc, local_canonicalize_alloc,
local_dimshuffle_lift, local_dimshuffle_lift,
...@@ -5184,3 +5185,18 @@ def test_local_useless_alloc(): ...@@ -5184,3 +5185,18 @@ def test_local_useless_alloc():
assert len(topo) == 3 assert len(topo) == 3
assert isinstance(topo[-2].op, Assert) assert isinstance(topo[-2].op, Assert)
assert isinstance(topo[-1].op, Alloc) assert isinstance(topo[-1].op, Alloc)
def test_apply_rebroadcast_opt():
# Test the `Elemwise` case in `local_rebroadcast_lift` with `fgraph=None`.
# This is called by in `apply_rebroadcast_opt`.
a = vector(dtype="float32")
b = tensor("float64", [True])
x = b.astype(a.dtype)
broadcastable = (False,)
axis = [(i, broadcastable[i]) for i in range(len(broadcastable))]
rval = Rebroadcast(*axis)(x)
res = apply_rebroadcast_opt(rval)
assert res is rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论