提交 807c0c97 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use static shape information instead of broadcastable in Scan

上级 e59cef18
......@@ -200,7 +200,7 @@ def copy_var_format(var, as_var):
rval = as_var.type.filter_variable(rval)
else:
tmp = as_var.type.clone(
shape=(tuple(var.broadcastable[:1]) + tuple(as_var.broadcastable))
shape=(tuple(var.type.shape[:1]) + tuple(as_var.type.shape))
)
rval = tmp.filter_variable(rval)
return rval
......@@ -805,7 +805,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# output sequence
o = outputs[idx]
self.output_types.append(
typeConstructor((False,) + o.type.broadcastable, o.type.dtype)
# TODO: What can we actually say about the shape of this
# added dimension?
typeConstructor((None,) + o.type.shape, o.type.dtype)
)
idx += len(info.mit_mot_out_slices[jdx])
......@@ -816,7 +818,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for o in outputs[idx:end]:
self.output_types.append(
typeConstructor((False,) + o.type.broadcastable, o.type.dtype)
# TODO: What can we actually say about the shape of this
# added dimension?
typeConstructor((None,) + o.type.shape, o.type.dtype)
)
# shared outputs + possibly the ending condition
......@@ -2320,8 +2324,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# equivalent (if False). Here, we only need the variable.
v_shp_i = validator.check(shp_i)
if v_shp_i is None:
if hasattr(r, "broadcastable") and r.broadcastable[i]:
shp.append(1)
if r.type.shape[i] is not None:
shp.append(r.type.shape[i])
else:
shp.append(Shape_i(i)(r))
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论