提交 daae1ad4 authored 作者: Frederic's avatar Frederic

Make ShapeFeature.same_shape(), support comparing only 1 dimensions of the variables

上级 3dd1bf0a
...@@ -1171,14 +1171,22 @@ class ShapeFeature(object): ...@@ -1171,14 +1171,22 @@ class ShapeFeature(object):
self.set_shape_i(v, ii, new_r) self.set_shape_i(v, ii, new_r)
self.shape_of_reverse_index[r] = set() self.shape_of_reverse_index[r] = set()
def same_shape(self, x, y): def same_shape(self, x, y, dim_x=None, dim_y=None):
"""Return True if we are able to assert that x and y have the """Return True if we are able to assert that x and y have the
same shape same shape.
dim_x and dim_y are optional. If used, they should be an index
to compare only 1 shape of x or y.
""" """
sx = self.shape_of[x] sx = self.shape_of[x]
sy = self.shape_of[y] sy = self.shape_of[y]
if sx is None or sy is None: if sx is None or sy is None:
return False return False
if dim_x is not None:
sx = [sx[dim_x]]
if dim_y is not None:
sy = [sy[dim_y]]
assert len(sx) == len(sy) assert len(sx) == len(sy)
for dx, dy in zip(sx, sy): for dx, dy in zip(sx, sy):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论