提交 28922228 authored 作者: Frederic's avatar Frederic

In local_gpu_careduce, remove dimensions that is now detected as…

In local_gpu_careduce, remove dimensions that is now detected as rebroadcastable, but wasn't during graph build.
上级 a4643310
......@@ -628,15 +628,31 @@ def local_gpu_careduce(node):
greduce = GpuCAReduce(reduce_mask, scalar_op)
if greduce.supports_c_code([gpu_from_host(x)]):
rval = host_from_gpu(greduce(gpu_from_host(x)))
if rval.type == node.outputs[0].type:
out = node.outputs[0]
if rval.type == out.type:
return [rval]
else:
print >> sys.stderr, (
"WARNING: local_gpu_careduce got type wrong",
rval.type, node.outputs[0].type,
node.inputs[0].type,
node)
return None
for b1, b2 in zip(rval.broadcastable,
out.type.broadcastable):
if b1 is True:
# It can happen that during
# optimization we discover that the
# input can be broadcasted, but didn't
# know that at graph build time.
continue
if b1 is False and b2 is True:
# We should not loose the information
# that one dimensions was
# broadcastable.
print >> sys.stderr, (
"WARNING: local_gpu_careduce got type"
" wrong",
rval.type, out.type,
node.inputs[0].type, x.type,
node)
return None
rval = patternbroadcast(rval,
out.type.broadcastable)
else:
# Try to make a simpler pattern based on reshaping
......@@ -665,21 +681,37 @@ def local_gpu_careduce(node):
if new_greduce.supports_c_code(reshaped_gpu_inputs):
reduce_reshaped_x = host_from_gpu(
new_greduce(gpu_reshaped_x))
out = node.outputs[0]
if reduce_reshaped_x.ndim != node.outputs[0].ndim:
unreshaped_reduce = reduce_reshaped_x.reshape(
tensor.stack(*shape_of[node.outputs[0]]))
if reduce_reshaped_x.ndim != out.ndim:
rval = reduce_reshaped_x.reshape(
tensor.stack(*shape_of[out]))
else:
unreshaped_reduce = reduce_reshaped_x
if unreshaped_reduce.type == node.outputs[0].type:
return [unreshaped_reduce]
rval = reduce_reshaped_x
if rval.type == out.type:
return [rval]
else:
print >> sys.stderr, (
"WARNING: local_gpu_careduce got type wrong",
unreshaped_reduce.type, node.outputs[0].type,
node.inputs[0].type,
node)
return None
for b1, b2 in zip(rval.broadcastable,
out.type.broadcastable):
if b1 is True:
# It can happen that during
# optimization we discover that the
# input can be broadcasted, but didn't
# know that at graph build time.
continue
if b1 is False and b2 is True:
# We should not loose the information
# that one dimensions was
# broadcastable.
print >> sys.stderr, (
"WARNING: local_gpu_careduce got type"
" wrong",
rval.type, out.type,
node.inputs[0].type, x.type,
node)
return None
rval = patternbroadcast(rval,
out.broadcastable)
return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论