提交 10d8e7ee authored 作者: Cesar Laurent's avatar Cesar Laurent

Better location for broadcast check.

上级 e9e679cd
...@@ -402,7 +402,7 @@ class Scan(PureOp): ...@@ -402,7 +402,7 @@ class Scan(PureOp):
'by using dimshuffle or shape_padleft. ' 'by using dimshuffle or shape_padleft. '
) )
def _check_broadcast(v1, v2): def check_broadcast(v1, v2):
""" Checks that the broadcast pattern of v1 and v2. """ Checks that the broadcast pattern of v1 and v2.
Controls that the broadcast pattern of the variable provided as Controls that the broadcast pattern of the variable provided as
...@@ -444,7 +444,6 @@ class Scan(PureOp): ...@@ -444,7 +444,6 @@ class Scan(PureOp):
rval = var rval = var
if rval.type.dtype != as_var.type.dtype: if rval.type.dtype != as_var.type.dtype:
rval = rval.astype(as_var.type.dtype) rval = rval.astype(as_var.type.dtype)
_check_broadcast(var, as_var)
if rval.ndim == as_var.ndim: if rval.ndim == as_var.ndim:
rval = as_var.type.filter_variable(rval) rval = as_var.type.filter_variable(rval)
else: else:
...@@ -459,6 +458,7 @@ class Scan(PureOp): ...@@ -459,6 +458,7 @@ class Scan(PureOp):
argoffset = 0 argoffset = 0
for inner_seq, outer_seq in zip(self.inner_seqs(self.inputs), for inner_seq, outer_seq in zip(self.inner_seqs(self.inputs),
self.outer_seqs(inputs)): self.outer_seqs(inputs)):
check_broadcast(outer_seq, inner_seq)
new_inputs.append(format(outer_seq, as_var=inner_seq)) new_inputs.append(format(outer_seq, as_var=inner_seq))
argoffset += len(self.outer_seqs(inputs)) argoffset += len(self.outer_seqs(inputs))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论