提交 1acdafac authored 作者: f0k's avatar f0k

add check before patternbroadcast in local_gpualloc_memset_0

上级 08c6d742
......@@ -1699,10 +1699,13 @@ def local_gpualloc_memset_0(node):
inp.data.size == 1 and
(numpy.asarray(inp.data) == 0).all()):
new_out = GpuAlloc(memset_0=True)(*node.inputs)
if new_out.type.broadcastable != node.outputs[0].type.broadcastable:
old_bcast = node.outputs[0].type.broadcastable
if new_out.type.broadcastable != old_bcast:
# check that we did not try discarding a broadcastable dimension
assert not any(b_old and not b_new for b_old, b_new in zip(
old_bcast, new_out.type.broadcastable))
# force old broadcasting pattern; we must not change it here
new_out = tensor.patternbroadcast(new_out,
node.outputs[0].broadcastable)
new_out = tensor.patternbroadcast(new_out, old_bcast)
return [new_out]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论