提交 a482cc6f authored 作者: Frederic's avatar Frederic

Port gpu reduce opt to make reshape the input if needed.

上级 85faa716
...@@ -249,9 +249,52 @@ def local_gpua_incsubtensor(node): ...@@ -249,9 +249,52 @@ def local_gpua_incsubtensor(node):
def local_gpua_careduce(node): def local_gpua_careduce(node):
if (isinstance(node.op.scalar_op, scalar.basic.Add) or if (isinstance(node.op.scalar_op, scalar.basic.Add) or
isinstance(node.op.scalar_op, scalar.basic.Mul)): isinstance(node.op.scalar_op, scalar.basic.Mul)):
return GpuCAReduce(node.op.scalar_op, axis=node.op.axis,
dtype=getattr(node.op, 'dtype', None), greduce = GpuCAReduce(node.op.scalar_op, axis=node.op.axis)
acc_dtype=getattr(node.op, 'acc_dtype', None)) if greduce.supports_c_code([gpu_from_host(x)]):
return greduce
else:
# Try to make a simpler pattern based on reshaping
# The principle is that if two adjacent dimensions have
# the same value in the reduce_mask, then we can reshape
# to make them a single dimension, do the reduction, and
# then reshape to get them back.
if node.op.axis is None:
reduce_mask = [1] * x.type.ndim
else:
reduce_mask = [0] * x.type.ndim
for a in node.op.axis:
assert reduce_mask[a] == 0
reduce_mask[a] = 1
shape_of = node.fgraph.shape_feature.shape_of
x_shape = shape_of[x]
new_in_shp = [x_shape[0]]
new_mask = [reduce_mask[0]]
for i in xrange(1, x.type.ndim):
if reduce_mask[i] == reduce_mask[i - 1]:
new_in_shp[-1] *= x_shape[i]
else:
new_mask.append(reduce_mask[i])
new_in_shp.append(x_shape[i])
new_greduce = GpuCAReduce(new_mask, scalar_op)
reshaped_x = x.reshape(tensor.stack(*new_in_shp))
gpu_reshaped_x = gpu_from_host(reshaped_x)
reshaped_gpu_inputs = [gpu_reshaped_x]
if new_greduce.supports_c_code(reshaped_gpu_inputs):
reduce_reshaped_x = host_from_gpu(
new_greduce(gpu_reshaped_x))
if reduce_reshaped_x.ndim != node.outputs[0].ndim:
unreshaped_reduce = reduce_reshaped_x.reshape(
tensor.stack(*shape_of[node.outputs[0]]))
else:
unreshaped_reduce = reduce_reshaped_x
return [unreshaped_reduce]
@register_opt() @register_opt()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论