提交 0c4f1f05 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up ShapeFeature.same_shape and remove an overly restrictive assert

上级 105bcc8c
...@@ -6,6 +6,7 @@ import time ...@@ -6,6 +6,7 @@ import time
import traceback import traceback
from collections import defaultdict from collections import defaultdict
from io import StringIO from io import StringIO
from typing import Optional
import numpy as np import numpy as np
...@@ -1355,51 +1356,70 @@ class ShapeFeature(features.Feature): ...@@ -1355,51 +1356,70 @@ class ShapeFeature(features.Feature):
self.set_shape_i(v, ii, new_r) self.set_shape_i(v, ii, new_r)
self.shape_of_reverse_index[r] = set() self.shape_of_reverse_index[r] = set()
def same_shape(self, x, y, dim_x=None, dim_y=None): def same_shape(
"""Return True if we are able to assert that x and y have the self,
same shape. x: Variable,
y: Variable,
dim_x: Optional[int] = None,
dim_y: Optional[int] = None,
) -> bool:
"""Return ``True`` if `x` and `y` have the same shape.
dim_x and dim_y are optional. If used, they should be an index Parameters
to compare only 1 dimension of x and y. ==========
x
The `Variable` for which its shape is to be compared with `y`'s shape.
y
The `Variable` for which its shape is to be compared with `x`'s shape.
dim_x
If non ``None``, compare only the dimension of `x` equal to
`dim_x`.
dim_y
If non ``None``, compare only the dimension of `y` equal to
`dim_y`.
""" """
sx = self.shape_of[x] sx = self.shape_of[x]
sy = self.shape_of[y] sy = self.shape_of[y]
if sx is None or sy is None: if sx is None or sy is None:
return False return False
if dim_x is not None: if dim_x is not None:
sx = [sx[dim_x]] sx = [sx[dim_x]]
if dim_y is not None: if dim_y is not None:
sy = [sy[dim_y]] sy = [sy[dim_y]]
assert len(sx) == len(sy)
# We look on each dimensions we want to compare. if len(sx) != len(sy):
# If any of them can't be asserted to be equal, return False. return False
# Otherwise, we return True at the end.
for dx, dy in zip(sx, sy): for dx, dy in zip(sx, sy):
if dx is dy: if dx is dy:
continue continue
# Need to try to find that they are the same shape. We
# need to compare the full graph. It could be slow. So I # For now, only the `Shape_i` case is (explicitly) supported.
# just implement for now the case of Shape_i. # TODO: How necessary is this with the `equal_computations` below?
if not dx.owner or not dy.owner: if not dx.owner or not dy.owner:
return False return False
if not isinstance(dx.owner.op, Shape_i) or not isinstance(
dy.owner.op, Shape_i
):
return False
opx = dx.owner.op opx = dx.owner.op
opy = dy.owner.op opy = dy.owner.op
if not isinstance(opx, Shape_i) or not isinstance(opy, Shape_i):
return False
if opx.i != opy.i: if opx.i != opy.i:
return False return False
# FB I'm not sure if this handle correctly constants.
if dx.owner.inputs[0] == dy.owner.inputs[0]: if dx.owner.inputs[0] == dy.owner.inputs[0]:
continue continue
# To be sure to cover all case, call equal_computation. # To be sure to cover all case, call equal_computation.
# Can't use aesara.graph.basic.is_same_graph(dx, dy)
# As it currently expect that dx and dy aren't in a FunctionGraph
if not equal_computations([dx], [dy]): if not equal_computations([dx], [dy]):
return False return False
return True return True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论