提交 f4268206 authored 作者: abergeron's avatar abergeron

Merge pull request #3696 from adbrebs/reduce_mask_101

Fix mask 101 GPU support in GpuCAReduce
...@@ -865,20 +865,31 @@ def local_gpu_careduce(node): ...@@ -865,20 +865,31 @@ def local_gpu_careduce(node):
new_in_shp.append(x_shape[i]) new_in_shp.append(x_shape[i])
new_greduce = GpuCAReduce(new_mask, scalar_op) new_greduce = GpuCAReduce(new_mask, scalar_op)
reshaped_x = x.reshape(tensor.stack(new_in_shp)) new_x = x.reshape(tensor.stack(new_in_shp))
gpu_reshaped_x = as_cuda_ndarray_variable(reshaped_x) gpu_new_x = as_cuda_ndarray_variable(new_x)
reshaped_gpu_inputs = [gpu_reshaped_x] if not new_greduce.supports_c_code([gpu_new_x]):
if new_greduce.supports_c_code(reshaped_gpu_inputs): if not new_mask == [1, 0, 1]:
reduce_reshaped_x = host_from_gpu( return
new_greduce(gpu_reshaped_x)) # The reduced mask [1, 0, 1] is not supported but
# [1, 0, 1, 1] is. Therefore, we add a broadcastable
if reduce_reshaped_x.ndim != out.ndim: # dimension to new_x and change the mask to
rval = reduce_reshaped_x.reshape( # [1, 0, 1, 1].
tensor.stack(shape_of[out])) new_x = new_x.dimshuffle(0, 1, 2, 'x')
else: gpu_new_x = as_cuda_ndarray_variable(new_x)
rval = reduce_reshaped_x
else: new_greduce = GpuCAReduce([1, 0, 1, 1], scalar_op)
return if not new_greduce.supports_c_code([gpu_new_x]):
raise Exception('Reduction mask [1, 0, 1, 1] is'
'supposed to be supported.')
rval = host_from_gpu(
new_greduce(gpu_new_x))
# Restore the expected shape of the output
if rval.ndim != out.ndim:
rval = rval.reshape(
tensor.stack(shape_of[out]))
if rval.type == out.type: if rval.type == out.type:
return [rval] return [rval]
else: else:
......
...@@ -113,7 +113,7 @@ def test_careduce(): ...@@ -113,7 +113,7 @@ def test_careduce():
((4100, 4, 3), [2]), ((5, 4100, 3), [2]), ((5, 4, 4100), [2]), # 001 ((4100, 4, 3), [2]), ((5, 4100, 3), [2]), ((5, 4, 4100), [2]), # 001
((4100, 4, 3), [0, 1]), ((5, 4100, 3), [0, 1]), ((5, 4, 4100), [0, 1]), # 110 ((4100, 4, 3), [0, 1]), ((5, 4100, 3), [0, 1]), ((5, 4, 4100), [0, 1]), # 110
((4100, 4, 3), [1, 2]), ((5, 4100, 3), [1, 2]), ((5, 4, 4100), [1, 2]), # 011 ((4100, 4, 3), [1, 2]), ((5, 4100, 3), [1, 2]), ((5, 4, 4100), [1, 2]), # 011
#((4100,4,3),[0,2]),((5,4100,3),[0,2]),((5,4,4100),[0,2]),#101 ##not implemented ((4100,4,3),[0,2]),((5,4100,3),[0,2]),((5,4,4100),[0,2]),
((4100, 4, 3), [0, 1, 2]), ((5, 4100, 3), [0, 1, 2]), ((5, 4, 4100), [0, 1, 2]), # 111 ((4100, 4, 3), [0, 1, 2]), ((5, 4100, 3), [0, 1, 2]), ((5, 4, 4100), [0, 1, 2]), # 111
((65, 4, 3), [0, 1, 2]), ((5, 65, 3), [0, 1, 2]), ((5, 4, 65), [0, 1, 2]), # 111 ((65, 4, 3), [0, 1, 2]), ((5, 65, 3), [0, 1, 2]), ((5, 4, 65), [0, 1, 2]), # 111
...@@ -133,13 +133,13 @@ def test_careduce(): ...@@ -133,13 +133,13 @@ def test_careduce():
# reduce over 2d # reduce over 2d
((4100, 4, 3, 2), [1, 2]), ((4, 4100, 3, 2), [1, 2]), ((4, 3, 4100, 2), [1, 2]), ((4, 3, 2, 4100), [1, 2]), # 0110 ((4100, 4, 3, 2), [1, 2]), ((4, 4100, 3, 2), [1, 2]), ((4, 3, 4100, 2), [1, 2]), ((4, 3, 2, 4100), [1, 2]), # 0110
# ((4100,4,3,2),[0,3]),((4,4100,3,2),[0,3]),((4,3,4100,2),[0,3]),((4,3,2,4100),[0,3]),#1001 need 101 ((4100,4,3,2),[0,3]),((4,4100,3,2),[0,3]),((4,3,4100,2),[0,3]),((4,3,2,4100),[0,3]),#1001
# ((4100,4,3,2),[0,2]),((4,4100,3,2),[0,2]),((4,3,4100,2),[0,2]),((4,3,2,4100),[0,2]),#1010 not implemented # ((4100,4,3,2),[0,2]),((4,4100,3,2),[0,2]),((4,3,4100,2),[0,2]),((4,3,2,4100),[0,2]),#1010 not implemented
((4100, 4, 3, 2), [0, 1]), ((4, 4100, 3, 2), [0, 1]), ((4, 3, 4100, 2), [0, 1]), ((4, 3, 2, 4100), [0, 1]), # 1100 ((4100, 4, 3, 2), [0, 1]), ((4, 4100, 3, 2), [0, 1]), ((4, 3, 4100, 2), [0, 1]), ((4, 3, 2, 4100), [0, 1]), # 1100
# reduce over 3d # reduce over 3d
# 3d not tested: 1101, 1110, 1111 # 3d not tested: 1101, 1110, 1111
# ((4100,4,3,2),[0,1,3]),((4,4100,3,2),[0,1,3]),((4,3,4100,2),[0,1,3]),((4,3,2,4100),[0,1,3]),#1101 need 101 ((4100,4,3,2),[0,1,3]),((4,4100,3,2),[0,1,3]),((4,3,4100,2),[0,1,3]),((4,3,2,4100),[0,1,3]),#1101
((4100, 4, 3, 2), [0, 1, 2]), ((4, 4100, 3, 2), [0, 1, 2]), ((4, 3, 4100, 2), [0, 1, 2]), ((4, 3, 2, 4100), [0, 1, 2]), # 1110 ((4100, 4, 3, 2), [0, 1, 2]), ((4, 4100, 3, 2), [0, 1, 2]), ((4, 3, 4100, 2), [0, 1, 2]), ((4, 3, 2, 4100), [0, 1, 2]), # 1110
# reduce over 4d # reduce over 4d
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论