提交 d29cdd8d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use canonicalized forms in ShapeFeature.same_shape

上级 e44edc3e
...@@ -1426,29 +1426,27 @@ class ShapeFeature(features.Feature): ...@@ -1426,29 +1426,27 @@ class ShapeFeature(features.Feature):
if len(sx) != len(sy): if len(sx) != len(sy):
return False return False
for dx, dy in zip(sx, sy): # Canonicalize the graphs so that comparisons are reasonable
# TODO FIXME: This should *not* need to be performed manually here.
if dx is dy: # Instead, the shape information in `self.shape_of` should be operated
continue # upon alongside all the other elements in a `FunctionGraph` (e.g. as
# if `self.shape_of.values()` were additional outputs).
shapes_fg = FunctionGraph(
outputs=sx + sy,
# features=[self],
clone=True,
# copy_inputs=False,
)
from aesara.graph.opt_utils import optimize_graph
canon_shapes = optimize_graph(
shapes_fg, custom_opt=topo_constant_folding
).outputs
sx = canon_shapes[: len(sx)]
sy = canon_shapes[len(sx) :]
# For now, only the `Shape_i` case is (explicitly) supported. for dx, dy in zip(sx, sy):
# TODO: How necessary is this with the `equal_computations` below?
if not dx.owner or not dy.owner:
return False
opx = dx.owner.op
opy = dy.owner.op
if not isinstance(opx, Shape_i) or not isinstance(opy, Shape_i):
return False
if opx.i != opy.i:
return False
if dx.owner.inputs[0] == dy.owner.inputs[0]:
continue
# To be sure to cover all case, call equal_computation.
if not equal_computations([dx], [dy]): if not equal_computations([dx], [dy]):
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论