提交 0c4f1f05 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clean up ShapeFeature.same_shape and remove an overly restrictive assert

上级 105bcc8c
......@@ -6,6 +6,7 @@ import time
import traceback
from collections import defaultdict
from io import StringIO
from typing import Optional
import numpy as np
......@@ -1355,51 +1356,70 @@ class ShapeFeature(features.Feature):
self.set_shape_i(v, ii, new_r)
self.shape_of_reverse_index[r] = set()
def same_shape(self, x, y, dim_x=None, dim_y=None):
"""Return True if we are able to assert that x and y have the
same shape.
def same_shape(
self,
x: Variable,
y: Variable,
dim_x: Optional[int] = None,
dim_y: Optional[int] = None,
) -> bool:
"""Return ``True`` if `x` and `y` have the same shape.
dim_x and dim_y are optional. If used, they should be an index
to compare only 1 dimension of x and y.
Parameters
==========
x
The `Variable` for which its shape is to be compared with `y`'s shape.
y
The `Variable` for which its shape is to be compared with `x`'s shape.
dim_x
If non ``None``, compare only the dimension of `x` equal to
`dim_x`.
dim_y
If non ``None``, compare only the dimension of `y` equal to
`dim_y`.
"""
sx = self.shape_of[x]
sy = self.shape_of[y]
if sx is None or sy is None:
return False
if dim_x is not None:
sx = [sx[dim_x]]
if dim_y is not None:
sy = [sy[dim_y]]
assert len(sx) == len(sy)
# We look on each dimensions we want to compare.
# If any of them can't be asserted to be equal, return False.
# Otherwise, we return True at the end.
if len(sx) != len(sy):
return False
for dx, dy in zip(sx, sy):
if dx is dy:
continue
# Need to try to find that they are the same shape. We
# need to compare the full graph. It could be slow. So I
# just implement for now the case of Shape_i.
# For now, only the `Shape_i` case is (explicitly) supported.
# TODO: How necessary is this with the `equal_computations` below?
if not dx.owner or not dy.owner:
return False
if not isinstance(dx.owner.op, Shape_i) or not isinstance(
dy.owner.op, Shape_i
):
return False
opx = dx.owner.op
opy = dy.owner.op
if not isinstance(opx, Shape_i) or not isinstance(opy, Shape_i):
return False
if opx.i != opy.i:
return False
# FB I'm not sure if this handle correctly constants.
if dx.owner.inputs[0] == dy.owner.inputs[0]:
continue
# To be sure to cover all case, call equal_computation.
# Can't use aesara.graph.basic.is_same_graph(dx, dy)
# As it currently expect that dx and dy aren't in a FunctionGraph
if not equal_computations([dx], [dy]):
return False
return True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论