提交 f4268206 authored 作者: abergeron's avatar abergeron

Merge pull request #3696 from adbrebs/reduce_mask_101

Fix mask 101 GPU support in GpuCAReduce
......@@ -865,20 +865,31 @@ def local_gpu_careduce(node):
new_in_shp.append(x_shape[i])
new_greduce = GpuCAReduce(new_mask, scalar_op)
reshaped_x = x.reshape(tensor.stack(new_in_shp))
gpu_reshaped_x = as_cuda_ndarray_variable(reshaped_x)
reshaped_gpu_inputs = [gpu_reshaped_x]
if new_greduce.supports_c_code(reshaped_gpu_inputs):
reduce_reshaped_x = host_from_gpu(
new_greduce(gpu_reshaped_x))
if reduce_reshaped_x.ndim != out.ndim:
rval = reduce_reshaped_x.reshape(
tensor.stack(shape_of[out]))
else:
rval = reduce_reshaped_x
else:
return
new_x = x.reshape(tensor.stack(new_in_shp))
gpu_new_x = as_cuda_ndarray_variable(new_x)
if not new_greduce.supports_c_code([gpu_new_x]):
if not new_mask == [1, 0, 1]:
return
# The reduced mask [1, 0, 1] is not supported but
# [1, 0, 1, 1] is. Therefore, we add a broadcastable
# dimension to new_x and change the mask to
# [1, 0, 1, 1].
new_x = new_x.dimshuffle(0, 1, 2, 'x')
gpu_new_x = as_cuda_ndarray_variable(new_x)
new_greduce = GpuCAReduce([1, 0, 1, 1], scalar_op)
if not new_greduce.supports_c_code([gpu_new_x]):
raise Exception('Reduction mask [1, 0, 1, 1] is'
'supposed to be supported.')
rval = host_from_gpu(
new_greduce(gpu_new_x))
# Restore the expected shape of the output
if rval.ndim != out.ndim:
rval = rval.reshape(
tensor.stack(shape_of[out]))
if rval.type == out.type:
return [rval]
else:
......
......@@ -113,7 +113,7 @@ def test_careduce():
((4100, 4, 3), [2]), ((5, 4100, 3), [2]), ((5, 4, 4100), [2]), # 001
((4100, 4, 3), [0, 1]), ((5, 4100, 3), [0, 1]), ((5, 4, 4100), [0, 1]), # 110
((4100, 4, 3), [1, 2]), ((5, 4100, 3), [1, 2]), ((5, 4, 4100), [1, 2]), # 011
#((4100,4,3),[0,2]),((5,4100,3),[0,2]),((5,4,4100),[0,2]),#101 ##not implemented
((4100,4,3),[0,2]),((5,4100,3),[0,2]),((5,4,4100),[0,2]),
((4100, 4, 3), [0, 1, 2]), ((5, 4100, 3), [0, 1, 2]), ((5, 4, 4100), [0, 1, 2]), # 111
((65, 4, 3), [0, 1, 2]), ((5, 65, 3), [0, 1, 2]), ((5, 4, 65), [0, 1, 2]), # 111
......@@ -133,13 +133,13 @@ def test_careduce():
# reduce over 2d
((4100, 4, 3, 2), [1, 2]), ((4, 4100, 3, 2), [1, 2]), ((4, 3, 4100, 2), [1, 2]), ((4, 3, 2, 4100), [1, 2]), # 0110
# ((4100,4,3,2),[0,3]),((4,4100,3,2),[0,3]),((4,3,4100,2),[0,3]),((4,3,2,4100),[0,3]),#1001 need 101
((4100,4,3,2),[0,3]),((4,4100,3,2),[0,3]),((4,3,4100,2),[0,3]),((4,3,2,4100),[0,3]),#1001
# ((4100,4,3,2),[0,2]),((4,4100,3,2),[0,2]),((4,3,4100,2),[0,2]),((4,3,2,4100),[0,2]),#1010 not implemented
((4100, 4, 3, 2), [0, 1]), ((4, 4100, 3, 2), [0, 1]), ((4, 3, 4100, 2), [0, 1]), ((4, 3, 2, 4100), [0, 1]), # 1100
# reduce over 3d
# 3d not tested: 1101, 1110, 1111
# ((4100,4,3,2),[0,1,3]),((4,4100,3,2),[0,1,3]),((4,3,4100,2),[0,1,3]),((4,3,2,4100),[0,1,3]),#1101 need 101
((4100,4,3,2),[0,1,3]),((4,4100,3,2),[0,1,3]),((4,3,4100,2),[0,1,3]),((4,3,2,4100),[0,1,3]),#1101
((4100, 4, 3, 2), [0, 1, 2]), ((4, 4100, 3, 2), [0, 1, 2]), ((4, 3, 4100, 2), [0, 1, 2]), ((4, 3, 2, 4100), [0, 1, 2]), # 1110
# reduce over 4d
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论