提交 10d8e7ee authored 作者: Cesar Laurent's avatar Cesar Laurent

Better location for broadcast check.

上级 e9e679cd
......@@ -402,7 +402,7 @@ class Scan(PureOp):
'by using dimshuffle or shape_padleft. '
)
def _check_broadcast(v1, v2):
def check_broadcast(v1, v2):
""" Checks that the broadcast pattern of v1 and v2.
Controls that the broadcast pattern of the variable provided as
......@@ -444,7 +444,6 @@ class Scan(PureOp):
rval = var
if rval.type.dtype != as_var.type.dtype:
rval = rval.astype(as_var.type.dtype)
_check_broadcast(var, as_var)
if rval.ndim == as_var.ndim:
rval = as_var.type.filter_variable(rval)
else:
......@@ -459,6 +458,7 @@ class Scan(PureOp):
argoffset = 0
for inner_seq, outer_seq in zip(self.inner_seqs(self.inputs),
self.outer_seqs(inputs)):
check_broadcast(outer_seq, inner_seq)
new_inputs.append(format(outer_seq, as_var=inner_seq))
argoffset += len(self.outer_seqs(inputs))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论