提交 807c0c97 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use static shape information instead of broadcastable in Scan

上级 e59cef18
...@@ -200,7 +200,7 @@ def copy_var_format(var, as_var): ...@@ -200,7 +200,7 @@ def copy_var_format(var, as_var):
rval = as_var.type.filter_variable(rval) rval = as_var.type.filter_variable(rval)
else: else:
tmp = as_var.type.clone( tmp = as_var.type.clone(
shape=(tuple(var.broadcastable[:1]) + tuple(as_var.broadcastable)) shape=(tuple(var.type.shape[:1]) + tuple(as_var.type.shape))
) )
rval = tmp.filter_variable(rval) rval = tmp.filter_variable(rval)
return rval return rval
...@@ -805,7 +805,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -805,7 +805,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# output sequence # output sequence
o = outputs[idx] o = outputs[idx]
self.output_types.append( self.output_types.append(
typeConstructor((False,) + o.type.broadcastable, o.type.dtype) # TODO: What can we actually say about the shape of this
# added dimension?
typeConstructor((None,) + o.type.shape, o.type.dtype)
) )
idx += len(info.mit_mot_out_slices[jdx]) idx += len(info.mit_mot_out_slices[jdx])
...@@ -816,7 +818,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -816,7 +818,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for o in outputs[idx:end]: for o in outputs[idx:end]:
self.output_types.append( self.output_types.append(
typeConstructor((False,) + o.type.broadcastable, o.type.dtype) # TODO: What can we actually say about the shape of this
# added dimension?
typeConstructor((None,) + o.type.shape, o.type.dtype)
) )
# shared outputs + possibly the ending condition # shared outputs + possibly the ending condition
...@@ -2320,8 +2324,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2320,8 +2324,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# equivalent (if False). Here, we only need the variable. # equivalent (if False). Here, we only need the variable.
v_shp_i = validator.check(shp_i) v_shp_i = validator.check(shp_i)
if v_shp_i is None: if v_shp_i is None:
if hasattr(r, "broadcastable") and r.broadcastable[i]: if r.type.shape[i] is not None:
shp.append(1) shp.append(r.type.shape[i])
else: else:
shp.append(Shape_i(i)(r)) shp.append(Shape_i(i)(r))
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论