提交 1acdafac authored 作者: f0k's avatar f0k

add check before patternbroadcast in local_gpualloc_memset_0

上级 08c6d742
...@@ -1699,10 +1699,13 @@ def local_gpualloc_memset_0(node): ...@@ -1699,10 +1699,13 @@ def local_gpualloc_memset_0(node):
inp.data.size == 1 and inp.data.size == 1 and
(numpy.asarray(inp.data) == 0).all()): (numpy.asarray(inp.data) == 0).all()):
new_out = GpuAlloc(memset_0=True)(*node.inputs) new_out = GpuAlloc(memset_0=True)(*node.inputs)
if new_out.type.broadcastable != node.outputs[0].type.broadcastable: old_bcast = node.outputs[0].type.broadcastable
if new_out.type.broadcastable != old_bcast:
# check that we did not try discarding a broadcastable dimension
assert not any(b_old and not b_new for b_old, b_new in zip(
old_bcast, new_out.type.broadcastable))
# force old broadcasting pattern; we must not change it here # force old broadcasting pattern; we must not change it here
new_out = tensor.patternbroadcast(new_out, new_out = tensor.patternbroadcast(new_out, old_bcast)
node.outputs[0].broadcastable)
return [new_out] return [new_out]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论