提交 95f0c13d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make ShapeFeature use static shape values instead of broadcastable

上级 269903aa
......@@ -938,8 +938,8 @@ class ShapeFeature(features.Feature):
def shape_ir(self, i, r):
"""Return symbolic r.shape[i] for tensor variable r, int i."""
if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]:
return self.lscalar_one
if hasattr(r.type, "shape") and r.type.shape[i] is not None:
return constant(r.type.shape[i], dtype="int64")
else:
# Do not call make_node for test_value
s = Shape_i(i)(r)
......@@ -1079,8 +1079,8 @@ class ShapeFeature(features.Feature):
shape_vars = []
for i in range(r.type.ndim):
if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]:
shape_vars.append(self.lscalar_one)
if hasattr(r.type, "shape") and r.type.shape[i] is not None:
shape_vars.append(constant(r.type.shape[i], dtype="int64"))
else:
shape_vars.append(self.unpack(s[i], r))
assert all(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论