提交 95f0c13d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make ShapeFeature use static shape values instead of broadcastable

上级 269903aa
...@@ -938,8 +938,8 @@ class ShapeFeature(features.Feature): ...@@ -938,8 +938,8 @@ class ShapeFeature(features.Feature):
def shape_ir(self, i, r): def shape_ir(self, i, r):
"""Return symbolic r.shape[i] for tensor variable r, int i.""" """Return symbolic r.shape[i] for tensor variable r, int i."""
if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]: if hasattr(r.type, "shape") and r.type.shape[i] is not None:
return self.lscalar_one return constant(r.type.shape[i], dtype="int64")
else: else:
# Do not call make_node for test_value # Do not call make_node for test_value
s = Shape_i(i)(r) s = Shape_i(i)(r)
...@@ -1079,8 +1079,8 @@ class ShapeFeature(features.Feature): ...@@ -1079,8 +1079,8 @@ class ShapeFeature(features.Feature):
shape_vars = [] shape_vars = []
for i in range(r.type.ndim): for i in range(r.type.ndim):
if hasattr(r.type, "broadcastable") and r.type.broadcastable[i]: if hasattr(r.type, "shape") and r.type.shape[i] is not None:
shape_vars.append(self.lscalar_one) shape_vars.append(constant(r.type.shape[i], dtype="int64"))
else: else:
shape_vars.append(self.unpack(s[i], r)) shape_vars.append(self.unpack(s[i], r))
assert all( assert all(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论