提交 8d2f4abe authored 作者: Frederic Bastien's avatar Frederic Bastien

fix to allow more gpu sum to be done on the gpu.

上级 ba0db2f3
......@@ -259,7 +259,11 @@ def local_gpu_sum(node):
if hasattr(new_gsum, 'c_code_reduce_%s'%pattern):
reshaped_x = x.reshape(tensor.stack(*new_in_shp))
sum_reshaped_x = host_from_gpu(new_gsum(gpu_from_host(reshaped_x)))
unreshaped_sum = sum_reshaped_x.reshape(tensor.stack(*shape_of[node.outputs[0]]))
if sum_reshaped_x.ndim != node.outputs[0].ndim:
unreshaped_sum = sum_reshaped_x.reshape(tensor.stack(*shape_of[node.outputs[0]]))
else:
unreshaped_sum = sum_reshaped_x
if unreshaped_sum.type == node.outputs[0].type:
return [unreshaped_sum]
else:
......
......@@ -36,6 +36,14 @@ def tes_use():
def test_sum():
"""
test sum pattern 1, 11, 10, 01, 100, 110, 011, 001, 111, 0011, 0101, 0111, 1011, 1111
test sum pattern implemented with reshape:
1000, 0100, 0010, 0001, 11111
others implemented by reshape that are not tested
0011,0101,0110,1001,1010,1100
1110,1101,1011
TODO: test with broadcast
"""
for shape, pattern in [((100,3,1300),[1]),
......@@ -66,6 +74,15 @@ def test_sum():
((4100,4,3,2),[0,2,3]),((4,4100,3,2),[0,2,3]),((4,3,4100,2),[0,2,3]),#((4,3,2,4100),[0,2,3]),#1011
((4100,4,3,2),[1,2,3]),((4,4100,3,2),[1,2,3]),((4,3,4100,2),[1,2,3]),((4,3,2,4100),[1,2,3]),#0111
((4100,2,3,4),[0,1,2,3]),((2,4100,3,4),[0,1,2,3]),((2,3,4100,4),[0,1,2,3]),((2,3,4,4100),[0,1,2,3]),#1111
#test pattern implemented by reshape
((4100,4,3,2),[0]),((4,4100,3,2),[0]),((4,3,4100,2),[0]),((4,3,2,4100),[0]),#1000
((4100,4,3,2),[1]),((4,4100,3,2),[1]),((4,3,4100,2),[1]),((4,3,2,4100),[1]),#0100
((4100,4,3,2),[2]),((4,4100,3,2),[2]),((4,3,4100,2),[2]),((4,3,2,4100),[2]),#0010
((4100,4,3,2),[3]),((4,4100,3,2),[3]),((4,3,4100,2),[3]),((4,3,2,4100),[3]),#0001
((1100,2,3,4,5),[0,1,2,3,4]),((2,1100,3,4,5),[0,1,2,3,4]),((2,3,1100,4,5),[0,1,2,3,4]),((2,3,4,1100,5),[0,1,2,3,4]),((2,3,4,5,1100),[0,1,2,3,4]),#11111
]:
a = tensor.TensorType('float32',(False,)*len(shape))()
b = T.Sum(pattern)(a)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论