提交 bbc450e8 authored 作者: James Bergstra's avatar James Bergstra

Added a warning to cuda.opt and a check that we don't return the wrong type in

local_gpu_sum
上级 46001e3e
...@@ -226,7 +226,12 @@ def local_gpu_sum(node): ...@@ -226,7 +226,12 @@ def local_gpu_sum(node):
gsum=GpuSum(reduce_mask) gsum=GpuSum(reduce_mask)
pattern=(''.join(str(i) for i in reduce_mask)) pattern=(''.join(str(i) for i in reduce_mask))
if hasattr(gsum, 'c_code_reduce_%s'%pattern): if hasattr(gsum, 'c_code_reduce_%s'%pattern):
return [host_from_gpu(gsum(gpu_from_host(x)))] rval = host_from_gpu(gsum(gpu_from_host(x)))
if rval.type == node.outputs[0].type:
return [rval]
else:
print >> sys.stderr, "WARNING: local_gpu_sum got type wrong"
return None
else: else:
# Try to make a simpler pattern based on reshaping # Try to make a simpler pattern based on reshaping
...@@ -253,9 +258,13 @@ def local_gpu_sum(node): ...@@ -253,9 +258,13 @@ def local_gpu_sum(node):
reshaped_x = x.reshape(tensor.stack(*new_in_shp)) reshaped_x = x.reshape(tensor.stack(*new_in_shp))
sum_reshaped_x = host_from_gpu(new_gsum(gpu_from_host(reshaped_x))) sum_reshaped_x = host_from_gpu(new_gsum(gpu_from_host(reshaped_x)))
unreshaped_sum = sum_reshaped_x.reshape(tensor.stack(*shape_of[node.outputs[0]])) unreshaped_sum = sum_reshaped_x.reshape(tensor.stack(*shape_of[node.outputs[0]]))
return [unreshaped_sum] if unreshaped_sum.type == node.outputs[0].type:
return [unreshaped_sum]
else:
print >> sys.stderr, "WARNING: local_gpu_sum got type wrong"
return None
raise Exception("GpuSum don't have implemented the pattern",pattern) raise Exception("GpuSum don't have implemented the pattern",pattern)
return False return False
@register_opt() @register_opt()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论